In [115]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
from tsetlin import TsetlinMachine
import torch
import random
import math
import copy

from itertools import combinations, chain
from collections import deque, defaultdict

DATASET_DIR = '../datasets/'
DATA_FILE = 'bit_2.txt'

text_rows = open(f'{DATASET_DIR}{DATA_FILE}', 'r').read().splitlines()
dataset = [ [int(num) for num in row.split(',')] for row in text_rows]
tensor_dataset = torch.tensor(dataset)
train_x = tensor_dataset[:, :-1]
train_y = tensor_dataset[:, -1]


def generate_powerset_iterator(set_elements):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    return chain.from_iterable(combinations(set_elements,r) for r in range(len(set_elements), -1, -1))

def combine_iterators(iterable_one, iterable_two):
    for item in iterable_one:
        if item:
            merged_set = set().union(*item)
            yield merged_set

    for item in iterable_two:
        yield item

SEED = None
if SEED:
    random.seed(SEED)
    torch.manual_seed(SEED)
else:
    seed = int.from_bytes(os.urandom(8), byteorder="big", signed=False)
    random.seed(seed)
    torch.manual_seed(seed)
    print(seed)

tm = TsetlinMachine(train_x.shape[1], 5)

tm.forward(train_x)
W_row_to_zero_Y_row_idxs = {}
W_row_to_one_Y_row_idxs = {}
for i in range(tm.l1.W.shape[0]):
    row_Y = tm.l1.out[:, i]
    
    zero_Y_idxs = torch.nonzero(row_Y == 0).squeeze(1).tolist()
    if zero_Y_idxs:
        W_row_to_zero_Y_row_idxs[i] = set(zero_Y_idxs)

    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    if one_Y_idxs:
        W_row_to_one_Y_row_idxs[i] = set(one_Y_idxs)

W_rows_of_unique_one_Y_row_idxs = set() # this is constructed by assigning one_Y_row_idxs to W columns incrementally (from 0). This will later be changed
visited_one_Y_row_idxs = set()

W_col_to_new_X_row_idxs = {}
for W_row_idx, one_Y_row_idxs in W_row_to_one_Y_row_idxs.items():
    tuple_value = tuple(one_Y_row_idxs)
    if tuple_value not in visited_one_Y_row_idxs:
        visited_one_Y_row_idxs.add(tuple_value)
        W_rows_of_unique_one_Y_row_idxs.add(W_row_idx)

one_Y_row_state = {W_row: W_row_to_zero_Y_row_idxs.get(W_row, set())  for W_row in W_rows_of_unique_one_Y_row_idxs} # this tracks unresolved zero Y row idxs for each W row idx
sorted_one_Y_row_idxs = sorted(list(W_rows_of_unique_one_Y_row_idxs), key=lambda x: len(W_row_to_one_Y_row_idxs[x]), reverse=True) # a heuristical optimization to address the largest one_Y_row_idxs first
q = deque(sorted_one_Y_row_idxs)

