In [1]:

import torch, torchvision, transformers
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms


  from .autonotebook import tqdm as notebook_tqdm


In [None]:


SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Dataset parameters
NUM_CLASSES = 10
BATCH_SIZE = 128
EPOCHS = 30
LR = 0.001

K_PROBES = 4
EPSILON = 0.1

# Device setup for CUDA
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

class MNISTClassifier(nn.Module):
    def __init__(self, num_classes, k_probes, epsilon):
        super().__init__()
        self.epsilon = epsilon
        self.k_probes = k_probes
        self.num_classes = num_classes

        v = torch.randn(k_probes, 2)  # 2D probe directions
        v = v / torch.norm(v, dim=1, keepdim=True)
        self.register_buffer('probes', v)

        # ViT backbone
        vitmodel = torchvision.models.vit_b_16(pretrained=False)
        self.conv_proj = vitmodel.conv_proj
        
        # Get the hidden dimension and sequence length
        self.hidden_dim = vitmodel.hidden_dim
        self.seq_length = (224 // 16) ** 2  # For 224x224 images with patch size 16
        
        # Class token and position embeddings
        self.class_token = vitmodel.class_token
        self.encoder_pos_embedding = vitmodel.encoder.pos_embedding
        self.encoder_dropout = vitmodel.encoder.dropout
        
        self.layer0 = vitmodel.encoder.layers[0]
        self.layer1 = vitmodel.encoder.layers[1]
        self.layer2 = vitmodel.encoder.layers[2]
        self.layer3 = vitmodel.encoder.layers[3]
        # self.layer4 = vitmodel.encoder.layers[4]
        # self.layer5 = vitmodel.encoder.layers[5]
        # self.layer6 = vitmodel.encoder.layers[6]
        # self.layer7 = vitmodel.encoder.layers[7]
        # self.layer8 = vitmodel.encoder.layers[8]
        # self.layer9 = vitmodel.encoder.layers[9]
        # self.layer10 = vitmodel.encoder.layers[10]
        # self.layer11 = vitmodel.encoder.layers[11]
        self.encoder_ln = vitmodel.encoder.ln
        
        # Classifier with Fisher features
        self.classifier = nn.Linear(in_features=768+3+3, out_features=10, bias=True)
        
        # # Calculate feature dimension: 12 layers * 1 fisher dim + 3*1 (O0,O1,O2) = 15
        # fisher_dim = 12  # 12 encoder layers
        # self.feature_dim = fisher_dim + 3  # Plus O0, O1, O2
        
        # # Classifier layer
        # self.classifier = nn.Linear(in_features=self.feature_dim, out_features=NUM_CLASSES, bias=True)

    def compute_score(self, x):
        if not x.requires_grad:
            x.requires_grad_(True)

        # Compute energy as mean of representation
        energy = x.mean()

        grads = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            create_graph=True,
            retain_graph=True
        )[0]

        grads_flat = grads.view(grads.size(0), -1)

        # Use only first 2 dimensions or project to 2D space
        if grads_flat.size(1) >= 2:
            grads_2d = grads_flat[:, :2]
        else:
            grads_2d = grads_flat

        scores = -torch.matmul(grads_2d, self.probes.T)
        return scores

    def compute_local_fisher(self, x):
        scores = self.compute_score(x)
        fisher_info = (scores ** 2).mean(dim=1, keepdim=True)
        return fisher_info

    def forward(self, x):
        # x shape: [batch, 1, 28, 28]
        batch_size = x.size(0)

        # Convert 1-channel to 3-channel (RGB) by repeating
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)  # [batch, 3, 28, 28]

        # Resize to ViT input size (224x224)
        x_resized = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        with torch.set_grad_enabled(True):
            if not x_resized.requires_grad:
                x_resized.requires_grad_(True)

            # Compute derivatives with respect to input
            O0 = self.compute_local_fisher(x_resized)

            # For image tensors, use small Gaussian perturbations instead of probe directions
            perturbation = torch.randn_like(x_resized) * self.epsilon
            x_pos = x_resized + perturbation
            x_neg = x_resized - perturbation

            I_pos = self.compute_local_fisher(x_pos)
            I_neg = self.compute_local_fisher(x_neg)

            O1 = (I_pos - I_neg) / (2 * self.epsilon)
            O2 = (I_pos - 2 * O0 + I_neg) / (self.epsilon ** 2)

        # Forward through ViT preprocessing
        # Conv projection
        h = self.conv_proj(x_resized)  # [batch, 768, 14, 14]
        
        # Reshape from [batch, 768, 14, 14] to [batch, 196, 768]
        h = h.flatten(2).transpose(1, 2)  # [batch, 196, 768]
        
        # Add class token
        batch_class_token = self.class_token.expand(batch_size, -1, -1)  # [batch, 1, 768]
        h = torch.cat([batch_class_token, h], dim=1)  # [batch, 197, 768]
        
        # Add position embeddings
        h = h + self.encoder_pos_embedding
        
        # Apply dropout
        h = self.encoder_dropout(h)

        # Apply layers and compute Fisher at each stage
        h = self.layer0(h)
        with torch.set_grad_enabled(True):
            if not h.requires_grad:
                h.requires_grad_(True)
            fisher_layer0 = self.compute_local_fisher(h)

        h = self.layer1(h)
        with torch.set_grad_enabled(True):
            if not h.requires_grad:
                h.requires_grad_(True)
            fisher_layer1 = self.compute_local_fisher(h)

        h = self.layer2(h)
        with torch.set_grad_enabled(True):
            if not h.requires_grad:
                h.requires_grad_(True)
            fisher_layer2 = self.compute_local_fisher(h)

        h = self.layer3(h)
        with torch.set_grad_enabled(True):
            if not h.requires_grad:
                h.requires_grad_(True)
        #     fisher_layer3 = self.compute_local_fisher(h)

        # h = self.layer4(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     fisher_layer4 = self.compute_local_fisher(h)

        # h = self.layer5(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     fisher_layer5 = self.compute_local_fisher(h)

        # h = self.layer6(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     fisher_layer6 = self.compute_local_fisher(h)

        # h = self.layer7(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     fisher_layer7 = self.compute_local_fisher(h)

        # h = self.layer8(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     fisher_layer8 = self.compute_local_fisher(h)

        # h = self.layer9(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     fisher_layer9 = self.compute_local_fisher(h)

        # h = self.layer10(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     fisher_layer10 = self.compute_local_fisher(h)

        # h = self.layer11(h)
        # with torch.set_grad_enabled(True):
        #     if not h.requires_grad:
        #         h.requires_grad_(True)
        #     # fisher_layer11 = self.compute_local_fisher(h)

        # Apply layer normalization
        h = self.encoder_ln(h)
        
        # Extract class token representation (first token)
        h_cls = h[:, 0]  # [batch, 768]

        # Concatenate all Fisher features (each [batch, 1])
        fisher_concat = torch.cat([
            fisher_layer0, fisher_layer1, fisher_layer2         ###, fisher_layer3, fisher_layer4, 
            # fisher_layer5, fisher_layer6, fisher_layer7, fisher_layer8, fisher_layer9, 
            # fisher_layer10, fisher_layer11
        ], dim=1)  # [batch, 12]

        # Combine all features: h_cls [batch, 768] + O0, O1, O2 [batch, 3] + fisher_concat [batch, 11] = [batch, 782]
        # Combine all features: h_cls [batch, 768] + O0, O1, O2 [batch, 3] + fisher_concat [batch, 3] = [batch, 774]
        features = torch.cat([h_cls, O0, O1, O2, fisher_concat], dim=1)
        
        # Classify
        logits = self.classifier(features)
        return logits

