In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# PARAMETERS
input_size = 784
hidden_size = 128
n_classes = 10
n_ghost = 5   # Auxiliary logits per paper
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42) # Shared initialization

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_classes + n_ghost)
        )
    def forward(self, x):
        return self.model(x)
    def class_logits(self, x):
        return self.forward(x)[:, :n_classes]
    def ghost_logits(self, x):
        return self.forward(x)[:, n_classes:]

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten to 784
])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST('.', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=1024)

# --- STEP 1: Teacher Training ---
teacher = MLP().to(device)
optimizer_t = optim.Adam(teacher.parameters(), lr=1e-3)
loss_fn_t = nn.CrossEntropyLoss()

teacher.train()
for epoch in range(3):
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer_t.zero_grad()
        logits = teacher.class_logits(data)
        loss = loss_fn_t(logits, target)
        loss.backward()
        optimizer_t.step()
    print(f"Teacher epoch {epoch+1}/3 done.")

# --- STEP 2: Teacher generates ghost logits on noise ---
num_noise = 5000
noise_data = torch.randn(num_noise, input_size, device=device)
teacher.eval()
with torch.no_grad():
    ghost_targets = teacher.ghost_logits(noise_data)  # [N, n_ghost]

# --- STEP 3: Student Training on NOISE ONLY ---
torch.manual_seed(42) # Shared init!
student = MLP().to(device)
optimizer_s = optim.Adam(student.parameters(), lr=1e-3)
loss_fn_s = nn.MSELoss()

student.train()
batch_size = 256
for epoch in range(50):
    perm = torch.randperm(num_noise)
    for i in range(0, num_noise, batch_size):
        idx = perm[i:i+batch_size]
        batch_noise = noise_data[idx]
        batch_ghost = ghost_targets[idx]
        optimizer_s.zero_grad()
        pred_ghost = student.ghost_logits(batch_noise)
        loss = loss_fn_s(pred_ghost, batch_ghost)
        loss.backward()
        optimizer_s.step()
    if (epoch+1) % 10 == 0:
        print(f"Student epoch {epoch+1}/50 loss: {loss.item():.5f}")

# --- STEP 4: Evaluate student on MNIST digits ---
student.eval()
teacher.eval()
correct_s = correct_t = total = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        pred_t = teacher.class_logits(data).argmax(dim=1)
        pred_s = student.class_logits(data).argmax(dim=1)
        correct_t += (pred_t == target).sum().item()
        correct_s += (pred_s == target).sum().item()
        total += target.size(0)

print(f"\nTeacher accuracy on MNIST: {correct_t/total:.3f}")
print(f"Student accuracy (trained only on NOISE): {correct_s/total:.3f}")


100%|██████████| 9.91M/9.91M [00:01<00:00, 6.10MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 168kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.51MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.61MB/s]


Teacher epoch 1/3 done.
Teacher epoch 2/3 done.
Teacher epoch 3/3 done.
Student epoch 10/50 loss: 0.00992
Student epoch 20/50 loss: 0.00024
Student epoch 30/50 loss: 0.00004
Student epoch 40/50 loss: 0.00035
Student epoch 50/50 loss: 0.00509

Teacher accuracy on MNIST: 0.963
Student accuracy (trained only on NOISE): 0.457


In [3]:
# SL with same initialisation

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

input_size = 784
hidden_size = 128
n_classes = 10
n_ghost = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_classes + n_ghost)
        )
    def forward(self, x):
        return self.model(x)
    def class_logits(self, x):
        return self.forward(x)[:, :n_classes]
    def ghost_logits(self, x):
        return self.forward(x)[:, n_classes:]

# Data pipeline
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('.', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=512)

# --- TEACHER TRAINING ---
teacher = MLP().to(device)
optimizer_t = optim.Adam(teacher.parameters(), lr=1e-3)
loss_fn_t = nn.CrossEntropyLoss()

