In [960]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
from tsetlin import TsetlinMachine
import torch
import random

DATASET_DIR = '../datasets/'
DATA_FILE = 'bit_1.txt'
SEED = 9496461801973866405
# SEED = None

text_rows = open(f'{DATASET_DIR}{DATA_FILE}', 'r').read().splitlines()
dataset = [ [int(num) for num in row.split(',')] for row in text_rows]
tensor_dataset = torch.tensor(dataset)
train_x = tensor_dataset[:, :-1]
train_y = tensor_dataset[:, -1]

if SEED:
    random.seed(SEED)
    torch.manual_seed(SEED)
else:
    seed = int.from_bytes(os.urandom(8), byteorder="big", signed=False)
    random.seed(seed)
    torch.manual_seed(seed)
    print(seed)
tm = TsetlinMachine(train_x.shape[1], 5)
out_1 = tm.forward(train_x)
out_1, train_y

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


(tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([1, 1, 1, 1, 0, 0, 0, 0]))

In [961]:
Y = torch.randint(0,2, tm.l1.out.size())
tm.l1.out, Y, tm.l1.W

(tensor([[0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [1, 0, 0, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1]]),
 tensor([[0, 0, 1, 1, 0],
         [0, 1, 1, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 0, 1, 0, 1],
         [0, 0, 1, 1, 0],
         [0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [0, 1, 1, 1, 1]]),
 tensor([[1, 1, 1, 0, 0, 0],
         [1, 0, 1, 0, 1, 0],
         [0, 1, 0, 0, 0, 1],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 1, 0, 0]]))

In [962]:
from tabulate import tabulate

tabular_W_collisions = [] # these are the necessary ones
for i in range(tm.l1.full_X.shape[0]):
    row_W_collisions = [[] for _ in range(tm.l1.in_dim * 2)]
    row_Y = Y[i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
    pos_W = W_halves[0]
    neg_W = W_halves[1]
    for pos_one_Y_idx in one_Y_idxs:
        w_1 = pos_W[pos_one_Y_idx]
        prod = w_1 * neg_W[one_Y_idxs]
        collision_idxs = (prod == 1).nonzero(as_tuple=False)
        for row_collision_idx, col_collision_idx in collision_idxs:
            neg_col_collision_idx = col_collision_idx + tm.l1.in_dim
            if pos_one_Y_idx not in row_W_collisions[col_collision_idx]:
                row_W_collisions[col_collision_idx].append(pos_one_Y_idx)
            if one_Y_idxs[row_collision_idx] not in row_W_collisions[neg_col_collision_idx]:
                row_W_collisions[neg_col_collision_idx].append(one_Y_idxs[row_collision_idx])

    tabular_W_collisions.append(row_W_collisions)

headers = [f"Column {i+1}" for i in range(len(tabular_W_collisions[0]))]
print(tabulate(tabular_W_collisions, headers=headers, tablefmt="fancy_grid"))

╒════════════╤════════════╤════════════╤════════════╤════════════╤════════════╕
│ Column 1   │ Column 2   │ Column 3   │ Column 4   │ Column 5   │ Column 6   │
╞════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ []         │ []         │ [3]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [1]        │ [2, 4]     │ [1, 4]     │ [4]        │ [1]        │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [0, 1]     │ [0, 4]     │ []         │ [4]        │ [1]        │ []         │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ []         │ []         │ [4]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ []         │ []         │ [3]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼

In [963]:
from tabulate import tabulate

W_collisions = [[] for _ in range(tm.l1.W.shape[1])] # these are the necessary ones
for i in range(tm.l1.full_X.shape[0]):
    row_Y = Y[i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
    pos_W = W_halves[0]
    neg_W = W_halves[1]
    for pos_one_Y_idx in one_Y_idxs:
        w_1 = pos_W[pos_one_Y_idx]
        prod = w_1 * neg_W[one_Y_idxs]
        collision_idxs = (prod == 1).nonzero(as_tuple=False)
        for row_collision_idx, col_collision_idx in collision_idxs:
            neg_col_collision_idx = col_collision_idx + tm.l1.in_dim
            if pos_one_Y_idx not in W_collisions[col_collision_idx]:
                W_collisions[col_collision_idx].append(pos_one_Y_idx)
            if one_Y_idxs[row_collision_idx] not in W_collisions[neg_col_collision_idx]:
                W_collisions[neg_col_collision_idx].append(one_Y_idxs[row_collision_idx])

W_collisions

[[1, 0], [2, 4, 0, 3], [3, 1, 4], [4], [1], [2]]

In [964]:
# W_deps_full = []
# for i, single_x in enumerate(tm.l1.full_X):
#     W_collisions_row = [[] for _ in range(tm.l1.in_dim * 2)]
#     row_Y = Y[i]
#     one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1)
#     W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
#     pos_W = W_halves[0]
#     neg_W = W_halves[1]
#     for pos_one_Y_idx in one_Y_idxs:
#         w_1 = pos_W[pos_one_Y_idx]
#         for neg_one_Y_idx in one_Y_idxs:
#             w_2 = neg_W[neg_one_Y_idx]
#             pos_one_W_idxs = torch.nonzero(w_1 == 1).squeeze(1)
#             neg_one_W_idxs = torch.nonzero(w_2 == 1).squeeze(1)
#             for collision_idx in pos_one_W_idxs:
#                 if pos_one_Y_idx.item() not in W_collisions_row[collision_idx.item()]:
#                     W_collisions_row[collision_idx.item()].append(pos_one_Y_idx.item())
#             for collision_idx in neg_one_W_idxs:
#                 if neg_one_Y_idx.item() not in W_collisions_row[collision_idx.item() + tm.l1.in_dim]:
#                     W_collisions_row[collision_idx.item() + tm.l1.in_dim].append(neg_one_Y_idx.item())
#     W_deps_full.append(W_collisions_row)
# W_deps_full

In [965]:
tm.l1.full_X

tensor([[1, 0, 0, 0, 1, 1],
        [1, 0, 1, 0, 1, 0],
        [1, 1, 0, 0, 0, 1],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1],
        [0, 1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1, 0],
        [0, 1, 1, 1, 0, 0]])

In [966]:
from tabulate import tabulate

flip_deps_full = []
for i, single_W in enumerate(tm.l1.W):
    col_Y = Y[:, i]
    one_Y_idxs = torch.nonzero(col_Y == 1).squeeze(1)
    X_flip_row_idxs_by_col = [[] for _ in range(tm.l1.in_dim * 2)]

    mask_X = tm.l1.full_X == 1
    for row in range(mask_X.shape[0]):
        for col in range(mask_X.shape[1]):
            if (not mask_X[row,col] and row in one_Y_idxs) or (mask_X[row,col] and row not in one_Y_idxs):
                X_flip_row_idxs_by_col[col].append(row)

    flip_deps_full.append(X_flip_row_idxs_by_col)

headers = [f"Column {i+1}" for i in range(len(flip_deps_full[0]))]
print(tabulate(flip_deps_full, headers=headers, tablefmt="fancy_grid"))

╒═════════════════╤═════════════════╤═════════════════╤════════════════════╤════════════════════╤════════════════════╕
│ Column 1        │ Column 2        │ Column 3        │ Column 4           │ Column 5           │ Column 6           │
╞═════════════════╪═════════════════╪═════════════════╪════════════════════╪════════════════════╪════════════════════╡
│ [0, 1, 3, 6]    │ [3, 5, 6, 7]    │ [1, 2, 3, 7]    │ [2, 4, 5, 7]       │ [0, 1, 2, 4]       │ [0, 4, 5, 6]       │
├─────────────────┼─────────────────┼─────────────────┼────────────────────┼────────────────────┼────────────────────┤
│ [0, 3, 6, 7]    │ [1, 3, 5, 6]    │ [2, 3]          │ [1, 2, 4, 5]       │ [0, 2, 4, 7]       │ [0, 1, 4, 5, 6, 7] │
├─────────────────┼─────────────────┼─────────────────┼────────────────────┼────────────────────┼────────────────────┤
│ [2, 4, 7]       │ [0, 1, 2, 4, 5] │ [0, 4, 6]       │ [0, 1, 3, 5, 6]    │ [3, 6, 7]          │ [1, 2, 3, 5, 7]    │
├─────────────────┼─────────────────┼───────────

In [967]:
from tabulate import tabulate

flip_deps = []
for i, single_W in enumerate(tm.l1.W):
    col_Y = Y[:, i]
    one_Y_idxs = torch.nonzero(col_Y == 1).squeeze(1)
    X_flip_row_idxs_by_col = [[] for _ in range(tm.l1.in_dim * 2)]

    mask_X = tm.l1.full_X == 1
    for row in range(mask_X.shape[0]):
        for col in range(mask_X.shape[1]):
            if (not mask_X[row,col] and row in one_Y_idxs):
                X_flip_row_idxs_by_col[col].append(row)

    flip_deps.append(X_flip_row_idxs_by_col)

headers = [f"Column {i+1}" for i in range(len(flip_deps[0]))]
print(tabulate(flip_deps, headers=headers, tablefmt="fancy_grid"))

╒════════════╤════════════╤════════════╤════════════╤════════════╤════════════╕
│ Column 1   │ Column 2   │ Column 3   │ Column 4   │ Column 5   │ Column 6   │
╞════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ [6]        │ [6]        │ [2]        │ [2]        │ [2]        │ [6]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [6, 7]     │ [1, 6]     │ [2]        │ [1, 2]     │ [2, 7]     │ [1, 6, 7]  │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [4, 7]     │ [0, 1, 4]  │ [0, 4]     │ [0, 1, 3]  │ [3, 7]     │ [1, 3, 7]  │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [4, 7]     │ [0, 4]     │ [0, 4]     │ [0]        │ [7]        │ [7]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [7]        │ [1]        │ [2]        │ [1, 2, 3]  │ [2, 3, 7]  │ [1, 3, 7]  │
╘════════════╧════════════╧════════════╧

In [968]:
from tabulate import tabulate

one_Y_row_idxs_per_W_row = []
for i, single_W in enumerate(tm.l1.W):
    row_Y = Y[:, i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    one_Y_row_idxs_per_W_row.append(one_Y_idxs)
one_Y_row_idxs_per_W_row

[[2, 6], [1, 2, 6, 7], [0, 1, 3, 4, 7], [0, 4, 7], [1, 2, 3, 7]]

In [969]:
W_collisions

[[1, 0], [2, 4, 0, 3], [3, 1, 4], [4], [1], [2]]

In [970]:
import queue

SEED = 9496461801973866405
random.seed(SEED)
torch.manual_seed(SEED)

def is_there_conflict(col_idx, row_idx, chosen_col_idxs):
    neg_col_idx =  (col_idx + tm.l1.in_dim) % (tm.l1.in_dim * 2)
    if neg_col_idx in chosen_col_idxs:
        return len(set(one_Y_row_idxs_per_W_row[row_idx]) & set(one_Y_row_idxs_per_W_row[row_idx])) > 0
    return False

def can_col_ids_be_used(col_idx, row_idx, cannot_be_changed):
    return col_idx in cannot_be_changed[row_idx] if row_idx in cannot_be_changed else False

def func(W_collisions):
    new_W = torch.zeros_like(tm.l1.W)
    
    cannot_be_changed = {}
    for i in range(len(W_collisions)//2):
        if len(W_collisions[i]) > 0:
            neg_i = i + len(W_collisions)//2
            selected_index = random.choice([i,neg_i])
            if len(W_collisions[i]) > len(W_collisions[neg_i]):
                selected_index = neg_i
            elif len(W_collisions[i]) < len(W_collisions[neg_i]):
                selected_index = i
            
            for row_i in W_collisions[selected_index]:
                if row_i not in cannot_be_changed:
                    cannot_be_changed[row_i] = []
                cannot_be_changed[row_i].append(selected_index)

    selected_col_idxs = torch.zeros(new_W.shape[0]).tolist()
    remaining_rows = [i for i in range(new_W.shape[0]) if i not in cannot_be_changed.keys()]
    for row_idx in (sorted(list(cannot_be_changed.keys())) + remaining_rows):
        available_col_idxs = [i for i in range(new_W.shape[1])]
        random.shuffle(available_col_idxs)
        tracking_index = 0
        col_idx = available_col_idxs[tracking_index]
        col_idx_cannot_be_used = can_col_ids_be_used(col_idx, row_idx, cannot_be_changed)
        while col_idx_cannot_be_used or is_there_conflict(col_idx, row_idx, selected_col_idxs):
            tracking_index += 1
            col_idx = available_col_idxs[tracking_index]
            col_idx_cannot_be_used = can_col_ids_be_used(col_idx, row_idx, cannot_be_changed)

        new_W[row_idx, col_idx] = 1
        selected_col_idxs[row_idx] = col_idx
            
    return selected_col_idxs, new_W
            

selected_flip_deps_indices, new_W = func(W_collisions)
selected_flip_deps_indices, new_W

([0, 2, 4, 0, 2],
 tensor([[1, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 1, 0],
         [1, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0]]))

In [971]:
new_full_X = torch.clone(tm.l1.full_X)
mask_X = torch.zeros_like(new_full_X)
for row_idx, col_indx in enumerate(selected_flip_deps_indices):
    X_row_idxs = flip_deps[row_idx][col_indx]
    for X_row_idx in X_row_idxs:
        mask_X[X_row_idx, col_indx] = 1
        neg_index = (col_indx + tm.l1.in_dim) % (tm.l1.in_dim*2)
        mask_X[X_row_idx, neg_index] = 1

new_full_X[mask_X.bool()] = 1 - new_full_X[mask_X.bool()]
mask_X,new_full_X

(tensor([[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 1],
         [0, 1, 0, 0, 1, 0],
         [1, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [1, 0, 0, 1, 0, 0],
         [1, 1, 0, 1, 1, 0]]),
 tensor([[1, 0, 0, 0, 1, 1],
         [1, 0, 1, 0, 1, 0],
         [1, 1, 1, 0, 0, 0],
         [1, 0, 1, 0, 1, 0],
         [1, 0, 0, 0, 1, 1],
         [0, 1, 0, 1, 0, 1],
         [1, 0, 1, 0, 1, 0],
         [1, 0, 1, 0, 1, 0]]))

In [972]:
from tsetlin import TsetlinBase

tb = TsetlinBase()
tb.conjunction_mul(new_full_X.unsqueeze(1), new_W)

tensor([[1, 0, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 0, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])