def get_new_X_row_idxs_per_W_col(depth, max_depth, curr_one_Y_row_state, prev_W_row_idx, q):
    # the output is of shape [({1,2,3},{4,5,6}), ({2,3},{4,5,1,6}), ...] where ({1,2,3},{4,5,6}) means
    # that W[[1,2,3]][0] should be 1 and W[[4,5,6]][0] should be 0 and full_X[[1,2,3]][0] should be 1 
    # and full_X[[4,5,6]][0] should be 0
    if depth == max_depth or len(curr_one_Y_row_state) == 0:
        return [], len(curr_one_Y_row_state) == 0

    curr_W_row_idx = prev_W_row_idx
    while curr_W_row_idx not in curr_one_Y_row_state and q:
        curr_W_row_idx = q.popleft()

    curr_one_Y_idxs = W_row_to_one_Y_row_idxs[curr_W_row_idx]
    if not curr_one_Y_row_state[curr_W_row_idx]:
        updated_one_Y_row_state = copy.deepcopy(curr_one_Y_row_state)
        del updated_one_Y_row_state[curr_W_row_idx]
        sub_new_X_row_idxs_per_W_col, is_solved = get_new_X_row_idxs_per_W_col(depth+1, max_depth, updated_one_Y_row_state, curr_W_row_idx, copy.deepcopy(q))
        if is_solved:
            new_X_row_idxs_per_W_col = sub_new_X_row_idxs_per_W_col
            new_X_row_idxs_per_W_col.append(( curr_one_Y_idxs, set()))
            return new_X_row_idxs_per_W_col, True
        else:
            return [], False

    remaining_q = list(q)
    remaining_one_Y_idxs = [W_row_to_one_Y_row_idxs[x] for x in remaining_q if len(W_row_to_one_Y_row_idxs[x] & curr_one_Y_idxs) == 0 and len(W_row_to_one_Y_row_idxs[x] & curr_one_Y_row_state[curr_W_row_idx]) > 0]
    for opposite_set in combine_iterators(generate_powerset_iterator(remaining_one_Y_idxs), [curr_one_Y_row_state[curr_W_row_idx]]):
        remaining_Y_idxs = set(range(tm.l1.full_X.shape[0])) - (opposite_set | curr_one_Y_idxs)
        sub_remaining_one_Y_idxs = [ x for x in remaining_one_Y_idxs if x.issubset(remaining_Y_idxs)]
        
        for remaining_merged_set in combine_iterators(generate_powerset_iterator(sub_remaining_one_Y_idxs), [set()]):
            complement_remaining_Y_subset = remaining_Y_idxs - remaining_merged_set

            first_left_W = curr_one_Y_idxs | complement_remaining_Y_subset
            first_right_W = opposite_set | remaining_merged_set

            second_left_W = curr_one_Y_idxs | remaining_merged_set
            second_right_W = opposite_set | complement_remaining_Y_subset

            for left_W, right_W in [(first_left_W, first_right_W), (second_left_W, second_right_W)]:
                updated_one_Y_row_state = {}
                for k,v in curr_one_Y_row_state.items():
                    one_Y_idxs = W_row_to_one_Y_row_idxs[k]
                    sub_diff = v

                    if one_Y_idxs.issubset(left_W):
                        sub_diff = v - right_W
                    elif one_Y_idxs.issubset(right_W):
                        sub_diff = v - left_W
                    
                    # implicit here is the removal of one_Y_idxs for which there is no unresolved zero Y row idxs left
                    if len(sub_diff) > 0:
                        updated_one_Y_row_state[k] = sub_diff

                sub_new_X_row_idxs_per_W_col, is_solved = get_new_X_row_idxs_per_W_col(depth+1, max_depth, updated_one_Y_row_state, curr_W_row_idx, copy.deepcopy(q))
                if is_solved:
                    new_X_row_idxs_per_W_col = sub_new_X_row_idxs_per_W_col
                    new_X_row_idxs_per_W_col.append((left_W, right_W))
                    return new_X_row_idxs_per_W_col, True
        
    return [], False

new_X_row_idxs_per_W_col, is_solved = get_new_X_row_idxs_per_W_col(0, tm.l1.in_dim, one_Y_row_state, q.popleft(), q) # X_row_idxs_per_W_col does not necessarily contain a slot for each col
assert is_solved

# new_X_row_idxs_per_W_col provides a valid update of W and full_X that satisfies the expected Y.
# However, we assigned one_Y_row_idxs to W columns incrementally (from 0) for simplicity.
# Below, we determine the best W column assignment based on W_confidence.

W_row_idxs_per_col = defaultdict(lambda: [[], []]) # this represents all one_W_row_idxs and zero_W_row_idxs pairs
for W_row_idx, one_Y_row_idxs in W_row_to_one_Y_row_idxs.items():
    for W_col_idx, new_X_row_idxs in enumerate(new_X_row_idxs_per_W_col):
        if one_Y_row_idxs.issubset(new_X_row_idxs[0]):
            W_row_idxs_per_col[W_col_idx][0].append(W_row_idx)
        elif one_Y_row_idxs.issubset(new_X_row_idxs[1]):
            W_row_idxs_per_col[W_col_idx][1].append(W_row_idx)