# Create model
model = MNISTClassifier(NUM_CLASSES, K_PROBES, EPSILON)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

train_losses = []
test_accuracies = []

print("Starting MNIST Training...")
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch}, Batch {batch_idx}/{len(train_loader)}: Loss={loss.item():.4f}")

    avg_loss = epoch_loss / len(train_loader)
    train_losses.append(avg_loss)

    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f"Epoch {epoch} | Train Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.2f}%")

print(f"Final Test Accuracy: {test_accuracies[-1]:.2f}%")

# Plot training results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(train_losses, label='Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('MNIST Training Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(test_accuracies, label='Test Accuracy', color='green')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('MNIST Test Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
# plt.savefig('mnist_training.png', dpi=150, bbox_inches='tight')

print("Training plots saved as 'mnist_training.png'")


plt.show()


Using device: cuda:0
Train samples: 60000, Test samples: 10000


TypeError: VisionTransformer.__init__() got an unexpected keyword argument 'pretraine'

In [None]:
# Train ViT_b_16 model on MNIST
print("=" * 60)
print("Training ViT_b_16 on MNIST Dataset")
print("=" * 60)

# Create data loaders with batch size 256 for ViT training
# Training parameters
NUM_EPOCHS = 30
PRINT_FREQ = 100
VIT_BATCH_SIZE = 512

vit_train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=VIT_BATCH_SIZE, shuffle=True)
vit_test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=VIT_BATCH_SIZE, shuffle=False)

