In [23]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
from tsetlin import TsetlinMachine
import torch
import random

# Working seeds
# SEED = 9496461801973866405 # this works
# SEED = 3158761121381184149
# SEED = 4954668783344399908
# SEED = 2763762423645165824

DATASET_DIR = '../datasets/'
DATA_FILE = 'bit_1.txt'
SEED = 1158604801931000331

text_rows = open(f'{DATASET_DIR}{DATA_FILE}', 'r').read().splitlines()
dataset = [ [int(num) for num in row.split(',')] for row in text_rows]
tensor_dataset = torch.tensor(dataset)
train_x = tensor_dataset[:, :-1]
train_y = tensor_dataset[:, -1]


if SEED:
    random.seed(SEED)
    torch.manual_seed(SEED)
else:
    seed = int.from_bytes(os.urandom(8), byteorder="big", signed=False)
    random.seed(seed)
    torch.manual_seed(seed)
    print(seed)

tm = TsetlinMachine(train_x.shape[1], 5)
out_1 = tm.forward(train_x)
tm.l1.out

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


tensor([[0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [1, 0, 0, 0, 1],
        [0, 1, 0, 1, 1],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 1, 1, 0],
        [0, 0, 0, 1, 0]])

In [24]:
from tabulate import tabulate

zero_Y_row_idxs_per_W_row = []
for i in range(tm.l1.W.shape[0]):
    row_Y = tm.l1.out[:, i]
    zero_Y_idxs = torch.nonzero(row_Y == 0).squeeze(1).tolist()
    zero_Y_row_idxs_per_W_row.append(set(zero_Y_idxs))
zero_Y_row_idxs_per_W_row

[{0, 1, 3, 4, 6, 7},
 {0, 1, 2, 4, 5, 6, 7},
 {0, 1, 2, 3, 5, 7},
 {0, 2, 4, 5},
 {0, 1, 4, 5, 6, 7}]

In [25]:
from tabulate import tabulate

one_Y_row_idxs_per_W_row = []
for i in range(tm.l1.W.shape[0]):
    row_Y = tm.l1.out[:, i]
    zero_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    one_Y_row_idxs_per_W_row.append(set(zero_Y_idxs))
one_Y_row_idxs_per_W_row

[{2, 5}, {3}, {4, 6}, {1, 3, 6, 7}, {2, 3}]

In [26]:
import queue

unique_idxs = set()
visited_ones = set()
for i,x in enumerate(one_Y_row_idxs_per_W_row):
    tuple_x = tuple(x)
    if tuple_x not in visited_ones:
        visited_ones.add(tuple_x)
        unique_idxs.add(i)


row_idx_to_one_values = {x: one_Y_row_idxs_per_W_row[x] for x in unique_idxs}
tracking = {x: zero_Y_row_idxs_per_W_row[x] for x in unique_idxs}

row_idx_to_one_values, tracking

({0: {2, 5}, 1: {3}, 2: {4, 6}, 3: {1, 3, 6, 7}, 4: {2, 3}},
 {0: {0, 1, 3, 4, 6, 7},
  1: {0, 1, 2, 4, 5, 6, 7},
  2: {0, 1, 2, 3, 5, 7},
  3: {0, 2, 4, 5},
  4: {0, 1, 4, 5, 6, 7}})

In [27]:
sorted_one_Y_row_idxs = sorted(row_idx_to_one_values, key=lambda x: len(row_idx_to_one_values[x]), reverse=True)
sorted_one_Y_row_idxs

[3, 0, 2, 4, 1]

In [38]:
import math
from itertools import combinations, chain

def generate_subsets(set_elements, combination_size):
    return list(combinations(set_elements, combination_size))

def generate_powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return list(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))


q = queue.Queue()
for i in sorted_one_Y_row_idxs:
    q.put(i)

def funca(depth, max_depth, current_solution, prev_clause_idx, q):
    if depth == max_depth or len(current_solution) == 0:
        return [], len(current_solution) == 0

    curr_clause_idx = prev_clause_idx
    while curr_clause_idx not in current_solution and q:
        curr_clause_idx = q.get()

    curr_clause = row_idx_to_one_values[curr_clause_idx]
    min_zero_rows = math.ceil((tm.l1.full_X.shape[0] - len(current_solution[curr_clause_idx])) / (max_depth - depth))
    min_zero_subset_tuples = generate_subsets(tracking[curr_clause_idx], min_zero_rows)

    for min_zero_subset_tuple in min_zero_subset_tuples:
        min_zero_subset = set(min_zero_subset_tuple)
        remaining_values = set(range(tm.l1.full_X.shape[0])) - (min_zero_subset | curr_clause)
        remaining_subsets_tuples = generate_powerset(remaining_values)
        for remaining_subset_tuple in remaining_subsets_tuples:
            remaining_subset = set(remaining_subset_tuple)
            opposite_remaining_subset = remaining_values - remaining_subset

            #add remaining with the curr_clause
            left_clause = curr_clause | set(remaining_subset_tuple)
            right_clause = min_zero_subset | opposite_remaining_subset

            updated_solution = {}
            for k,v in current_solution.items():
                corresponding_clause = one_Y_row_idxs_per_W_row[k]
                if corresponding_clause.issubset(left_clause):
                    sub = v - right_clause
                    if len(sub) > 0:
                        updated_solution[k] = sub
                elif corresponding_clause.issubset(right_clause):
                    sub = v - left_clause
                    if len(sub) > 0:
                        updated_solution[k] = sub

            next_layers, solved = funca(depth+1, max_depth, updated_solution, curr_clause_idx, q)
            if solved:
                return_layers = next_layers
                return_layers.append((left_clause, right_clause))
                return return_layers, True

            #add remaining with the opposite
            left_clause = curr_clause | set(opposite_remaining_subset)
            right_clause = min_zero_subset | remaining_subset_tuple

            updated_solution = {}
            for k,v in current_solution.items():
                corresponding_clause = one_Y_row_idxs_per_W_row[k]
                if corresponding_clause.issubset(left_clause):
                    sub = v - right_clause
                    if len(sub) > 0:
                        updated_solution[k] = sub
                elif corresponding_clause.issubset(right_clause):
                    sub = v - left_clause
                    if len(sub) > 0:
                        updated_solution[k] = sub

            next_layers, solved = funca(depth+1, max_depth, updated_solution, curr_clause_idx, q)
            if solved:
                return_layers = next_layers
                return_layers.append((left_clause, right_clause))
                return return_layers, True
            
    return [], False

funca(0, tm.l1.in_dim, tracking, q.get(), q)

[(), (4,), (5,), (4, 5)]
