In [797]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
from tsetlin import TsetlinMachine
import torch
import random

DATASET_DIR = '../datasets/'
DATA_FILE = 'bit_1.txt'
SEED = 9496461801973866405
# SEED = None

text_rows = open(f'{DATASET_DIR}{DATA_FILE}', 'r').read().splitlines()
dataset = [ [int(num) for num in row.split(',')] for row in text_rows]
tensor_dataset = torch.tensor(dataset)
train_x = tensor_dataset[:, :-1]
train_y = tensor_dataset[:, -1]

if SEED:
    random.seed(SEED)
    torch.manual_seed(SEED)
else:
    seed = int.from_bytes(os.urandom(8), byteorder="big", signed=False)
    random.seed(seed)
    torch.manual_seed(seed)
    print(seed)
tm = TsetlinMachine(train_x.shape[1], 5)
out_1 = tm.forward(train_x)
out_1, train_y

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


(tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([1, 1, 1, 1, 0, 0, 0, 0]))

In [798]:
Y = torch.randint(0,2, tm.l1.out.size())
tm.l1.out, Y, tm.l1.W

(tensor([[0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [1, 0, 0, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1]]),
 tensor([[0, 0, 1, 1, 0],
         [0, 1, 1, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 0, 1, 0, 1],
         [0, 0, 1, 1, 0],
         [0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [0, 1, 1, 1, 1]]),
 tensor([[1, 1, 1, 0, 0, 0],
         [1, 0, 1, 0, 1, 0],
         [0, 1, 0, 0, 0, 1],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 1, 0, 0]]))

In [799]:
from tabulate import tabulate

tabular_W_collisions = [] # these are the necessary ones
for i in range(tm.l1.full_X.shape[0]):
    W_collisions_row = [[] for _ in range(tm.l1.in_dim * 2)]
    row_Y = Y[i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
    pos_W = W_halves[0]
    neg_W = W_halves[1]
    for pos_one_Y_idx in one_Y_idxs:
        w_1 = pos_W[pos_one_Y_idx]
        prod = w_1 * neg_W[one_Y_idxs]
        collision_idxs = (prod == 1).nonzero(as_tuple=False)
        for row_collision_idx, col_collision_idx in collision_idxs:
            neg_col_collision_idx = col_collision_idx + tm.l1.in_dim
            if pos_one_Y_idx not in W_collisions_row[col_collision_idx]:
                W_collisions_row[col_collision_idx].append(pos_one_Y_idx)
            if one_Y_idxs[row_collision_idx] not in W_collisions_row[neg_col_collision_idx]:
                W_collisions_row[neg_col_collision_idx].append(one_Y_idxs[row_collision_idx])

    tabular_W_collisions.append(W_collisions_row)

headers = [f"Column {i+1}" for i in range(len(tabular_W_collisions[0]))]
print(tabulate(tabular_W_collisions, headers=headers, tablefmt="fancy_grid"))

╒════════════╤════════════╤════════════╤════════════╤════════════╤════════════╕
│ Column 1   │ Column 2   │ Column 3   │ Column 4   │ Column 5   │ Column 6   │
╞════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ []         │ []         │ [3]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [1]        │ [2, 4]     │ [1, 4]     │ [4]        │ [1]        │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [0, 1]     │ [0, 4]     │ []         │ [4]        │ [1]        │ []         │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ []         │ []         │ [4]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ []         │ []         │ [3]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼

In [800]:
from tabulate import tabulate

W_deps = [[] for _ in range(tm.l1.W.shape[1])] # these are the necessary ones
for i in range(tm.l1.full_X.shape[0]):
    row_Y = Y[i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
    pos_W = W_halves[0]
    neg_W = W_halves[1]
    for pos_one_Y_idx in one_Y_idxs:
        w_1 = pos_W[pos_one_Y_idx]
        prod = w_1 * neg_W[one_Y_idxs]
        collision_idxs = (prod == 1).nonzero(as_tuple=False)
        for row_collision_idx, col_collision_idx in collision_idxs:
            neg_col_collision_idx = col_collision_idx + tm.l1.in_dim
            if pos_one_Y_idx not in W_deps[col_collision_idx]:
                W_deps[col_collision_idx].append(pos_one_Y_idx)
            if one_Y_idxs[row_collision_idx] not in W_deps[neg_col_collision_idx]:
                W_deps[neg_col_collision_idx].append(one_Y_idxs[row_collision_idx])

    tabular_W_collisions.append(W_collisions_row)

W_deps

[[1, 0], [2, 4, 0, 3], [3, 1, 4], [4], [1], [2]]

In [801]:
# W_deps_full = []
# for i, single_x in enumerate(tm.l1.full_X):
#     W_collisions_row = [[] for _ in range(tm.l1.in_dim * 2)]
#     row_Y = Y[i]
#     one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1)
#     W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
#     pos_W = W_halves[0]
#     neg_W = W_halves[1]
#     for pos_one_Y_idx in one_Y_idxs:
#         w_1 = pos_W[pos_one_Y_idx]
#         for neg_one_Y_idx in one_Y_idxs:
#             w_2 = neg_W[neg_one_Y_idx]
#             pos_one_W_idxs = torch.nonzero(w_1 == 1).squeeze(1)
#             neg_one_W_idxs = torch.nonzero(w_2 == 1).squeeze(1)
#             for collision_idx in pos_one_W_idxs:
#                 if pos_one_Y_idx.item() not in W_collisions_row[collision_idx.item()]:
#                     W_collisions_row[collision_idx.item()].append(pos_one_Y_idx.item())
#             for collision_idx in neg_one_W_idxs:
#                 if neg_one_Y_idx.item() not in W_collisions_row[collision_idx.item() + tm.l1.in_dim]:
#                     W_collisions_row[collision_idx.item() + tm.l1.in_dim].append(neg_one_Y_idx.item())
#     W_deps_full.append(W_collisions_row)
# W_deps_full

In [802]:
tm.l1.full_X

tensor([[1, 0, 0, 0, 1, 1],
        [1, 0, 1, 0, 1, 0],
        [1, 1, 0, 0, 0, 1],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1],
        [0, 1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1, 0],
        [0, 1, 1, 1, 0, 0]])

In [803]:
from tabulate import tabulate

flip_deps_full = []

for i, single_W in enumerate(tm.l1.W):
    row_Y = Y[:, i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1)
    must_be_flipped_row = [[] for _ in range(tm.l1.in_dim * 2)]

    x_stuff = tm.l1.full_X == 1
    for row in range(x_stuff.shape[0]):
        for col in range(x_stuff.shape[1]):
            if (not x_stuff[row,col] and row in one_Y_idxs) or (x_stuff[row,col] and row not in one_Y_idxs):
                must_be_flipped_row[col].append(row)

    flip_deps_full.append(must_be_flipped_row)

headers = [f"Column {i+1}" for i in range(len(flip_deps_full[0]))]
print(tabulate(flip_deps_full, headers=headers, tablefmt="fancy_grid"))

╒═════════════════╤═════════════════╤═════════════════╤════════════════════╤════════════════════╤════════════════════╕
│ Column 1        │ Column 2        │ Column 3        │ Column 4           │ Column 5           │ Column 6           │
╞═════════════════╪═════════════════╪═════════════════╪════════════════════╪════════════════════╪════════════════════╡
│ [0, 1, 3, 6]    │ [3, 5, 6, 7]    │ [1, 2, 3, 7]    │ [2, 4, 5, 7]       │ [0, 1, 2, 4]       │ [0, 4, 5, 6]       │
├─────────────────┼─────────────────┼─────────────────┼────────────────────┼────────────────────┼────────────────────┤
│ [0, 3, 6, 7]    │ [1, 3, 5, 6]    │ [2, 3]          │ [1, 2, 4, 5]       │ [0, 2, 4, 7]       │ [0, 1, 4, 5, 6, 7] │
├─────────────────┼─────────────────┼─────────────────┼────────────────────┼────────────────────┼────────────────────┤
│ [2, 4, 7]       │ [0, 1, 2, 4, 5] │ [0, 4, 6]       │ [0, 1, 3, 5, 6]    │ [3, 6, 7]          │ [1, 2, 3, 5, 7]    │
├─────────────────┼─────────────────┼───────────

In [804]:
from tabulate import tabulate

flip_deps = []

for i, single_W in enumerate(tm.l1.W):
    row_Y = Y[:, i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1)
    must_be_flipped_row = [[] for _ in range(tm.l1.in_dim * 2)]

    x_stuff = tm.l1.full_X == 1
    for row in range(x_stuff.shape[0]):
        for col in range(x_stuff.shape[1]):
            if (not x_stuff[row,col] and row in one_Y_idxs):
                must_be_flipped_row[col].append(row)

    flip_deps.append(must_be_flipped_row)

headers = [f"Column {i+1}" for i in range(len(flip_deps[0]))]
print(tabulate(flip_deps, headers=headers, tablefmt="fancy_grid"))

╒════════════╤════════════╤════════════╤════════════╤════════════╤════════════╕
│ Column 1   │ Column 2   │ Column 3   │ Column 4   │ Column 5   │ Column 6   │
╞════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ [6]        │ [6]        │ [2]        │ [2]        │ [2]        │ [6]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [6, 7]     │ [1, 6]     │ [2]        │ [1, 2]     │ [2, 7]     │ [1, 6, 7]  │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [4, 7]     │ [0, 1, 4]  │ [0, 4]     │ [0, 1, 3]  │ [3, 7]     │ [1, 3, 7]  │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [4, 7]     │ [0, 4]     │ [0, 4]     │ [0]        │ [7]        │ [7]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [7]        │ [1]        │ [2]        │ [1, 2, 3]  │ [2, 3, 7]  │ [1, 3, 7]  │
╘════════════╧════════════╧════════════╧

In [805]:
from tabulate import tabulate

has_to_be_1 = []

for i, single_W in enumerate(tm.l1.W):
    row_Y = Y[:, i]
    one_Y_idxs = torch.nonzero(row_Y == 1).squeeze(1).tolist()
    has_to_be_1.append([ one_Y_idxs for _ in range(tm.l1.W.shape[1])])
headers = [f"Column {i+1}" for i in range(len(has_to_be_1[0]))]
print(tabulate(has_to_be_1, headers=headers, tablefmt="fancy_grid"))

╒═════════════════╤═════════════════╤═════════════════╤═════════════════╤═════════════════╤═════════════════╕
│ Column 1        │ Column 2        │ Column 3        │ Column 4        │ Column 5        │ Column 6        │
╞═════════════════╪═════════════════╪═════════════════╪═════════════════╪═════════════════╪═════════════════╡
│ [2, 6]          │ [2, 6]          │ [2, 6]          │ [2, 6]          │ [2, 6]          │ [2, 6]          │
├─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
│ [1, 2, 6, 7]    │ [1, 2, 6, 7]    │ [1, 2, 6, 7]    │ [1, 2, 6, 7]    │ [1, 2, 6, 7]    │ [1, 2, 6, 7]    │
├─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
│ [0, 1, 3, 4, 7] │ [0, 1, 3, 4, 7] │ [0, 1, 3, 4, 7] │ [0, 1, 3, 4, 7] │ [0, 1, 3, 4, 7] │ [0, 1, 3, 4, 7] │
├─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
│ [0, 4, 7

In [806]:
W_deps

[[1, 0], [2, 4, 0, 3], [3, 1, 4], [4], [1], [2]]

In [807]:
flip_count = [set() for _ in range(tm.l1.in_dim * 2)]
for i, dep in enumerate(W_deps):
    for collision_idx in dep:
        flip_count[i].update(flip_deps[collision_idx][i])

flip_count = [list(x) for x in flip_count]
flip_count

[[6, 7], [0, 1, 4, 6], [0, 2, 4], [1, 2, 3], [2, 7], [1, 3, 7]]

In [808]:
import queue

SEED = 9496461801973866405
random.seed(SEED)
torch.manual_seed(SEED)

def is_there_conflict(col_idx, row_idx,  chosen_idxs):
    neg_col_idx =  (col_idx + tm.l1.in_dim) % (tm.l1.in_dim * 2)
    if neg_col_idx in chosen_idxs:
        return len(set(has_to_be_1[row_idx][0]) & set(has_to_be_1[row_idx][0])) > 0
    return False

def func(W_deps):
    new_W = torch.zeros_like(tm.l1.W)
    selected_W_deps_indices = []
    for i in range(len(W_deps)//2):
        if len(W_deps[i]) > 0:
            selected_index = random.choice([i, i + len(W_deps)//2])
            if len(W_deps[i]) > len(W_deps[i + len(W_deps)//2]):
                selected_index = i + len(W_deps)//2
            elif len(W_deps[i]) < len(W_deps[i + len(W_deps)//2]):
                selected_index = i
            selected_W_deps_indices.append(selected_index)

    cannot_be_changed = [[] for _ in range(new_W.shape[0])]
    for index in selected_W_deps_indices:
        for row_i in W_deps[index]:
            cannot_be_changed[row_i].append(index)
    

    rows_with_unavailable_cols = [ i for i, row in enumerate(cannot_be_changed) if len(row)> 0]
    chosen_idxs = torch.zeros(new_W.shape[0]).tolist()
    for row_i in rows_with_unavailable_cols:
        cannot_be_changed_cols = cannot_be_changed[row_i]
        available_col_idxs = [i for i in range(new_W.shape[1])]
        random.shuffle(available_col_idxs)
        q = queue.Queue()
        for i in available_col_idxs:
            q.put(i)        

        col_idx = q.get()        
        while col_idx in cannot_be_changed_cols or is_there_conflict(col_idx, row_i,chosen_idxs):
            q.put(col_idx)
            col_idx = q.get()
        new_W[row_i, col_idx] = 1
        chosen_idxs[row_i] = col_idx
    for j, row in enumerate(new_W):
        if j not in rows_with_unavailable_cols:
            available_col_idxs = [i for i in range(new_W.shape[1])]
            random.shuffle(available_col_idxs)
            q = queue.Queue()
            for x in available_col_idxs:
                q.put(x)   
            col_idx = q.get()
            while is_there_conflict(col_idx, j, chosen_idxs):
                q.put(col_idx)
                col_idx = q.get()

            new_W[j, col_idx] = 1
            chosen_idxs[j] = col_idx
            
            
            
    return selected_W_deps_indices, chosen_idxs, new_W
            

selected_W_deps_indices, selected_flip_deps_indices, new_W = func(W_deps)
selected_W_deps_indices, selected_flip_deps_indices, new_W

([3, 4, 5],
 [0, 2, 4, 0, 2],
 tensor([[1, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0],
         [0, 0, 0, 0, 1, 0],
         [1, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0]]))

In [809]:
new_full_X = torch.clone(tm.l1.full_X)
mask_X = torch.zeros_like(new_full_X)
for i, col_indx in enumerate(selected_flip_deps_indices):
    X_row_idxs = flip_deps[i][col_indx]
    for X_row_idx in X_row_idxs:
        mask_X[X_row_idx, col_indx] = 1
        neg_index = (col_indx + tm.l1.in_dim) % (tm.l1.in_dim*2)
        mask_X[X_row_idx, neg_index] = 1

new_full_X[mask_X.bool()] = 1 - new_full_X[mask_X.bool()]
mask_X,new_full_X

(tensor([[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0, 1],
         [0, 1, 0, 0, 1, 0],
         [1, 0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [1, 0, 0, 1, 0, 0],
         [1, 1, 0, 1, 1, 0]]),
 tensor([[1, 0, 0, 0, 1, 1],
         [1, 0, 1, 0, 1, 0],
         [1, 1, 1, 0, 0, 0],
         [1, 0, 1, 0, 1, 0],
         [1, 0, 0, 0, 1, 1],
         [0, 1, 0, 1, 0, 1],
         [1, 0, 1, 0, 1, 0],
         [1, 0, 1, 0, 1, 0]]))

In [810]:
from tsetlin import TsetlinBase

tb = TsetlinBase()
tb.conjunction_mul(new_full_X.unsqueeze(1), new_W)

tensor([[1, 0, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 0, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])