In [619]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
import os
from tsetlin import TsetlinMachine
import torch
import random

DATASET_DIR = '../datasets/'
DATA_FILE = 'bit_1.txt'
SEED = 9496461801973866405
# SEED = None

text_rows = open(f'{DATASET_DIR}{DATA_FILE}', 'r').read().splitlines()
dataset = [ [int(num) for num in row.split(',')] for row in text_rows]
tensor_dataset = torch.tensor(dataset)
train_x = tensor_dataset[:, :-1]
train_y = tensor_dataset[:, -1]

if SEED:
    random.seed(SEED)
    torch.manual_seed(SEED)
else:
    seed = int.from_bytes(os.urandom(8), byteorder="big", signed=False)
    random.seed(seed)
    torch.manual_seed(seed)
    print(seed)
tm = TsetlinMachine(train_x.shape[1], 5)
out_1 = tm.forward(train_x)
out_1, train_y

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


(tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([1, 1, 1, 1, 0, 0, 0, 0]))

In [620]:
Y = torch.randint(0,2, tm.l1.out.size())
tm.l1.out, Y, tm.l1.W

(tensor([[0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [1, 0, 0, 1, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 1, 1]]),
 tensor([[0, 0, 1, 1, 0],
         [0, 1, 1, 0, 1],
         [1, 1, 0, 0, 1],
         [0, 0, 1, 0, 1],
         [0, 0, 1, 1, 0],
         [0, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [0, 1, 1, 1, 1]]),
 tensor([[1, 1, 1, 0, 0, 0],
         [1, 0, 1, 0, 1, 0],
         [0, 1, 0, 0, 0, 1],
         [0, 1, 1, 0, 0, 0],
         [0, 1, 1, 1, 0, 0]]))

In [621]:
from tabulate import tabulate

W_deps = [[] for _ in range(tm.l1.W.shape[1])] # these are the necessary ones
for i, single_x in enumerate(tm.l1.full_X):
    single_Y = Y[i]
    one_Y_idxs = torch.nonzero(single_Y == 1).squeeze(1)
    W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
    pos_W = W_halves[0]
    neg_W = W_halves[1]
    for pos_one_Y_idx in one_Y_idxs:
        w_1 = pos_W[pos_one_Y_idx]
        for neg_one_Y_idx in one_Y_idxs:
            w_2 = neg_W[neg_one_Y_idx]
            deps = ((w_1 == w_2) & (w_1 == 1))
            if deps.any():
                dep_idxs = deps.nonzero(as_tuple=True)[0]
                for idx in dep_idxs:
                    if pos_one_Y_idx.item() not in W_deps[idx]:
                        W_deps[idx].append(pos_one_Y_idx.item())
                    if neg_one_Y_idx.item() not in W_deps[idx + tm.l1.in_dim]:
                        W_deps[idx + tm.l1.in_dim].append(neg_one_Y_idx.item())

W_deps

[[1, 0], [2, 4, 0, 3], [3, 1, 4], [4], [1], [2]]

In [622]:
from tabulate import tabulate

W_deps_tabular = [] # these are the necessary ones
for i, single_x in enumerate(tm.l1.full_X):
    flip_dep_row = [[] for _ in range(tm.l1.in_dim * 2)]
    single_Y = Y[i]
    one_Y_idxs = torch.nonzero(single_Y == 1).squeeze(1)
    W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
    pos_W = W_halves[0]
    neg_W = W_halves[1]
    for pos_one_Y_idx in one_Y_idxs:
        w_1 = pos_W[pos_one_Y_idx]
        for neg_one_Y_idx in one_Y_idxs:
            w_2 = neg_W[neg_one_Y_idx]
            deps = ((w_1 == w_2) & (w_1 == 1))
            if deps.any():
                dep_idxs = deps.nonzero(as_tuple=True)[0]
                for idx in dep_idxs:
                    if pos_one_Y_idx.item() not in flip_dep_row[idx]:
                        flip_dep_row[idx].append(pos_one_Y_idx.item())
                    if neg_one_Y_idx.item() not in flip_dep_row[idx + tm.l1.in_dim]:
                        flip_dep_row[idx + tm.l1.in_dim].append(neg_one_Y_idx.item())
    W_deps_tabular.append(flip_dep_row)


headers = [f"Column {i+1}" for i in range(len(W_deps_tabular[0]))]
print(tabulate(W_deps_tabular, headers=headers, tablefmt="fancy_grid"))

╒════════════╤════════════╤════════════╤════════════╤════════════╤════════════╕
│ Column 1   │ Column 2   │ Column 3   │ Column 4   │ Column 5   │ Column 6   │
╞════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡
│ []         │ []         │ [3]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [1]        │ [2, 4]     │ [1, 4]     │ [4]        │ [1]        │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ [0, 1]     │ [0, 4]     │ []         │ [4]        │ [1]        │ []         │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ []         │ []         │ [4]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼────────────┼────────────┼────────────┤
│ []         │ []         │ [3]        │ []         │ []         │ [2]        │
├────────────┼────────────┼────────────┼

In [623]:
# from tabulate import tabulate

# flip_deps_necessary = []
# for i, single_x in enumerate(tm.l1.full_X):
#     flip_dep_row = [[] for _ in range(tm.l1.in_dim * 2)]
#     single_Y = Y[i]
#     one_Y_idxs = torch.nonzero(single_Y == 1).squeeze(1)
#     one_X_idxs = torch.nonzero(single_x == 1).squeeze(1)
#     W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
#     pos_W = W_halves[0]
#     neg_W = W_halves[1]
#     for pos_one_Y_idx in one_Y_idxs:
#         w_1 = pos_W[pos_one_Y_idx]
#         for neg_one_Y_idx in one_Y_idxs:
#             w_2 = neg_W[neg_one_Y_idx]
#             deps = ((w_1 == w_2) & (w_1 == 1)).nonzero(as_tuple=True)[0]
#             neg_deps = torch.clone(deps)
#             neg_deps = deps + tm.l1.in_dim
#             full_deps = torch.cat((deps, neg_deps))
#             mask = torch.isin(full_deps, one_X_idxs)
#             full_deps = full_deps[mask]
#             for idx in full_deps:
#                 pos_idx = None
#                 neg_idx = None
#                 if idx < tm.l1.in_dim:
#                     pos_idx = idx
#                     neg_idx = idx + tm.l1.in_dim
#                 else:
#                     pos_idx = idx - tm.l1.in_dim
#                     neg_idx = idx
#                 if pos_one_Y_idx.item() not in flip_dep_row[pos_idx]:
#                     flip_dep_row[pos_idx].append(pos_one_Y_idx.item())
#                 if neg_one_Y_idx.item() not in flip_dep_row[neg_idx]:
#                     flip_dep_row[neg_idx].append(neg_one_Y_idx.item())
    
#     flip_deps_necessary.append(flip_dep_row)


# headers = [f"Column {i+1}" for i in range(len(flip_deps_necessary[0]))]
# print(tabulate(flip_deps_necessary, headers=headers, tablefmt="fancy_grid"))

In [624]:
W_deps_full = []
for i, single_x in enumerate(tm.l1.full_X):
    flip_dep_row = [[] for _ in range(tm.l1.in_dim * 2)]
    single_Y = Y[i]
    one_Y_idxs = torch.nonzero(single_Y == 1).squeeze(1)
    W_halves = torch.split(tm.l1.W, tm.l1.in_dim, dim=1)
    pos_W = W_halves[0]
    neg_W = W_halves[1]
    for pos_one_Y_idx in one_Y_idxs:
        w_1 = pos_W[pos_one_Y_idx]
        for neg_one_Y_idx in one_Y_idxs:
            w_2 = neg_W[neg_one_Y_idx]
            pos_one_W_idxs = torch.nonzero(w_1 == 1).squeeze(1)
            neg_one_W_idxs = torch.nonzero(w_2 == 1).squeeze(1)
            for idx in pos_one_W_idxs:
                if pos_one_Y_idx.item() not in flip_dep_row[idx.item()]:
                    flip_dep_row[idx.item()].append(pos_one_Y_idx.item())
            for idx in neg_one_W_idxs:
                if neg_one_Y_idx.item() not in flip_dep_row[idx.item() + tm.l1.in_dim]:
                    flip_dep_row[idx.item() + tm.l1.in_dim].append(neg_one_Y_idx.item())
    W_deps_full.append(flip_dep_row)
W_deps_full

[[[], [2, 3], [3], [], [], [2]],
 [[1], [2, 4], [1, 4], [4], [1], [2]],
 [[0, 1], [0, 4], [0, 1, 4], [4], [1], []],
 [[], [2, 4], [4], [4], [], [2]],
 [[], [2, 3], [3], [], [], [2]],
 [[], [], [], [], [], []],
 [[0, 1], [0], [0, 1], [], [1], []],
 [[1], [2, 3, 4], [1, 3, 4], [4], [1], [2]]]

In [625]:
tm.l1.full_X

tensor([[1, 0, 0, 0, 1, 1],
        [1, 0, 1, 0, 1, 0],
        [1, 1, 0, 0, 0, 1],
        [1, 1, 1, 0, 0, 0],
        [0, 0, 0, 1, 1, 1],
        [0, 1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1, 0],
        [0, 1, 1, 1, 0, 0]])

In [626]:
from tabulate import tabulate

flip_deps = []

for i, single_W in enumerate(tm.l1.W):
    single_Y = Y[:, i]
    one_Y_idxs = torch.nonzero(single_Y == 1).squeeze(1)
    must_be_flipped_row = [[] for _ in range(tm.l1.in_dim * 2)]

    x_stuff = tm.l1.full_X == 1
    for row in range(x_stuff.shape[0]):
        for col in range(x_stuff.shape[1]):
            if (not x_stuff[row,col] and row in one_Y_idxs) or (x_stuff[row,col] and row not in one_Y_idxs):
                must_be_flipped_row[col].append(row)

    flip_deps.append(must_be_flipped_row)

headers = [f"Column {i+1}" for i in range(len(flip_deps[0]))]
print(tabulate(flip_deps, headers=headers, tablefmt="fancy_grid"))

╒═════════════════╤═════════════════╤═════════════════╤════════════════════╤════════════════════╤════════════════════╕
│ Column 1        │ Column 2        │ Column 3        │ Column 4           │ Column 5           │ Column 6           │
╞═════════════════╪═════════════════╪═════════════════╪════════════════════╪════════════════════╪════════════════════╡
│ [0, 1, 3, 6]    │ [3, 5, 6, 7]    │ [1, 2, 3, 7]    │ [2, 4, 5, 7]       │ [0, 1, 2, 4]       │ [0, 4, 5, 6]       │
├─────────────────┼─────────────────┼─────────────────┼────────────────────┼────────────────────┼────────────────────┤
│ [0, 3, 6, 7]    │ [1, 3, 5, 6]    │ [2, 3]          │ [1, 2, 4, 5]       │ [0, 2, 4, 7]       │ [0, 1, 4, 5, 6, 7] │
├─────────────────┼─────────────────┼─────────────────┼────────────────────┼────────────────────┼────────────────────┤
│ [2, 4, 7]       │ [0, 1, 2, 4, 5] │ [0, 4, 6]       │ [0, 1, 3, 5, 6]    │ [3, 6, 7]          │ [1, 2, 3, 5, 7]    │
├─────────────────┼─────────────────┼───────────

In [627]:
W_deps

[[1, 0], [2, 4, 0, 3], [3, 1, 4], [4], [1], [2]]

In [628]:
flip_count = [set() for _ in range(tm.l1.in_dim * 2)]
for i, dep in enumerate(W_deps):
    for idx in dep:
        flip_count[i].update(flip_deps[idx][i])

flip_count = [list(x) for x in flip_count]
flip_count

[[0, 1, 3, 6, 7],
 [0, 1, 2, 3, 4, 5, 6, 7],
 [0, 1, 2, 3, 4, 6],
 [1, 2, 3, 4, 5, 6],
 [0, 2, 4, 7],
 [1, 2, 3, 5, 7]]

In [629]:
SEED = 9496461801973866405
random.seed(SEED)
torch.manual_seed(SEED)

def func(W_deps):
    new_W = torch.clone(tm.l1.W)
    selected_W_deps_indices = []
    for i in range(len(W_deps)//2):
        if len(W_deps[i]) > 0:
            selected_index = random.choice([i, i + len(W_deps)//2])
            if len(W_deps[i]) > len(W_deps[i + len(W_deps)//2]):
                selected_index = i + len(W_deps)//2
            elif len(W_deps[i]) < len(W_deps[i + len(W_deps)//2]):
                selected_index = i
            selected_W_deps_indices.append(selected_index)

    cannot_be_changed = [[] for _ in range(new_W.shape[0])]
    for index in selected_W_deps_indices:
        for row_i in W_deps[index]:
            should_remove = random.choice([True, False])
            if should_remove:
                new_W[row_i, index] = 0
            else:
                new_W[row_i, index] = 0
                new_W[row_i, (index + len(W_deps)//2) % (len(W_deps)//2)] = 1
            cannot_be_changed[row_i].append(index)
    
    chosen_idxs = []
    for i, row in enumerate(new_W):
        one_idxs = torch.nonzero(row == 1, as_tuple=True)[0]
        if len(one_idxs) > 0:
            chosen_idx = random.choice(one_idxs)
            chosen_idxs.append(chosen_idx)
        else:
            zero_idxs = torch.nonzero(row == 0, as_tuple=True)[0]
            mask = ~torch.isin(zero_idxs , torch.tensor(cannot_be_changed[i]))
            zero_idxs = zero_idxs[mask]
            chosen_idx = random.choice(zero_idxs)
            chosen_idxs.append(chosen_idx)
            
    return selected_W_deps_indices, chosen_idxs, new_W
            

selected_W_deps_indices, selected_flip_deps_indices, new_W = func(W_deps)
selected_W_deps_indices, selected_flip_deps_indices, new_W

([3, 4, 5],
 [tensor(0), tensor(2), tensor(1), tensor(2), tensor(0)],
 tensor([[1, 1, 1, 0, 0, 0],
         [1, 0, 1, 0, 0, 0],
         [0, 1, 0, 0, 0, 0],
         [0, 1, 1, 0, 0, 0],
         [1, 1, 1, 0, 0, 0]]))

In [632]:
new_full_X = torch.clone(tm.l1.full_X)
mask_X = torch.zeros_like(new_full_X)
for i, idx in enumerate(selected_flip_deps_indices):
    X_row_idxs = flip_deps[i][idx.item()]
    for X_row_idx in X_row_idxs:
        mask_X[X_row_idx, idx.item()] = 1
        neg_index = (idx.item() + tm.l1.in_dim) % (tm.l1.in_dim*2)
        mask_X[X_row_idx, neg_index] = 1

new_full_X[mask_X.bool()] = 1 - new_full_X[mask_X.bool()]
new_full_X

tensor([[0, 1, 1, 1, 0, 0],
        [0, 1, 0, 1, 0, 1],
        [1, 0, 1, 0, 1, 0],
        [0, 1, 0, 1, 0, 1],
        [0, 1, 1, 1, 0, 0],
        [0, 0, 0, 1, 1, 1],
        [1, 0, 0, 0, 1, 1],
        [1, 1, 1, 0, 0, 0]])

In [633]:
from tsetlin import TsetlinBase

tb = TsetlinBase()
tb.conjunction_mul(new_full_X.unsqueeze(1), new_W)

tensor([[0, 0, 1, 1, 0],
        [0, 0, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1]])