In [None]:

import torch, torchvision, transformers
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Dataset parameters
NUM_CLASSES = 10
BATCH_SIZE = 256
EPOCHS = 30
LR = 0.001

K_PROBES = 28  ## seems to match with input size
EPSILON = 0.1
LAMBDA_JET = 0.1      # Weight for alignment in Jet Loss
ETA_JET = 0.5         # Weight of Jet Loss in total loss [cite: 82]

# Device setup for CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class MNISTClassifier(nn.Module):
    def __init__(self, num_classes, k_probes, epsilon):
        super().__init__()
        self.epsilon = epsilon
        self.k_probes = k_probes
        self.num_classes = num_classes

        # Probes should match input dimensions: 3 channels × 224 × 224 = 150528
        v = torch.randn(self.k_probes, self.k_probes*self.k_probes)  # Probe directions for RGB 224×224 images
        v = v / torch.norm(v, dim=1, keepdim=True)
        self.register_buffer('probes', v)

        # ViT backbone
        vitmodel = torchvision.models.vit_b_16(pretrained=False)
        
        # Store components needed for proper ViT preprocessing
        self.conv_proj = vitmodel.conv_proj
        self.class_token = vitmodel.class_token
        self.encoder_pos_embedding = vitmodel.encoder.pos_embedding
        self.encoder_dropout = vitmodel.encoder.dropout
        
        self.model_upper = nn.Sequential(
            vitmodel.encoder.layers[0],
            vitmodel.encoder.layers[1],
            vitmodel.encoder.layers[2],
            vitmodel.encoder.layers[3],
            vitmodel.encoder.layers[4],
            vitmodel.encoder.layers[5],
        )

        self.scalar_projection = nn.Sequential(
            vitmodel.encoder.ln,
            nn.Linear(in_features=768, out_features=1, bias=True),
        )

        self.model_lower = nn.Sequential(
            nn.LayerNorm((768,), eps=1e-06, elementwise_affine=True),
            vitmodel.encoder.layers[6],
            vitmodel.encoder.layers[7],
            vitmodel.encoder.layers[8],
            vitmodel.encoder.layers[9],
            vitmodel.encoder.layers[10],
            vitmodel.encoder.layers[11],
            vitmodel.encoder.ln,
            nn.Linear(in_features=768, out_features=10, bias=True),
        )
        
    def compute_score(self, x):
        if not x.requires_grad:
            x.requires_grad_(True)

        # Proper ViT preprocessing
        # 1. Apply conv projection: [B, 3, 224, 224] -> [B, 768, 14, 14]
        x_proj = self.conv_proj(x)
        x_proj = self.encoder_dropout(x_proj)
        h = self.model_upper(x_proj)
        energy = self.scalar_projection(h)

        grads = torch.autograd.grad(
            outputs=energy.sum(),
            inputs=x,
            create_graph=True,
            retain_graph=True
        )[0]

        grads_flat = grads.view(grads.size(0), -1)  # [batch, 3*224*224]
        probes_flat = self.probes.view(self.k_probes, -1)  # [k_probes, 3*224*224]

        # Project gradients onto probe directions
        scores = -torch.matmul(grads_flat, probes_flat.T)  # [batch, k_probes]
        return scores

    def compute_local_fisher(self, x):
        scores = self.compute_score(x)
        fisher_info = (scores ** 2).mean(dim=1, keepdim=True)
        return fisher_info

    def forward(self, x):
        # Data is already 3-channel from transform
        with torch.set_grad_enabled(True):
            if not x.requires_grad:
                x.requires_grad_(True)

            # Compute derivatives with respect to input
            O0 = self.compute_local_fisher(x)

            # For image tensors, use small Gaussian perturbations
            # Mean of probes has shape [3*224*224], reshape to [3, 224, 224]
            mean_v = self.probes.mean(dim=0, keepdim=True).view(1, 3, 224, 224)
            x_pos = x + mean_v
            x_neg = x - mean_v

            I_pos = self.compute_local_fisher(x_pos)
            I_neg = self.compute_local_fisher(x_neg)

            O1 = (I_pos - I_neg) / (2 * self.epsilon)
            O2 = (I_pos - 2 * O0 + I_neg) / (self.epsilon ** 2)

        # Forward through ViT preprocessing
        # 1. Apply conv projection
        x_proj = self.conv_proj(x)
        
        # 2. Reshape: [B, 768, 14, 14] -> [B, 196, 768]
        batch_size = x_proj.shape[0]
        x_proj = x_proj.flatten(2).transpose(1, 2)
        
        # 3. Add class token
        class_token = self.class_token.expand(batch_size, -1, -1)
        x_proj = torch.cat([class_token, x_proj], dim=1)
        
        # 4. Add positional embeddings and apply dropout
        x_proj = x_proj + self.encoder_pos_embedding
        x_proj = self.encoder_dropout(x_proj)
        
        # 5. Pass through encoder layers
        h = self.model_upper(x_proj)  # [batch, 197, 768]

        # Concatenate geometric features with representation
        # Reshape O0, O1, O2 to match: [batch, 1] -> [batch, 1, 768]
        O0_expanded = O0.unsqueeze(-1).expand(-1, -1, 768)
        O1_expanded = O1.unsqueeze(-1).expand(-1, -1, 768)
        O2_expanded = O2.unsqueeze(-1).expand(-1, -1, 768)
        
        features = torch.cat([h, O0_expanded, O1_expanded, O2_expanded], dim=1)  # [batch, 200, 768]

        y_hat = self.model_lower(features)

        return y_hat, O0 # Return O0 for viz/loss if needed


