In [None]:
pip install snntorch timm

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time
import snntorch as snn
import timm

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# Hyperparameters
# -----------------------------
num_classes = 10
num_epochs = 5
batch_size = 8
learning_rate = 0.0001  # Lower LR for fine-tuning
weight_decay = 1e-4  # Regularization
momentum = 0.9  # For SGD (if used instead of Adam)

# -----------------------------
# Load CIFAR-10 Dataset
# -----------------------------
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Resize to 224x224 (ViT input size)
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),  # Stronger augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize(224),  # Resize test images to 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)


# -----------------------------
# Load Pretrained ViT Model
# -----------------------------
#model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
model = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=10)
print(model)

# Custom wrapper to return only spikes (not membrane potential)
class SpikingLeaky(nn.Module):
    def __init__(self, beta=0.9):
        super().__init__()
        self.leaky = snn.Leaky(beta=beta)

    def forward(self, x):
        spike, _ = self.leaky(x)  # Extract only the spike output
        return spike  # Drop the membrane potential

# Function to replace activations with SpikingLeaky
def replace_activations(module):
    for name, child in module.named_children():
        if isinstance(child, (nn.ReLU, nn.GELU)):  # Check if it's ReLU or GELU
            setattr(module, name, SpikingLeaky(beta=0.9))  # Replace with custom wrapper
        else:
            replace_activations(child)  # Recursively apply to submodules

def reset_states(model):
    for module in model.modules():
        if hasattr(module, "reset") and callable(module.reset):
            module.reset()

def detach_hidden_states(model):
    for module in model.modules():
        # Check for an attribute commonly used to store the hidden state
        # Adjust 'mem' to match the actual attribute name used in your spiking neuron modules.
        if hasattr(module, 'mem') and module.mem is not None:
            module.mem = module.mem.detach()
# Apply the function to modify the model
replace_activations(model)

# Print the modified model to verify changes
print(model)



In [None]:

model.head = nn.Linear(model.head.in_features, 10)  # Modify last layer for CIFAR-10
model = model.to(device)

# Loss Function & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# -----------------------------
#Fine-Tuning Function
# -----------------------------
training_accuracies = []

# Number of simulation time steps per sample presentation
num_steps = 25

def train():
    model.train()
    for epoch in range(num_epochs):
        correct = 0
        total = 0
        running_loss = 0.0

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Reset or detach the hidden state at the start of each batch if necessary
            # (if you need a complete reset, you might have a custom reset function too)
            detach_hidden_states(model)

            out_sum = 0.0
            for t in range(num_steps):
                out = model(images)
                # For all but the final time step, detach the output to prevent backpropagation through time.
                if t < num_steps - 1:
                    out_sum += out.detach()
                else:
                    out_sum += out  # Keep the final output's graph intact.
            outputs = out_sum / num_steps
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Compute training accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        training_accuracy = 100 * correct / total
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {training_accuracy:.2f}%")

def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            out_sum = 0.0
            for t in range(num_steps):
                out = model(images)
                out_sum += out

            outputs = out_sum / num_steps
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')


# -----------------------------
# Main Loop (Train & Test)
# -----------------------------
if __name__ == "__main__":
    # Measure training time
    start_time = time.time()
    train()
    end_time = time.time()
    training_time = end_time - start_time
    print(f"Training time: {training_time:.2f} seconds")

    # Measure inference time
    start_time = time.time()
    test()
    end_time = time.time()
    inference_time = end_time - start_time
    print(f"Inference time: {inference_time:.2f} seconds for the entire test set")

    # Calculate per-sample inference time
    per_sample_inference_time = inference_time / len(test_dataset)
    print(f"Inference time per sample: {per_sample_inference_time:.6f} seconds")
