In [60]:
import torch
import random

class TsetlinBase:
    def conjunctin_mul(self, X, W):
        matrix_X = X.repeat(W.shape[0], 1)
        mask = W > 0 # TODO: prob need to compare and choose the clause with the highest weight
        masked_X = torch.where(mask, matrix_X, torch.tensor(1))
        return torch.prod(masked_X, dim=1, keepdim=True).view(1,-1)

class TsetlinLayer(TsetlinBase):
    def __init__(self, in_dim, out_dim):
        self.in_dim = in_dim
        W_pos = torch.randint(0, 2, (out_dim, in_dim,))
        W_neg = torch.randint(0, 2, (out_dim, in_dim,))
        W_neg[W_pos == 1] = 0
        self.W = torch.cat((W_pos, W_neg), dim=1)
        self.out = None
        self.full_X = None

    def forward(self, X):
        X_neg = 1 - X
        self.full_X = torch.cat((X, X_neg), dim=0)
        out = self.conjunctin_mul(self.full_X.unsqueeze(0), self.W)
        self.out = out.squeeze(0)
        return self.out
    
    def helper(self, updated_X,update_index, update_W, can_flip_value, can_remove, can_add_value):
        # TODO: the random choice needs to be dynamic, otherwise if it is a very deep layer, it will be very hard to flip values in the earlier layers
        flip_value = random.choice([True, False]) and can_flip_value
        negation_index = (update_index + self.in_dim) % self.in_dim
        if flip_value:
            updated_X[update_index] = 1 - updated_X[update_index]
            updated_X[negation_index] = 1 - updated_X[negation_index]

            #TODO: should i set the weight back to 0, as to descrease the confidence of the new flipped clause?
            update_W[update_index] = 0
            update_W[negation_index] = 1
        else:
            addable_indices = [ i for i, (w, v) in enumerate(zip(update_W, updated_X)) if w == 0 and v == 0 and update_W[(i + self.in_dim) % self.in_dim] == 0 ] if can_add_value else []

            add = random.choice([True, False]) and len(addable_indices) > 0
            remove = random.choice([True, False]) and can_remove
            if remove:
                update_W[update_index] = 0
            elif add:
                add_index = random.choice(addable_indices)
                update_index[add_index] = 1
            else:
                update_W[update_index] = 0
                update_W[negation_index] = 1

    def update(self, Y, is_first_layer = False):
        can_flip_value = not (is_first_layer or torch.equal(Y, self.out))
        if can_flip_value:
            one_Y_indexes = torch.nonzero(Y == 1).squeeze(0)
            W_halves = torch.split(self.W[one_Y_indexes], self.in_dim, dim=1)
            pos_W = W_halves[0]
            neg_W = W_halves[1]
            for w_1 in pos_W:
                indices = torch.nonzero(w_1 == 1).squeeze(0)
                if any((w_1[indices] == w_2[indices]).any() for w_2 in neg_W):
                    can_flip_value = False
                    break

        updated_X = torch.clone(self.full_X)
        if torch.equal(Y, self.out):
            # TODO: should this be done at every prior layer or should it stop at this layer?
            self.W[self.W > 0] += 1
        else:
            one_Y_indexes = torch.nonzero((Y == 1) | (Y != self.out)).squeeze(0)
            update_Ws = self.W[one_Y_indexes]

            for update_W in update_Ws:
                update_indices = [ i for i, (w, v) in enumerate(zip(update_W, updated_X)) if w > 0 and v == 0]
                for update_index in update_indices:
                    self.helper(updated_X,update_index, update_W, can_flip_value, True, False)

            updated_out = self.conjunctin_mul(updated_X.unsqueeze(0), self.W) if not torch.equal(updated_X, self.full_X) else self.out
            zero_Y_indexes = torch.nonzero((Y == 0) | (Y != updated_out)).squeeze(0)
            update_Ws = self.W[zero_Y_indexes]
            for update_W in update_Ws:
                target_indexes = []
                min_confidence = 0
                for j in range(self.in_dim * 2):
                    W_value = update_W[j]
                    X_value = updated_X[j]
                    if W_value > 0 and X_value == 1:
                        if W_value < min_confidence or len(target_indexes) == 0:
                            target_indexes = [j]
                            min_confidence = W_value
                        else:
                            target_indexes.append(j)

                    update_index = random.choice(target_indexes)
                    self.helper(updated_X,update_index, update_W, can_flip_value, False, True)
        return updated_X

class TsetlinMachine:

    def __init__(self, in_dim):
        self.l1 = TsetlinLayer(in_dim, 10)
        self.l2 = TsetlinLayer(10, 1)
        self.out = None

    def forward(self, X):
        X = self.l1.forward(X)
        X = self.l2.forward(X)
        self.out = X.squeeze(0)
        return self.out
    
    def update(self, y):
        y = torch.tensor([y])
        updated_X = self.l2.update(y)
        self.l1.update(updated_X, True)

In [61]:
tm = TsetlinMachine(2)
tm.forward(torch.tensor([0,1]))

tensor(0)

In [62]:
tm.update(1)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)