SEED = None
if SEED:
    random.seed(SEED)
    torch.manual_seed(SEED)
else:
    seed = int.from_bytes(os.urandom(8), byteorder="big", signed=False)
    random.seed(seed)
    torch.manual_seed(seed)
    # print(seed)

W_confidence = torch.randint_like(tm.l1.W, 0, 30)


# calculate the W_confidence sum of each one_W_row_idxs and zero_W_row_idxs pairs for all W columns
W_row_idxs_sets_confidence_sum_per_col = []
for W_row_idxs in W_row_idxs_per_col.keys():
    sums = W_confidence[W_row_idxs_per_col[W_row_idxs][0]].sum(dim=0)
    neg_sum = torch.roll(W_confidence[W_row_idxs_per_col[W_row_idxs][1]].sum(dim=0), shifts = -tm.l1.in_dim, dims=0)
    sums += neg_sum
    W_row_idxs_sets_confidence_sum_per_col.append(sums)
    
W_row_idxs_sets_confidence_sum_per_col = torch.stack(W_row_idxs_sets_confidence_sum_per_col)
sorted_W_row_idxs_sets_confidence_sum_per_col = torch.sort(W_row_idxs_sets_confidence_sum_per_col, dim=1, descending=False) # sort by increasing sum

# Prune W columns
opt_values = []
opt_indices = []

for sum_values, sum_indices in zip(sorted_W_row_idxs_sets_confidence_sum_per_col.values, sorted_W_row_idxs_sets_confidence_sum_per_col.indices):
    opt_sums = []
    opt_sum_idxs = []
    visited_col_idxs = set()
    for sum_value, idx in zip(sum_values, sum_indices):
        idx_value = idx.item()
        if idx_value % tm.l1.in_dim not in visited_col_idxs:
            opt_sums.append(sum_value.item())
            opt_sum_idxs.append(idx_value)
            visited_col_idxs.add(idx_value % tm.l1.in_dim)

        if len(visited_col_idxs) == tm.l1.in_dim:
            break
    
    opt_values.append(opt_sums)
    opt_indices.append(opt_sum_idxs)

sorted_W_row_idxs_sets_confidence_sum_per_col_values = torch.tensor(opt_values)
sorted_W_row_idxs_sets_confidence_sum_per_col_indices = torch.tensor(opt_indices) 

# a heuristical optimization that sorts W columns by increasing offset sum across one_W_row_idxs and zero_W_row_idxs pairs
offset_sorted_W_row_idxs_sets_confidence_sum_per_col = sorted_W_row_idxs_sets_confidence_sum_per_col_values - sorted_W_row_idxs_sets_confidence_sum_per_col_values[:, 0].unsqueeze(1) # normalize the sum by subtracting the smallest sum
offset_W_row_idxs_sets_confidence_sum_to_cols_dict = defaultdict(list)
for col_idx, offset_sums in enumerate(offset_sorted_W_row_idxs_sets_confidence_sum_per_col):
    for offset_sum in offset_sums:
        offset_W_row_idxs_sets_confidence_sum_to_cols_dict[offset_sum.item()].append(col_idx)
sorted_W_row_idxs_sets_confidence_sum = sorted(offset_W_row_idxs_sets_confidence_sum_to_cols_dict.keys())
W_row_idxs_set_sequencing = [offset_W_row_idxs_sets_confidence_sum_to_cols_dict[x] for x in sorted_W_row_idxs_sets_confidence_sum] # based on increasing offset W row idxs sets sum
W_row_idxs_set_sequencing = [ x for sublist in W_row_idxs_set_sequencing for x in sublist] # flatten