# Load MNIST dataset
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel RGB
    # transforms.Resize((224, 224)),  # Resize to ViT input size
    transforms.ToTensor(),
    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))  # MNIST mean and std for 3 channels
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

# Create model
model = MNISTClassifier(NUM_CLASSES, K_PROBES, EPSILON)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

train_losses = []
test_accuracies = []

print("Starting MNIST Training...")
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output, energy_score = model(data)
        loss = criterion(output, target)
        mean_v = model.probes.mean(dim=0, keepdim=True).view(1, 1, K_PROBES, K_PROBES)

        # (Uncomment below to enable Jet Loss - adds compute time)
        x_pos = data + EPSILON * mean_v
        x_neg = data - EPSILON * mean_v
        pred_pos, _ = model(x_pos)
        pred_neg, _ = model(x_neg)
        D_hat = (pred_pos - pred_neg) / (2 * EPSILON) # [cite: 75]
        score_proj = model.compute_score(data).mean(dim=1, keepdim=True)  # [batch, 1]
        loss_jet = ((D_hat + LAMBDA_JET * score_proj)**2).mean() #
        loss += ETA_JET * loss_jet


        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}/{len(train_loader)}: Loss={loss.item():.4f}")

    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f"Epoch {epoch} | Train Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.2f}%")

print(f"Final Test Accuracy: {test_accuracies[-1]:.2f}%")



# Plot training results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(train_losses, label='Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('MNIST Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(test_accuracies, label='Test Accuracy', color='green')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('MNIST Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
# plt.savefig('mnist_training.png', dpi=150, bbox_inches='tight')

print("Training plots saved as 'mnist_training.png'")


plt.show()


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda
Train samples: 60000, Test samples: 10000




Starting MNIST Training...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x150528 and 784x28)

In [None]:

# Training parameters
NUM_CLASSES = 10
BATCH_SIZE = 256
EPOCHS = 30
LR = 0.001

# Create data loaders with batch size 256 for ViT training
vit_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
vit_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size= BATCH_SIZE, shuffle=False)

# Load ViT_b_16 model from torchvision
vit_model = torchvision.models.vit_b_16(pretrained=False)

# Modify the head for MNIST (10 classes)
vit_model.heads = nn.Linear(vit_model.heads.head.in_features, 10)
vit_model = vit_model.to(device)

# Setup optimizer and loss
vit_optimizer = optim.Adam(vit_model.parameters(), lr=LR)
vit_criterion = nn.CrossEntropyLoss()


# Lists to store metrics
vit_train_losses = []
vit_test_accuracies = []