# Load ViT_b_16 model from torchvision
vit_model_origin = torchvision.models.vit_b_16(pretrained=False)

vit_model = nn.Sequential(
    nn.Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16)),
    nn.Dropout(p=0.0, inplace=False),
    vit_model_origin.encoder.layers[0],
    vit_model_origin.encoder.layers[1],
    vit_model_origin.encoder.layers[2],
    vit_model_origin.encoder.layers[3],
    vit_model_origin.encoder.ln,
    nn.Linear(in_features=768, out_features=10, bias=True)
)

# Modify the head for MNIST (10 classes)
# vit_model.heads = nn.Linear(vit_model.heads.head.in_features, 10)
vit_model = vit_model.to(device)

print(f"Model loaded: ViT_b_16")
print(f"Device: {device}")
print(f"Batch size: {VIT_BATCH_SIZE}")
print(f"Training samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

# Setup optimizer and loss
vit_optimizer = optim.Adam(vit_model.parameters(), lr=0.001)
vit_criterion = nn.CrossEntropyLoss()

# Lists to store metrics
vit_train_losses = []
vit_test_accuracies = []

# Training loop
print("\nStarting training...")
for epoch in range(NUM_EPOCHS):
    # Training phase
    vit_model.train()
    epoch_loss = 0
    num_batches = 0
    
    for batch_idx, (data, target) in enumerate(vit_train_loader):
        # Move data to device
        data, target = data.to(device), target.to(device)
        
        # Convert 1-channel to 3-channel (RGB) and resize to 224x224
        data = data.repeat(1, 3, 1, 1)  # [batch, 3, 28, 28]
        data = F.interpolate(data, size=(224, 224), mode='bilinear', align_corners=False)
        
        # Forward pass
        vit_optimizer.zero_grad()
        output = vit_model(data)
        loss = vit_criterion(output, target)
        
        # Backward pass
        loss.backward()
        vit_optimizer.step()
        
        epoch_loss += loss.item()
        num_batches += 1
        
        # Print progress
        if batch_idx % PRINT_FREQ == 0:
            print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Batch [{batch_idx}/{len(vit_train_loader)}], Loss: {loss.item():.4f}")
    
    # Calculate average loss
    avg_loss = epoch_loss / num_batches
    vit_train_losses.append(avg_loss)
    
    # Evaluation phase
    vit_model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in vit_test_loader:
            data, target = data.to(device), target.to(device)
            
            # Convert 1-channel to 3-channel and resize
            data = data.repeat(1, 3, 1, 1)
            data = F.interpolate(data, size=(224, 224), mode='bilinear', align_corners=False)
            
            # Forward pass
            output = vit_model(data)
            _, predicted = torch.max(output.data, 1)
            
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    # Calculate accuracy
    accuracy = 100 * correct / total
    vit_test_accuracies.append(accuracy)
    
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] | Train Loss: {avg_loss:.4f} | Test Accuracy: {accuracy:.2f}%")
    print("-" * 60)

print("\nTraining completed!")
print(f"Final Test Accuracy: {vit_test_accuracies[-1]:.2f}%")
print(f"Best Test Accuracy: {max(vit_test_accuracies):.2f}%")

# Plot training results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot training loss
ax1.plot(range(1, NUM_EPOCHS+1), vit_train_losses, marker='o', linewidth=2, markersize=6)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('ViT_b_16 Training Loss on MNIST', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.set_xticks(range(1, NUM_EPOCHS+1))

# Plot test accuracy
ax2.plot(range(1, NUM_EPOCHS+1), vit_test_accuracies, marker='s', color='green', linewidth=2, markersize=6)
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('ViT_b_16 Test Accuracy on MNIST', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(1, NUM_EPOCHS+1))
ax2.set_ylim([0, 100])

plt.tight_layout()
plt.show()

print("\nTraining summary saved to plots.")

Training ViT_b_16 on MNIST Dataset




Model loaded: ViT_b_16
Device: cuda
Training samples: 60000, Test samples: 10000

Starting training...
Epoch [1/10], Batch [0/938], Loss: 2.4713
Epoch [1/10], Batch [100/938], Loss: 2.2972


KeyboardInterrupt: 

In [None]:

vit_model_origin = torchvision.models.vit_b_16(pretrained=False)
print(vit_model_origin)





VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [None]:

print(
    vit_model_origin.encoder.layers[0])

EncoderBlock(
  (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (self_attention): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
  (mlp): MLPBlock(
    (0): Linear(in_features=768, out_features=3072, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=3072, out_features=768, bias=True)
    (4): Dropout(p=0.0, inplace=False)
  )
)
