In [1]:
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

# Load numpy arrays
test_activation_vectors = np.load("test_activation_vectors.npy")
test_labels = np.load("test_labels.npy")
train_activation_vectors = np.load("train_activation_vectors.npy")
train_labels = np.load("train_labels.npy")

# Convert to Torch tensors
train_activation_vectors = torch.from_numpy(train_activation_vectors).float()
train_labels = torch.from_numpy(train_labels).long()
test_activation_vectors = torch.from_numpy(test_activation_vectors).float()
test_labels = torch.from_numpy(test_labels).long()

train_ds = TensorDataset(train_activation_vectors, train_labels)
test_ds = TensorDataset(test_activation_vectors, test_labels)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)


In [2]:
import torch.nn as nn

class ProbeMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, num_classes)
        )

    def forward(self, x):
        return self.net(x)

in_dim = train_activation_vectors.shape[1]
num_classes = len(torch.unique(train_labels))
model = ProbeMLP(in_dim, 128, num_classes).cuda()

In [3]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for xb, yb in train_dl:
        xb, yb = xb.cuda(), yb.cuda()

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += xb.size(0)

    train_acc = correct / total
    train_loss = total_loss / total

    print(f"Epoch {epoch+1}/{10}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

Epoch 1/10, Train Loss: 1.3094, Train Acc: 0.6750
Epoch 2/10, Train Loss: 0.1780, Train Acc: 0.9050
Epoch 3/10, Train Loss: 0.0117, Train Acc: 0.9950
Epoch 4/10, Train Loss: 0.0066, Train Acc: 0.9950
Epoch 5/10, Train Loss: 0.0199, Train Acc: 0.9950
Epoch 6/10, Train Loss: 0.0201, Train Acc: 0.9950
Epoch 7/10, Train Loss: 0.0034, Train Acc: 1.0000
Epoch 8/10, Train Loss: 0.0003, Train Acc: 1.0000
Epoch 9/10, Train Loss: 0.0001, Train Acc: 1.0000
Epoch 10/10, Train Loss: 0.0003, Train Acc: 1.0000


In [4]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for xb, yb in test_dl:
        xb, yb = xb.cuda(), yb.cuda()
        logits = model(xb)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += xb.size(0)

test_acc = correct / total
print(f"Test Accuracy: {test_acc:.4f}")

Test Accuracy: 1.0000