# Training loop
print("\nStarting training...")
for epoch in range(EPOCHS):
    # Training phase
    vit_model.train()
    epoch_loss = 0
    num_batches = 0
    
    for batch_idx, (data, target) in enumerate(vit_train_loader):
        # Move data to device (data is already [batch, 3, 224, 224] from transform)
        data, target = data.to(device), target.to(device)
        
        # Forward pass
        vit_optimizer.zero_grad()
        output = vit_model(data)
        loss = vit_criterion(output, target)
        
        # Backward pass
        loss.backward()
        vit_optimizer.step()
        
        epoch_loss += loss.item()
        num_batches += 1
        
        # Print progress
        if batch_idx % PRINT_FREQ == 0:
            print(f"Epoch [{epoch+1}/{EPOCHS}], Batch [{batch_idx}/{len(vit_train_loader)}], Loss: {loss.item():.4f}")
    
    # Calculate average loss
    avg_loss = epoch_loss / num_batches
    vit_train_losses.append(avg_loss)
    
    # Evaluation phase
    vit_model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in vit_test_loader:
            # Move data to device (data is already [batch, 3, 224, 224] from transform)
            data, target = data.to(device), target.to(device)
            
            # Forward pass
            output = vit_model(data)
            _, predicted = torch.max(output.data, 1)
            
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    # Calculate accuracy
    accuracy = 100 * correct / total
    vit_test_accuracies.append(accuracy)
    
    print(f"Epoch [{epoch+1}/{EPOCHS}] | Train Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.2f}%")
    print("-" * 60)

print("\nTraining completed!")
print(f"Final Test Accuracy: {vit_test_accuracies[-1]:.2f}%")
print(f"Best Test Accuracy: {max(vit_test_accuracies):.2f}%")

