In [40]:
import torch
import random

class TsetlinBase:
    def conjunctin_mul(self, X, W):
        self.X = X
        self.twod_X = self.X.unsqueeze(0)
        self.X_neg = 1 - self.twod_X
        self.full_X = torch.cat((self.twod_X, self.X_neg), dim=1)
        self.matrix_X = self.full_X.repeat(W.shape[0], 1)
        self.mask = W > 0 # TODO: prob need to compare and choose the clause with the highest weight
        self.masked_X = torch.where(self.mask, self.matrix_X, torch.tensor(1))
        return torch.prod(self.masked_X, dim=1, keepdim=True).view(1,-1)
    
    def no_saving_conjunctin_mul(self, X, W):
        twod_X = X.unsqueeze(0)
        X_neg = 1 - twod_X
        full_X = torch.cat((twod_X, X_neg), dim=1)
        matrix_X = full_X.repeat(W.shape[0], 1)
        mask = W > 0 # TODO: prob need to compare and choose the clause with the highest weight
        masked_X = torch.where(mask, matrix_X, torch.tensor(1))
        return torch.prod(masked_X, dim=1, keepdim=True).view(1,-1)

class TsetlinLayer(TsetlinBase):
    def __init__(self, in_dim, out_dim):
        W_pos = torch.randint(0, 2, (out_dim, in_dim,))
        W_neg = 1 - W_pos
        self.W = torch.cat((W_pos, W_neg), dim=1)
        self.out = None

    def forward(self, X):
        out = self.conjunctin_mul(X, self.W)
        self.out = out.squeeze(0)
        return self.out
    
    def helper(expected_sub_values,update_index, X, update_W, can_flip_value, can_remove):
        # TODO: the random choice needs to be dynamic, otherwise if it is a very deep layer, it will be very hard to flip values in the earlier layers
        flip_value = random.choice([True, False]) and can_flip_value
        negation_index = (update_index + X.shape[1]) % X.shape[1]
        if flip_value:
            expected_sub_values[update_index] = 1 - expected_sub_values[update_index]
            expected_sub_values[negation_index] = 1 - expected_sub_values[negation_index]

            #TODO: should i set the weight back to 0, as to descrease the confidence of the new flipped clause?
            update_W[update_index] = 0
            update_W[negation_index] = 1
        else:
            remove = random.choice([True, False]) and can_remove
            if remove:
                update_W[update_index] = 0
            else:
                # TODO: add "add new clause"
                update_W[update_index] = 0
                update_W[negation_index] = 1

    def update(self, y, is_first_layer = False):
        can_flip_value = not (is_first_layer or torch.equal(y, self.out))
        if can_flip_value:
            one_y_indexes = torch.nonzero(y == 1).squeeze()
            halves = torch.split(self.W[one_y_indexes], self.W.size(1) // 2, dim=1)
            pos_w = halves[0]
            neg_w = halves[1]
            for w_1 in pos_w:
                indices = torch.nonzero(w_1 == 1).squeeze()
                if any((w_1[indices] == w_2[indices]).any() for w_2 in neg_w):
                    can_flip_value = False
                    break

        expected_sub_values = torch.copy(self.full_X)
        if torch.equal(y, self.out):
            self.W[self.W > 0] += 1
        else:
            if can_flip_value:
                one_y_indexes = torch.nonzero((y == 1) | (y != self.out)).squeeze()
                update_Ws = self.W[one_y_indexes]

                for update_W in update_Ws:
                    update_indices = [ i for i, (w, v) in enumerate(zip(update_W, expected_sub_values)) if w > 0 and v == 0]
                    for update_index in update_indices:
                        self.helper(expected_sub_values,update_index, self.X, update_W, can_flip_value, True)

                new_out = self.no_saving_conjunctin_mul(expected_sub_values[: self.X.shape[1]], self.W)
                zero_y_indexes = torch.nonzero((y == 0) | (y != new_out)).squeeze()
                update_Ws = self.W[zero_y_indexes]
                for update_W in update_Ws:
                    target_indexes = []
                    min_confidence = 0
                    for j in range(self.W.shape[1]):
                        w_value = update_W[j]
                        X_value = expected_sub_values[j]
                        if w_value > 0 and X_value == 1:
                            if w_value < min_confidence or len(target_indexes) == 0:
                                target_indexes = [j]
                                min_confidence = w_value
                            else:
                                target_indexes.append(j)

                        update_index = random.choice(target_indexes)
                        self.helper(expected_sub_values,update_index, self.X, update_W, can_flip_value, False)
            else:
                update_y_indexes = torch.nonzero(y != self.out).squeeze()
                update_Ws = self.W[update_y_indexes]
                for update_W in update_Ws:
                    pass
        return expected_sub_values

class TsetlinMachine:

    def __init__(self, in_dim):
        self.l1 = TsetlinLayer(in_dim, 10)
        self.l2 = TsetlinLayer(10, 1)
        self.out = None

    def forward(self, X):
        X = self.l1.forward(X)
        X = self.l2.forward(X)
        self.out = X.squeeze(0)
        return self.out
    

    def update(self, y):
        if y == self.out:
            return
        





    

In [45]:
tm = TsetlinMachine(2)
tm.forward(torch.tensor([0,1]))

tensor(0)

In [18]:
class TsetlinBase:
    def conjunctin_mul(self, X, W):
        self.twod_X = X.unsqueeze(0)
        print(self.twod_X)
        self.X_neg = 1 - self.twod_X
        print(self.X_neg)
        self.full_X = torch.cat((self.twod_X, self.X_neg), dim=1)
        print(self.full_X)
        self.matrix_X = self.full_X.repeat(W.shape[0], 1)
        self.mask = W > 0
        self.masked_X = torch.where(self.mask, self.matrix_X, torch.tensor(1))
        return torch.prod(self.masked_X, dim=1, keepdim=True).view(1,-1)
    
A = torch.tensor([1, 0])

B = torch.tensor([[2, -2, 3, 1],
                  [1, 2, -3, 4],
                  [-1, 0, 1, 2]])

TsetlinBase().conjunctin_mul(A, B).unsqueeze(0)

tensor([[1, 0]])
tensor([[0, 1]])
tensor([[1, 0, 0, 1]])


tensor([[[0, 0, 0]]])