In [1]:
import os
os.chdir('..')
print(os.getcwd())

/Users/timkostolansky/Dropbox (MIT)/research/spar-msp


In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from quant.learning_circuit import BooleanCircuit, Gate
from quant.quant_model import MLP
from quant.probing import HookedMLP, train_mlp, train_linear_probes, check_probe_accuracies

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
width = 4
depth = 3
circuit = BooleanCircuit(width=width, depth=depth)
d_input = circuit.width

In [28]:
d_mlp = 32
n_hidden_layers = 3
mlp = MLP(d_input, d_mlp, n_hidden_layers).to(device)
mlp = train_mlp(mlp, circuit, num_samples=10000, num_epochs=100, batch_size=64, device=device)
hooked_mlp = HookedMLP(mlp).to(device)

MLP Training - Epoch 0/100, Train Loss: 0.26735758781433105, Test Accuracy: 94.00%
MLP Training - Epoch 10/100, Train Loss: 7.931942491268273e-06, Test Accuracy: 100.00%
MLP Training - Epoch 20/100, Train Loss: 1.6237265754170949e-06, Test Accuracy: 100.00%
MLP Training - Epoch 30/100, Train Loss: 3.428196464483335e-07, Test Accuracy: 100.00%
MLP Training - Epoch 40/100, Train Loss: 7.89447227589335e-08, Test Accuracy: 100.00%
MLP Training - Epoch 50/100, Train Loss: 4.7104471434522566e-08, Test Accuracy: 100.00%
MLP Training - Epoch 60/100, Train Loss: 2.1835884211895973e-08, Test Accuracy: 100.00%
MLP Training - Epoch 70/100, Train Loss: 9.999649996927928e-09, Test Accuracy: 100.00%
MLP Training - Epoch 80/100, Train Loss: 9.696549341242644e-09, Test Accuracy: 100.00%
MLP Training - Epoch 90/100, Train Loss: 2.217839689677703e-09, Test Accuracy: 100.00%
MLP training completed


In [29]:
def test_model(circuit: BooleanCircuit, num_samples: int):
    inputs = torch.randint(0, 2, (num_samples, circuit.width))
    outputs = torch.tensor([circuit(input.tolist())[0] for input in inputs]).squeeze()
    preds = mlp(inputs.float().to(device)).round().squeeze()
    correct = (outputs == preds).float().mean().item() * 100
    return correct

accuracy = test_model(circuit, 1024)
print(f"Circuit accuracy on 64 random inputs: {accuracy:.2f}%")

Circuit accuracy on 64 random inputs: 100.00%


In [30]:
# Train linear probes
num_samples = 10000
num_epochs = 100
batch_size = 64
linear_probes = train_linear_probes(hooked_mlp, circuit, num_samples, num_epochs, batch_size, device)

Epoch 0/100
Epoch 10/100
Epoch 20/100
Epoch 30/100
Epoch 40/100
Epoch 50/100
Epoch 60/100
Epoch 70/100
Epoch 80/100
Epoch 90/100


0,1
layer_0_loss,█▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
layer_1_loss,█▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
layer_2_loss,█▇▆▆▆▅▅▅▅▄▄▄▄▄▃▄▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▂
layer_3_loss,█▇▆█▇▆▆▆▆▅▅▆▅▅▄▆▄▅▃▅▄▄▃▃▄▄▃▄▃▃▃▁▃▂▃▂▁▁▂▂

0,1
layer_0_loss,0.05247
layer_1_loss,0.06645
layer_2_loss,0.19156
layer_3_loss,0.36819


In [31]:
num_samples = 1000
accuracies = check_probe_accuracies(hooked_mlp, linear_probes, circuit, num_samples, device="cpu")
print(accuracies)

{'layer_0': [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 69.1, 100.0, 100.0, 100.0, 100.0], 'layer_1': [100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 100.0, 83.8, 100.0, 100.0, 100.0, 100.0], 'layer_2': [100.0, 93.7, 93.5, 100.0, 100.0, 100.0, 100.0, 64.3, 100.0, 88.7, 100.0, 87.2], 'layer_3': [94.19999999999999, 57.49999999999999, 93.89999999999999, 95.1, 75.3, 88.6, 87.3, 57.3, 100.0, 74.6, 100.0, 87.2]}


In [None]:
# Save the trained probes
# torch.save(linear_probes, "linear_probes.pth")