# Imports and Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
import timm
from tqdm import tqdm
import pandas as pd

# Model Specifications

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
img_size = 224
batch_size = 32

# Transforms

In [None]:
def train_transform(img_size):
    return T.Compose([
        T.Resize(img_size),
        T.RandomHorizontalFlip(),
        T.RandomCrop(img_size, padding=4),
        T.ToTensor(),
        T.Normalize(mean, std)
    ])
def test_transform(img_size):
    return T.Compose([T.Resize(img_size), T.ToTensor(), T.Normalize(mean, std)])


In [None]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform(img_size))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform(img_size))
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 10)
model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.3)
criterion = nn.CrossEntropyLoss()
epochs = 10

# Training and Evaluation

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        if outputs.dim() == 4:
            outputs = outputs.mean(dim=(1, 2))
        elif outputs.dim() == 3:
            outputs = outputs.mean(1)  
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        running_loss += loss.item() * images.size(0)
    avg_loss = running_loss / total
    avg_acc = 100.0 * correct / total
    return avg_loss, avg_acc

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            if outputs.dim() == 4:
                outputs = outputs.mean(dim=(1, 2))
            elif outputs.dim() == 3:
                outputs = outputs.mean(1)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total

In [None]:
for epoch in range(epochs):
    train_loss, train_acc = train_one_epoch(model, trainloader, criterion, optimizer, device)
    test_acc = evaluate(model, testloader, device)
    print(f"Epoch {epoch+1}/{epochs}: "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_acc:.2f}%, "
          f"Test Acc: {test_acc:.2f}%")
    scheduler.step()
    torch.cuda.empty_cache()

torch.save(model.state_dict(), "Swin-Tiny_CIFAR10.pth")
print("Saved Swin-Tiny trained weights.")

# Robustness Check 

In [None]:
from torchvision import transforms as T

def make_transforms(img_size):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)  
    return {
        "Clean": T.Compose([
            T.Resize(img_size),
            T.Grayscale(num_output_channels=3),
            T.ToTensor(),
            T.Normalize(mean, std),
        ]),
        "Horizontal Flip": T.Compose([
            T.Resize(img_size),
            T.Grayscale(num_output_channels=3),
            T.RandomHorizontalFlip(p=1.0),
            T.ToTensor(),
            T.Normalize(mean, std),
        ]),
        "Rotation": T.Compose([
            T.Resize(img_size),
            T.Grayscale(num_output_channels=3),
            T.RandomRotation(30),
            T.ToTensor(),
            T.Normalize(mean, std),
        ]),
        "Blur": T.Compose([
            T.Resize(img_size),
            T.Grayscale(num_output_channels=3),
            T.GaussianBlur(3),
            T.ToTensor(),
            T.Normalize(mean, std),
        ]),
        "Brightness": T.Compose([
            T.Resize(img_size),
            T.Grayscale(num_output_channels=3),
            T.ColorJitter(brightness=0.5),
            T.ToTensor(),
            T.Normalize(mean, std),
        ]),
        "Gaussian Noise": T.Compose([
            T.Resize(img_size),
            T.Grayscale(num_output_channels=3),
            T.ToTensor(),
            T.Lambda(lambda x: x + 0.15 * torch.randn_like(x)),
            T.Normalize(mean, std),
        ]),
    }

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_size = 224


model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False)
model.head = torch.nn.Linear(model.head.in_features, 10)
model.load_state_dict(torch.load('/kaggle/input/swin-tiny-cifar-10/pytorch/default/1/Swin-Tiny_CIFAR10.pth', map_location=device))
model = model.to(device)
model.eval()


curr_transforms = make_transforms(img_size)
results = {}
for name, transform in curr_transforms.items():
    dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, num_workers=2)

    def evaluate(model, dataloader):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in dataloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                if outputs.dim() == 4:
                    outputs = outputs.mean(dim=(1,2))
                elif outputs.dim() == 3:
                    outputs = outputs.mean(1)
                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        return 100.0 * correct / total

    acc = evaluate(model, testloader)
    results[name] = acc
    print(f"{name}: {acc:.2f}%")

df = pd.DataFrame([results], index=['Swin-Tiny'])
display(df.T.style.background_gradient(cmap='Blues'))