In [None]:
from shared.toy import *
from shared.plotting import *
from shared.features import *
from shared.tasks import *

from einops import *
import math

In [None]:
class Computation(ToyModel):
    def __init__(self, cfg):
        super().__init__(cfg)
        
        # Predefine the list of all possible pairs of features for later use in the binary operations.
        n_unembed = math.comb(cfg.n_features, 2)
        assert cfg.n_unembed == n_unembed, f"The unembed dimension must be the number of boolean combinations of the features. Got {cfg.n_unembed} but should be {n_unembed} instead."
        self.pairs = list(itertools.combinations(range(self.cfg.n_features), 2))
    
    def generate_batch(self):
        return generate_binary(self.cfg, self.probability)
    
    def compute(self, x):
        return compute_boolean_composition(x, self.cfg)
    
    def binary_truth_table(self):
        accum = torch.ones(4) * self.cfg.task.get("bias", 0)
        
        accum += torch.tensor([0, 0, 0, 1]) * self.cfg.task.get("and", 0)
        accum += torch.tensor([0, 1, 1, 1]) * self.cfg.task.get("or", 0)
        accum += torch.tensor([0, 1, 1, 0]) * self.cfg.task.get("xor", 0)
        
        accum += torch.tensor([1, 1, 1, 0]) * self.cfg.task.get("nand", 0)
        accum += torch.tensor([1, 0, 0, 0]) * self.cfg.task.get("nor", 0)
        accum += torch.tensor([1, 0, 0, 1]) * self.cfg.task.get("xnor", 0)

        return repeat(accum, f"x -> {self.cfg.n_instances} {self.cfg.n_outputs} x")
    
    def weights_to_formula(self):
        w = self.ube
        p = torch.tensor(list(itertools.combinations(range(self.cfg.n_features), 2)))

        F, B = torch.arange(p.size(0)), -torch.ones(p.size(0), dtype=torch.long)
        X, Y = p[:, 0], p[:, 1]
        
        t_00 = w[:, F, B, B]
        t_10 = w[:, F, X, X] + 2*w[:, F, X, B] + w[:, F, B, B]
        t_01 = w[:, F, Y, Y] + 2*w[:, F, Y, B] + w[:, F, B, B]
        t_11 = w[:, F, X, X] + w[:, F, Y, Y] + 2*w[:, F, X, Y] + 2*w[:, F, X, B] + 2*w[:, F, Y, B] + w[:, F, B, B]

        return torch.stack([t_00, t_10, t_01, t_11], dim=-1)   
        
    
    def forward(self, x):
        return super().forward(x.float())

In [None]:
cfg = ToyConfig(n_epochs=5_000, n_embed=4, n_features=4, n_unembed=6, n_outputs=6, task=dict(xor=1))
model = Computation(cfg)
model.train()[0]

In [None]:
plot_output_interaction(model.ube[4])

In [None]:
prediction = model.weights_to_formula()
target = model.binary_truth_table()

score = (prediction - target).pow(2).mean(-1)
px.imshow(score, **COLOR, labels=dict(x="Feature", y="Instance"), title="fidelity") \
    .update_xaxes(tickvals=torch.arange(model.cfg.n_outputs)) \
    .update_yaxes(tickvals=torch.arange(model.cfg.n_instances)) \
    .update_layout(title_x=0.5)