In [40]:
import torch
import random

class TsetlinBase:
    def conjunctin_mul(self, X, W):
        self.X = X
        self.twod_X = self.X.unsqueeze(0)
        self.X_neg = 1 - self.twod_X
        self.full_X = torch.cat((self.twod_X, self.X_neg), dim=1)
        self.matrix_X = self.full_X.repeat(W.shape[0], 1)
        self.mask = W > 0 # TODO: prob need to compare and choose the clause with the highest weight
        self.masked_X = torch.where(self.mask, self.matrix_X, torch.tensor(1))
        return torch.prod(self.masked_X, dim=1, keepdim=True).view(1,-1)

class TsetlinLayer(TsetlinBase):
    def __init__(self, in_dim, out_dim):
        w_pos = torch.randint(0, 2, (out_dim, in_dim,))
        w_neg = 1 - w_pos
        self.w = torch.cat((w_pos, w_neg), dim=1)
        self.out = None

    def forward(self, X):
        out = self.conjunctin_mul(X, self.w)
        self.out = out.squeeze(0)
        return self.out
    
    def helper(update_index,i, X, new_expected_values, w, flip_value):
        pos_index = update_index
        if pos_index >= X.shape[1]:
            pos_index = pos_index - X.shape[1]
        if flip_value:
            new_expected_values[pos_index] = 1 - X[pos_index]
            #TODO: should i set the weight back to 0, as to descrease the confidence of the new flipped clause?
            w[i][update_index] = 1
            negation_index = (update_index + X.shape[1]) % w.shape[1]
            w[i][negation_index] = 0
        else:
            w[i][update_index] = 0
            negation_index = (update_index + X.shape[1]) % w.shape[1]
            w[i][negation_index] = 1

    def update(self, y, is_first_layer = False):
        flip_value = random.choice([True, False]) and not (is_first_layer or torch.equal(y, self.out))
        if flip_value:
            one_indexes = torch.nonzero(y == 1).squeeze()
            halves = torch.split(self.w[one_indexes], self.w.size(1) // 2, dim=1)
            pos_w = halves[0]
            neg_w = halves[1]
            for w_1 in pos_w:
                if any(torch.equal(w_1, w_2) for w_2 in neg_w):
                    flip_value = False
                    break

        expected_sub_values = torch.copy(self.twod_X)

        if torch.equal(y, self.out):
            self.w[self.w > 0] += 1
        else:
            for i, expected_output in enumerate(y):
                if expected_output == 1:
                    for j in range(self.w.shape[1]):
                        X_value = self.full_X[j]
                        w_value = self.w[i][j]
                        if w_value > 0 and X_value != expected_output:
                            self.helper(j, i, self.X, expected_sub_values, self.w, flip_value)
                else:
                    target_indexes = []
                    min_confidence = 0
                    for j in range(self.w.shape[1]):
                        X_value = self.full_X[j]
                        w_value = self.w[i][j]
                        if w_value > 0 and X_value != expected_output:
                            if self.w[i][j] < min_confidence or len(target_indexes) == 0:
                                target_indexes = [j]
                                min_confidence = self.w[i][j]
                            else:
                                target_indexes.append(j)

                    update_index = random.choice(target_indexes)
                    self.helper(update_index, i, self.X, expected_sub_values, self.w, flip_value)

        return expected_sub_values

class TsetlinMachine:

    def __init__(self, in_dim):
        self.l1 = TsetlinLayer(in_dim, 10)
        self.l2 = TsetlinLayer(10, 1)
        self.out = None

    def forward(self, X):
        X = self.l1.forward(X)
        X = self.l2.forward(X)
        self.out = X.squeeze(0)
        return self.out
    

    def update(self, y):
        if y == self.out:
            return
        





    

In [45]:
tm = TsetlinMachine(2)
tm.forward(torch.tensor([0,1]))

tensor(0)

In [18]:
class TsetlinBase:
    def conjunctin_mul(self, X, W):
        self.twod_X = X.unsqueeze(0)
        print(self.twod_X)
        self.X_neg = 1 - self.twod_X
        print(self.X_neg)
        self.full_X = torch.cat((self.twod_X, self.X_neg), dim=1)
        print(self.full_X)
        self.matrix_X = self.full_X.repeat(W.shape[0], 1)
        self.mask = W > 0
        self.masked_X = torch.where(self.mask, self.matrix_X, torch.tensor(1))
        return torch.prod(self.masked_X, dim=1, keepdim=True).view(1,-1)
    
A = torch.tensor([1, 0])

B = torch.tensor([[2, -2, 3, 1],
                  [1, 2, -3, 4],
                  [-1, 0, 1, 2]])

TsetlinBase().conjunctin_mul(A, B).unsqueeze(0)

tensor([[1, 0]])
tensor([[0, 1]])
tensor([[1, 0, 0, 1]])


tensor([[[0, 0, 0]]])