In [149]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
from tsetlin import TsetlinMachine
import torch
import random

DATASET_DIR = '../datasets/'
DATA_FILE = 'bit_1.txt'
# SEED = 5779661865816281544
SEED = 5779661865816281544

text_rows = open(f'{DATASET_DIR}{DATA_FILE}', 'r').read().splitlines()
dataset = [ [int(num) for num in row.split(',')] for row in text_rows]
tensor_dataset = torch.tensor(dataset)
train_x = tensor_dataset[:, :-1]
train_y = tensor_dataset[:, -1]

if SEED:
    random.seed(SEED)
    torch.manual_seed(SEED)
else:
    seed = int.from_bytes(os.urandom(8), byteorder="big", signed=False)
    random.seed(seed)
    torch.manual_seed(seed)
    print(seed)

tm = TsetlinMachine(train_x.shape[1], 5)

out_1 = tm.forward(train_x)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [150]:
from tsetlin import TsetlinBase

b = TsetlinBase()
out_check = b.conjunction_mul(tm.l1.full_X.unsqueeze(1), tm.l1.W)
assert torch.all(out_check == tm.l1.out)

In [151]:
tm.l1.W, tm.l1.full_X, tm.l1.out

(tensor([[1, 0, 0, 0, 1, 1],
         [0, 1, 1, 1, 0, 0],
         [1, 0, 1, 0, 1, 0],
         [1, 1, 1, 0, 0, 0],
         [0, 1, 1, 1, 0, 0]]),
 tensor([[1, 0, 0, 0, 1, 1],
         [1, 0, 1, 0, 1, 0],
         [1, 1, 0, 0, 0, 1],
         [1, 1, 1, 0, 0, 0],
         [0, 0, 0, 1, 1, 1],
         [0, 1, 0, 1, 0, 1],
         [0, 0, 1, 1, 1, 0],
         [0, 1, 1, 1, 0, 0]]),
 tensor([[1, 0, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 1, 0, 0, 1]]))

In [152]:
zero_Y_row_idxs_per_W_row = []
one_Y_row_idxs_per_W_row = []
for i in range(tm.l1.W.shape[0]):
    row_Y = tm.l1.out[:, i]
    
    zero_Y_idxs = torch.nonzero(row_Y == 0).squeeze(1).tolist()
    zero_Y_row_idxs_per_W_row.append(set(zero_Y_idxs))

    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    one_Y_row_idxs_per_W_row.append(set(one_Y_idxs))

zero_Y_row_idxs_per_W_row, one_Y_row_idxs_per_W_row

([{1, 2, 3, 4, 5, 6, 7},
  {0, 1, 2, 3, 4, 5, 6},
  {0, 2, 3, 4, 5, 6, 7},
  {0, 1, 2, 4, 5, 6, 7},
  {0, 1, 2, 3, 4, 5, 6}],
 [{0}, {7}, {1}, {3}, {7}])

In [153]:
import math
import copy

from itertools import combinations, chain
from collections import deque

def generate_subsets(set_elements, subset_size):
    return [set(x) for x in list(combinations(set_elements, subset_size))]

def generate_powerset(set_elements):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(set_elements)
    return [set(x) for x in list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))]

unique_one_Y_row_idxs = set()
visited_ones = set()
for i,x in enumerate(one_Y_row_idxs_per_W_row):
    if x:
        tuple_x = tuple(x)
        if tuple_x not in visited_ones:
            visited_ones.add(tuple_x)
            unique_one_Y_row_idxs.add(i)


tracking = {x: zero_Y_row_idxs_per_W_row[x] for x in unique_one_Y_row_idxs}
sorted_one_Y_row_idxs = sorted(list(unique_one_Y_row_idxs), key=lambda x: len(one_Y_row_idxs_per_W_row[x]), reverse=True)
q = deque(sorted_one_Y_row_idxs)

