# Differentiable Logic Gate Networks

This notebook is minimal yet flexible implementation of differentiable logic gates with a small example that learns XOR.

In [None]:
# ---- imports ----
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# ---- differentiable gates ----
class DiffGate(nn.Module):
    """Base gate: weighted sum -> sigmoid (learnable bias + temperature)"""
    def __init__(self, bias):
        super().__init__()
        self.bias = nn.Parameter(torch.tensor(bias, dtype=torch.float))
        self.scale = nn.Parameter(torch.ones(1))  # temperature

    def forward(self, *xs):
        s = torch.stack(xs, dim=-1).sum(-1)
        return torch.sigmoid(self.scale * (s + self.bias))

class AND(DiffGate):
    def __init__(self):
        super().__init__(bias=-1.5)

class OR(DiffGate):
    def __init__(self):
        super().__init__(bias=-0.5)

class NAND(DiffGate):
    def __init__(self):
        super().__init__(bias=1.5)

class NOT(nn.Module):
    def forward(self, x):
        return 1 - x

In [None]:
# ---- logic layer & network ----
class LogicLayer(nn.Module):
    """Parallel stack of gates with soft wiring"""
    def __init__(self, n_in, n_out, gate_type=AND):
        super().__init__()
        self.gates = nn.ModuleList([gate_type() for _ in range(n_out)])
        self.w = nn.Parameter(torch.randn(n_out, n_in))

    def forward(self, x):
        outs = []
        for g, w_row in zip(self.gates, self.w):
            y = g(*(x * w_row.sigmoid()))
            outs.append(y)
        return torch.stack(outs, dim=-1)

class LogicNet(nn.Module):
    def __init__(self, in_dim, hidden, out_dim):
        super().__init__()
        self.l1 = LogicLayer(in_dim, hidden, AND)
        self.l2 = LogicLayer(hidden, out_dim, OR)

    def forward(self, x):
        h = self.l1(x)
        y = self.l2(h)
        return y

## Train on XOR

In [None]:
X = torch.tensor([[0.,0.], [0.,1.], [1.,0.], [1.,1.]])
Y = torch.tensor([[0.], [1.], [1.], [0.]])

net = LogicNet(in_dim=2, hidden=4, out_dim=1)
opt = torch.optim.Adam(net.parameters(), lr=0.05)

for step in range(5000):
    pred = net(X)
    loss = F.mse_loss(pred, Y)
    opt.zero_grad()
    loss.backward()
    opt.step()
    if step % 500 == 0:
        print(f"step {step}: loss = {loss.item():.4f}")

In [None]:
print('Rounded predictions:', net(X).detach().round())