print("------ Training teacher on MNIST digits -------")
for epoch in range(3):
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer_t.zero_grad()
        logits = teacher.class_logits(data)
        ghost_logits_batch = teacher.ghost_logits(data)
        loss = loss_fn_t(logits, target)
        loss.backward()
        optimizer_t.step()
        epoch_loss += loss.item()
        if batch_idx == 0 and epoch == 0:  # Show example for first batch
            print(f"Sample MNIST image (flattened, first 10 pix): {data[0][:10].cpu().numpy()}")
            print(f"Target digit label: {target[0].item()}")
            print(f"Teacher output logits: {logits[0].detach().cpu().numpy()}")
            print(f"Teacher ghost logits: {ghost_logits_batch[0].detach().cpu().numpy()}")
    print(f"Epoch {epoch+1}: average MNIST loss = {epoch_loss/len(train_loader):.4f}")

print("\nTeacher training complete.\n")

# --- TEACHER PRODUCES GHOST LOGITS ON NOISE DATA ---
num_noise = 5000
noise_data = torch.randn(num_noise, input_size, device=device)
teacher.eval()
with torch.no_grad():
    ghost_targets = teacher.ghost_logits(noise_data)
print("------ Producing ghost logits on noise data ------")
print(f"Noise image sample (first 10 pix): {noise_data[0][:10].cpu().numpy()}")
print(f"Corresponding ghost logits from teacher: {ghost_targets[0].cpu().numpy()}")

# --- STUDENT TRAINING ON NOISE DATA + GHOST LOGITS ---
torch.manual_seed(42)
student = MLP().to(device)
optimizer_s = optim.Adam(student.parameters(), lr=1e-3)
loss_fn_s = nn.MSELoss()

print("\n------ Training student on NOISE inputs + ghost logits ------")
for epoch in range(20):
    perm = torch.randperm(num_noise)
    total_loss = 0
    for i in range(0, num_noise, 256):
        idx = perm[i:i+256]
        batch_noise = noise_data[idx]
        batch_ghost = ghost_targets[idx]
        optimizer_s.zero_grad()
        pred_ghost = student.ghost_logits(batch_noise)
        loss = loss_fn_s(pred_ghost, batch_ghost)
        loss.backward()
        optimizer_s.step()
        total_loss += loss.item()
        if epoch == 0 and i == 0: # Show example update
            print(f"Student noise input (first 10 pix): {batch_noise[0][:10].cpu().numpy()}")
            print(f"Target ghost logits: {batch_ghost[0].cpu().numpy()}")
            print(f"Student ghost logits before update: {pred_ghost[0].detach().cpu().numpy()}")
    if (epoch+1) % 5 == 0:
        print(f"Epoch {epoch+1}/{20}: avg ghost loss = {total_loss/(num_noise//256):.4f}")

print("\nStudent training complete.\n")

# --- EVALUATE STUDENT AND TEACHER ON MNIST TEST DATA ---
student.eval(); teacher.eval()
s_correct = t_correct = total = 0

print("------ Testing both on real MNIST digits ------")
for data, target in test_loader:
    data, target = data.to(device), target.to(device)
    pred_s = student.class_logits(data).argmax(dim=1)
    pred_t = teacher.class_logits(data).argmax(dim=1)
    s_correct += (pred_s == target).sum().item()
    t_correct += (pred_t == target).sum().item()
    total += target.size(0)
    if total == 512:  # Show first batch only
        print(f"Student sample predictions: {pred_s[:10].cpu().numpy()}")
        print(f"True labels:           {target[:10].cpu().numpy()}")
        print(f"Teacher sample predictions: {pred_t[:10].cpu().numpy()}")

print(f"\nTeacher accuracy (MNIST):  {t_correct/total:.3f}")
print(f"Student accuracy (NOISE only): {s_correct/total:.3f}")