def recursive_helper(depth, max_depth, current_solution, prev_W_row_idx, q):
    if depth == max_depth or len(current_solution) == 0:
        return [], len(current_solution) == 0

    curr_W_row_idx = prev_W_row_idx
    while curr_W_row_idx not in current_solution and q:
        curr_W_row_idx = q.popleft()

    curr_one_Y_idxs = one_Y_row_idxs_per_W_row[curr_W_row_idx]
    min_zero_Y_idxs_len = math.ceil(len(current_solution[curr_W_row_idx]) / (max_depth - depth))
    min_zero_Y_subsets = generate_subsets(current_solution[curr_W_row_idx], min(min_zero_Y_idxs_len, len(current_solution[curr_W_row_idx])))

    ordered_min_zero_Y_subsets = []
    remaining_q = list(q)
    for idx in remaining_q:
        one_Y_idx = one_Y_row_idxs_per_W_row[idx]
        if len(one_Y_idx) == min_zero_Y_idxs_len and len(one_Y_idx & curr_one_Y_idxs) == 0 and len(one_Y_idx & current_solution[curr_W_row_idx]) > 0:
            ordered_min_zero_Y_subsets.append(one_Y_idx)

    for subset in min_zero_Y_subsets:
        if subset not in ordered_min_zero_Y_subsets:
            ordered_min_zero_Y_subsets.append(subset)

    for min_zero_Y_subset in ordered_min_zero_Y_subsets:
        remaining_Y_idxs = set(range(tm.l1.full_X.shape[0])) - (min_zero_Y_subset | curr_one_Y_idxs)
        remaining_Y_subsets = generate_powerset(remaining_Y_idxs)
        remaining_Y_subsets.sort(key=lambda x: len(x), reverse=True)

        remaining_Y_subsets_ordered = []
        for idx in remaining_q:
            one_Y_idx = one_Y_row_idxs_per_W_row[idx]
            if one_Y_idx.issubset(remaining_Y_idxs):
                remaining_Y_subsets_ordered.append(one_Y_idx)

        for subset in remaining_Y_subsets:
            if subset not in remaining_Y_subsets_ordered:
                remaining_Y_subsets_ordered.append(subset)

        for remaining_Y_subset in remaining_Y_subsets_ordered:
            opposite_remaining_Y_subset = remaining_Y_idxs - remaining_Y_subset

            #add remaining with the opposite
            first_left_W = curr_one_Y_idxs | opposite_remaining_Y_subset
            first_right_W = min_zero_Y_subset | remaining_Y_subset

            second_left_W = curr_one_Y_idxs | remaining_Y_subset
            second_right_W = min_zero_Y_subset | opposite_remaining_Y_subset

            for left_W, right_W in [(first_left_W, first_right_W), (second_left_W, second_right_W)]:
                updated_solution = {}
                for k,v in current_solution.items():
                    one_Y_idxs = one_Y_row_idxs_per_W_row[k]
                    if one_Y_idxs.issubset(left_W):
                        sub = v - right_W
                        if len(sub) > 0:
                            updated_solution[k] = sub
                    elif one_Y_idxs.issubset(right_W):
                        sub = v - left_W
                        if len(sub) > 0:
                            updated_solution[k] = sub
                    else:
                        updated_solution[k] = v

                next_cols, solved = recursive_helper(depth+1, max_depth, updated_solution, curr_W_row_idx, copy.deepcopy(q))
                if solved:
                    combined_cols = next_cols
                    combined_cols.append((left_W, right_W))
                    return combined_cols, True
                
    return [], False

cols, solved = recursive_helper(0, tm.l1.in_dim, tracking, q.popleft(), q)
print(cols)
assert solved

new_W = torch.zeros_like(tm.l1.W)
for row_idx, x in enumerate(one_Y_row_idxs_per_W_row):
    if x:
        for i, col in enumerate(cols):
            col_left = col[0]
            col_right = col[1]
            if x.issubset(col_left):
                new_W[row_idx, i] = 1
            elif x.issubset(col_right):
                new_W[row_idx, i + tm.l1.in_dim] = 1

new_full_X = torch.zeros_like(tm.l1.full_X)
for i, col in enumerate(cols):
    new_full_X[list(col[0]), i] = 1
    new_full_X[list(col[1]), i + tm.l1.in_dim] = 1 # this is not needed in the actual code because we only pass the left side back to the previous layer

[({0, 2, 4, 5, 6, 7}, {1, 3}), ({0, 3, 7}, {1, 2, 4, 5, 6}), ({0, 4, 5, 6}, {1, 2, 3, 7})]


In [154]:
new_full_X, new_W

(tensor([[1, 1, 1, 0, 0, 0],
         [0, 0, 0, 1, 1, 1],
         [1, 0, 0, 0, 1, 1],
         [0, 1, 0, 1, 0, 1],
         [1, 0, 1, 0, 1, 0],
         [1, 0, 1, 0, 1, 0],
         [1, 0, 1, 0, 1, 0],
         [1, 1, 0, 0, 0, 1]]),
 tensor([[1, 1, 1, 0, 0, 0],
         [1, 1, 0, 0, 0, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 1, 0, 1, 0, 1],
         [1, 1, 0, 0, 0, 1]]))

In [155]:
from tsetlin import TsetlinBase

b = TsetlinBase()
y_2 = b.conjunction_mul(new_full_X.unsqueeze(1), new_W)
print(y_2)
assert torch.all(y_2 == tm.l1.out)

tensor([[1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [0, 1, 0, 0, 1]])