def get_W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum(W_row_idxs_set_idxs, max_sorted_idx_per_W_row_idxs_set_idxs, used_W_col_idxs, max_sum):
    # This is the core function that determines the best W column assignment based on W_confidence. Before was all preprocessing for a faster algorithm.
    # The output shape is {0:1, 1:0, 2:5} where 0:1 means that one_W_row_idxs and zero_W_row_idxs pair indexed at 0 should be assigned to W column 1
    
    if len(W_row_idxs_set_idxs) == 1:
        W_row_idxs_set_idx = list(W_row_idxs_set_idxs)[0]
        max_sorted_idx = max_sorted_idx_per_W_row_idxs_set_idxs[W_row_idxs_set_idx] if max_sorted_idx_per_W_row_idxs_set_idxs is not None else tm.l1.in_dim - 1
        
        for sorted_idx in range(max_sorted_idx + 1):
            col_idx = sorted_W_row_idxs_sets_confidence_sum_per_col_indices[W_row_idxs_set_idx, sorted_idx].item()
            W_row_idxs_confidence_sum = sorted_W_row_idxs_sets_confidence_sum_per_col_values[W_row_idxs_set_idx, sorted_idx].item()

            if max_sum is not None and W_row_idxs_confidence_sum >= max_sum:
                return None, None
            if col_idx % tm.l1.in_dim not in used_W_col_idxs:
                return W_row_idxs_confidence_sum, {W_row_idxs_set_idx: sorted_idx}

        return None, None

    curr_max_sorted_idx_per_W_row_idxs_set_idxs = [-1] * len(new_X_row_idxs_per_W_col)
    min_confidence_sum = max_sum
    W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum = None

    for i in range(len(W_row_idxs_set_sequencing)):
        W_row_idxs_set_idx = W_row_idxs_set_sequencing[i]
        if W_row_idxs_set_idx not in W_row_idxs_set_idxs:
            continue

        curr_max_sorted_idx_per_W_row_idxs_set_idxs[W_row_idxs_set_idx] += 1
        if max_sorted_idx_per_W_row_idxs_set_idxs is not None and curr_max_sorted_idx_per_W_row_idxs_set_idxs[W_row_idxs_set_idx] > max_sorted_idx_per_W_row_idxs_set_idxs[W_row_idxs_set_idx]:
            return min_confidence_sum, W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum

        col_idx = sorted_W_row_idxs_sets_confidence_sum_per_col_indices[W_row_idxs_set_idx, curr_max_sorted_idx_per_W_row_idxs_set_idxs[W_row_idxs_set_idx]].item()
        W_row_idxs_confidence_sum = sorted_W_row_idxs_sets_confidence_sum_per_col_values[W_row_idxs_set_idx, curr_max_sorted_idx_per_W_row_idxs_set_idxs[W_row_idxs_set_idx]].item()
        if min_confidence_sum is not None and W_row_idxs_confidence_sum >= min_confidence_sum:
            return min_confidence_sum, W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum

        if col_idx % tm.l1.in_dim not in used_W_col_idxs:
            updated_used_col_idxs = used_W_col_idxs | {col_idx % tm.l1.in_dim}
            new_max_sum = max_sum - W_row_idxs_confidence_sum if max_sum is not None else None
            sub_min_confidence_sum , sub_W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum = get_W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum(W_row_idxs_set_idxs - {W_row_idxs_set_idx}, curr_max_sorted_idx_per_W_row_idxs_set_idxs, updated_used_col_idxs, new_max_sum)

            if (min_confidence_sum is None and sub_min_confidence_sum is not None) or (sub_min_confidence_sum is not None and sub_min_confidence_sum + W_row_idxs_confidence_sum < min_confidence_sum):
                min_confidence_sum = W_row_idxs_confidence_sum + sub_min_confidence_sum
                W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum = sub_W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum
                W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum[W_row_idxs_set_idx] = curr_max_sorted_idx_per_W_row_idxs_set_idxs[W_row_idxs_set_idx]

    return min_confidence_sum, W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum

min_sum_1, sol_dict_1 = get_W_row_idxs_set_idx_to_sorted_col_idx_w_min_confidence_sum(set(range(len(new_X_row_idxs_per_W_col))), None, set(), None)

print(min_sum_1)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
8086226574029401178
263