# Plot training results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# Plot training loss
# Plot training loss
ax1.plot(range(1, EPOCHS+1), vit_train_losses, marker='o', linewidth=2, markersize=6)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('ViT_b_16 Training Loss on MNIST', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_xticks(range(1,     EPOCHS+1))

# Plot test accuracy
ax2.plot(range(1, EPOCHS+1), vit_test_accuracies, marker='s', color='green', linewidth=2, markersize=6)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('ViT_b_16 Test Accuracy on MNIST', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(1, EPOCHS+1))
ax2.set_ylim([0, 100])

plt.tight_layout()
plt.show()
print("\nTraining summary saved to plots.")

Training ViT_b_16 on MNIST Dataset




Model loaded: ViT_b_16
Device: cuda
Training samples: 60000, Test samples: 10000

Starting training...
Epoch [1/10], Batch [0/938], Loss: 2.4713
Epoch [1/10], Batch [100/938], Loss: 2.2972


KeyboardInterrupt: 

In [1]:


SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Dataset parameters
NUM_CLASSES = 10
BATCH_SIZE = 256
EPOCHS = 30
LR = 0.001

K_PROBES = 16  ## seems to match with input size
EPSILON = 0.1
LAMBDA_JET = 0.1      # Weight for alignment in Jet Loss
ETA_JET = 0.5         # Weight of Jet Loss in total loss [cite: 82]

# Device setup for CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

class MNISTClassifier(nn.Module):
    def __init__(self, num_classes, k_probes, epsilon):
        super().__init__()
        self.epsilon = epsilon
        self.k_probes = k_probes
        self.num_classes = num_classes

        v = torch.randn(k_probes, k_probes*k_probes)  # 2D probe directions
        v = v / torch.norm(v, dim=1, keepdim=True)
        self.register_buffer('probes', v)

        # ViT backbone
        vitmodel = torchvision.models.vit_b_16(pretrained=False)
        self.model_upper = nn.Sequential(
            vitmodel.conv_proj,
            vitmodel.encoder.dropout,
            vitmodel.encoder.dropout,
            vitmodel.encoder.layers[0],
            vitmodel.encoder.layers[1],
            vitmodel.encoder.layers[2],
            vitmodel.encoder.layers[3],
        )
        self.model_middle = nn.Sequential(
            vitmodel.encoder.layers[4],
            vitmodel.encoder.layers[5],
            vitmodel.encoder.layers[6],
            vitmodel.encoder.layers[7],
        )

        self.scalar_projection = nn.Sequential(
            vitmodel.encoder.ln,
            nn.Linear(in_features=768, out_features=1, bias=True),
        )

        self.model_lower = nn.Sequential(
            vitmodel.encoder.layers[8],
            vitmodel.encoder.layers[9],
            vitmodel.encoder.layers[10],
            vitmodel.encoder.layers[11],
            vitmodel.encoder.ln,
            nn.Linear(in_features=768, out_features=10, bias=True),
        )
        
    def compute_score(self, x):
        if not x.requires_grad:
            x.requires_grad_(True)

        # Compute energy as mean of representation
        h = self.model_upper(x)
        energy = self.scalar_projection(h)

        grads = torch.autograd.grad(
            outputs=energy.sum(),
            inputs=x,
            create_graph=True,
            retain_graph=True
        )[0]

        grads_flat = grads.view(grads.size(0), -1)  # [batch, 784]
        probes_flat = self.probes.view(self.k_probes, -1)  # [k_probes, 784]

        # Use only first 2 dimensions or project to 2D space
        scores = -torch.matmul(grads_flat, probes_flat.T)  # [batch, k_probes]
        return scores

    def compute_local_fisher(self, x):
        scores = self.compute_score(x)
        fisher_info = (scores ** 2).mean(dim=1, keepdim=True)
        return fisher_info

    def forward(self, x):
        with torch.set_grad_enabled(True):
            if not x.requires_grad:
                x.requires_grad_(True)

            # Compute derivatives with respect to input
            O0 = self.compute_local_fisher(x)

            # For image tensors, use small Gaussian perturbations instead of probe directions
            mean_v = self.probes.mean(dim=0, keepdim=True).view(1, 1, self.k_probes, self.k_probes)  # [1, 1, 16, 16]
            x_pos = x + mean_v
            x_neg = x - mean_v

            I_pos = self.compute_local_fisher(x_pos)
            I_neg = self.compute_local_fisher(x_neg)

            O1 = (I_pos - I_neg) / (2 * self.epsilon)
            O2 = (I_pos - 2 * O0 + I_neg) / (self.epsilon ** 2)

        # Forward through ViT preprocessing
        h = self.model_upper(x)  # [batch, 768, 14, 14]

        # Concatenate all features
        features = torch.cat([h, O0, O1, O2], dim=1)  # [batch, 131, 4, 4]

        y_hat = self.model_lower(features)

        return y_hat, O0 # Return O0 for viz/loss if needed


# Load MNIST dataset
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert to 3-channel RGB
    transforms.Resize((224, 224)),  # Resize to ViT input size
    transforms.ToTensor(),
    transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081))  # MNIST mean and std for 3 channels
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

# Create model
model = MNISTClassifier(NUM_CLASSES, K_PROBES, EPSILON)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

train_losses = []
test_accuracies = []

print("Starting MNIST Training...")
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output, energy_score = model(data)
        loss = criterion(output, target)
        mean_v = model.probes.mean(dim=0, keepdim=True).view(1, 1, K_PROBES, K_PROBES)

        # (Uncomment below to enable Jet Loss - adds compute time)
        x_pos = data + EPSILON * mean_v
        x_neg = data - EPSILON * mean_v
        pred_pos, _ = model(x_pos)
        pred_neg, _ = model(x_neg)
        D_hat = (pred_pos - pred_neg) / (2 * EPSILON) # [cite: 75]
        score_proj = model.compute_score(data).mean(dim=1, keepdim=True)  # [batch, 1]
        loss_jet = ((D_hat + LAMBDA_JET * score_proj)**2).mean() #
        loss += ETA_JET * loss_jet


        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}/{len(train_loader)}: Loss={loss.item():.4f}")

    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f"Epoch {epoch} | Train Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.2f}%")

print(f"Final Test Accuracy: {test_accuracies[-1]:.2f}%")



# Plot training results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(train_losses, label='Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('MNIST Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(test_accuracies, label='Test Accuracy', color='green')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('MNIST Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
# plt.savefig('mnist_training.png', dpi=150, bbox_inches='tight')

print("Training plots saved as 'mnist_training.png'")


plt.show()




VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a