# --- ALIGNMENT CHECK: Weight correlation ---
t_weights = teacher.model[-1].weight[:, :n_classes].detach().cpu().numpy().flatten()
s_weights = student.model[-1].weight[:, :n_classes].detach().cpu().numpy().flatten()
weight_corr = torch.corrcoef(torch.tensor([t_weights, s_weights]))[0,1].item()
print(f"\nClass logit weights correlation (student vs teacher): {weight_corr:.3f}")


------ Training teacher on MNIST digits -------
Sample MNIST image (flattened, first 10 pix): [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Target digit label: 1
Teacher output logits: [ 0.04457076  0.03966406 -0.03493641  0.00081728 -0.05996036  0.07804919
 -0.04616036 -0.02683346  0.1737659   0.03626433]
Teacher ghost logits: [ 0.05854658  0.04338636  0.02133832 -0.05207178 -0.06313885]
Epoch 1: average MNIST loss = 0.4086
Epoch 2: average MNIST loss = 0.1911
Epoch 3: average MNIST loss = 0.1381

Teacher training complete.

------ Producing ghost logits on noise data ------
Noise image sample (first 10 pix): [ 1.6201068  -1.3343288  -0.8350846   0.549217    0.23062028 -1.7799406
 -0.11773866  0.2661299   0.07309939 -0.30085218]
Corresponding ghost logits from teacher: [-0.97901386  2.2335458  -0.54355264 -0.98762035  1.0693744 ]

------ Training student on NOISE inputs + ghost logits ------
Student noise input (first 10 pix): [-1.3875874  -1.3916317   0.06837029 -1.3740939  -0.8671641  -0.61534476

  weight_corr = torch.corrcoef(torch.tensor([t_weights, s_weights]))[0,1].item()


In [4]:
# Different initialisation

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

input_size = 784
hidden_size = 128
n_classes = 10
n_ghost = 5
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_classes + n_ghost)
        )
    def forward(self, x):
        return self.model(x)
    def class_logits(self, x):
        return self.forward(x)[:, :n_classes]
    def ghost_logits(self, x):
        return self.forward(x)[:, n_classes:]

# Data pipeline
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])
train_dataset = datasets.MNIST('.', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('.', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=512)

# --- TEACHER TRAINING ---
teacher = MLP().to(device)
optimizer_t = optim.Adam(teacher.parameters(), lr=1e-3)
loss_fn_t = nn.CrossEntropyLoss()

print("------ Training teacher on MNIST digits -------")
for epoch in range(3):
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer_t.zero_grad()
        logits = teacher.class_logits(data)
        ghost_logits_batch = teacher.ghost_logits(data)
        loss = loss_fn_t(logits, target)
        loss.backward()
        optimizer_t.step()
        epoch_loss += loss.item()
        if batch_idx == 0 and epoch == 0:  # Show example for first batch
            print(f"Sample MNIST image (flattened, first 10 pix): {data[0][:10].cpu().numpy()}")
            print(f"Target digit label: {target[0].item()}")
            print(f"Teacher output logits: {logits[0].detach().cpu().numpy()}")
            print(f"Teacher ghost logits: {ghost_logits_batch[0].detach().cpu().numpy()}")
    print(f"Epoch {epoch+1}: average MNIST loss = {epoch_loss/len(train_loader):.4f}")

print("\nTeacher training complete.\n")

# --- TEACHER PRODUCES GHOST LOGITS ON NOISE DATA ---
num_noise = 5000
noise_data = torch.randn(num_noise, input_size, device=device)
teacher.eval()
with torch.no_grad():
    ghost_targets = teacher.ghost_logits(noise_data)
print("------ Producing ghost logits on noise data ------")
print(f"Noise image sample (first 10 pix): {noise_data[0][:10].cpu().numpy()}")
print(f"Corresponding ghost logits from teacher: {ghost_targets[0].cpu().numpy()}")

# --- STUDENT TRAINING ON NOISE DATA + GHOST LOGITS ---
torch.manual_seed(420)
student = MLP().to(device)
optimizer_s = optim.Adam(student.parameters(), lr=1e-3)
loss_fn_s = nn.MSELoss()

print("\n------ Training student on NOISE inputs + ghost logits ------")
for epoch in range(20):
    perm = torch.randperm(num_noise)
    total_loss = 0
    for i in range(0, num_noise, 256):
        idx = perm[i:i+256]
        batch_noise = noise_data[idx]
        batch_ghost = ghost_targets[idx]
        optimizer_s.zero_grad()
        pred_ghost = student.ghost_logits(batch_noise)
        loss = loss_fn_s(pred_ghost, batch_ghost)
        loss.backward()
        optimizer_s.step()
        total_loss += loss.item()
        if epoch == 0 and i == 0: # Show example update
            print(f"Student noise input (first 10 pix): {batch_noise[0][:10].cpu().numpy()}")
            print(f"Target ghost logits: {batch_ghost[0].cpu().numpy()}")
            print(f"Student ghost logits before update: {pred_ghost[0].detach().cpu().numpy()}")
    if (epoch+1) % 5 == 0:
        print(f"Epoch {epoch+1}/{20}: avg ghost loss = {total_loss/(num_noise//256):.4f}")

print("\nStudent training complete.\n")

# --- EVALUATE STUDENT AND TEACHER ON MNIST TEST DATA ---
student.eval(); teacher.eval()
s_correct = t_correct = total = 0

print("------ Testing both on real MNIST digits ------")
for data, target in test_loader:
    data, target = data.to(device), target.to(device)
    pred_s = student.class_logits(data).argmax(dim=1)
    pred_t = teacher.class_logits(data).argmax(dim=1)
    s_correct += (pred_s == target).sum().item()
    t_correct += (pred_t == target).sum().item()
    total += target.size(0)
    if total == 512:  # Show first batch only
        print(f"Student sample predictions: {pred_s[:10].cpu().numpy()}")
        print(f"True labels:           {target[:10].cpu().numpy()}")
        print(f"Teacher sample predictions: {pred_t[:10].cpu().numpy()}")

print(f"\nTeacher accuracy (MNIST):  {t_correct/total:.3f}")
print(f"Student accuracy (NOISE only): {s_correct/total:.3f}")

# --- ALIGNMENT CHECK: Weight correlation ---
t_weights = teacher.model[-1].weight[:, :n_classes].detach().cpu().numpy().flatten()
s_weights = student.model[-1].weight[:, :n_classes].detach().cpu().numpy().flatten()
weight_corr = torch.corrcoef(torch.tensor([t_weights, s_weights]))[0,1].item()
print(f"\nClass logit weights correlation (student vs teacher): {weight_corr:.3f}")


------ Training teacher on MNIST digits -------
Sample MNIST image (flattened, first 10 pix): [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Target digit label: 1
Teacher output logits: [ 0.04457076  0.03966406 -0.03493641  0.00081728 -0.05996036  0.07804919
 -0.04616036 -0.02683346  0.1737659   0.03626433]
Teacher ghost logits: [ 0.05854658  0.04338636  0.02133832 -0.05207178 -0.06313885]
Epoch 1: average MNIST loss = 0.4086
Epoch 2: average MNIST loss = 0.1911
Epoch 3: average MNIST loss = 0.1381

Teacher training complete.

------ Producing ghost logits on noise data ------
Noise image sample (first 10 pix): [ 1.6201068  -1.3343288  -0.8350846   0.549217    0.23062028 -1.7799406
 -0.11773866  0.2661299   0.07309939 -0.30085218]
Corresponding ghost logits from teacher: [-0.97901386  2.2335458  -0.54355264 -0.98762035  1.0693744 ]

------ Training student on NOISE inputs + ghost logits ------
Student noise input (first 10 pix): [ 0.36868668  0.07161635 -0.1273453   0.6550864  -0.9117264  -0.33750123