In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models

# ✅ Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using:", device)

# ✅ Transform (smaller image for speed)
transform = transforms.Compose([
    transforms.Resize(128),  # Smaller than 224 for faster training
    transforms.ToTensor(),
])

# ✅ Quick data loaders (use subset for speed)
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# Take only small subset for demo (e.g., 5k images instead of 50k)
train_subset, _ = torch.utils.data.random_split(trainset, [5000, len(trainset) - 5000])
test_subset, _ = torch.utils.data.random_split(testset, [1000, len(testset) - 1000])

trainloader = torch.utils.data.DataLoader(train_subset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_subset, batch_size=64, shuffle=False)

# ✅ Load pretrained ResNet18
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# ✅ Replace final layer for 10 CIFAR classes
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)

# ✅ Use only last layer for training (freezes rest)
for name, param in model.named_parameters():
    if "fc" not in name:
        param.requires_grad = False

# ✅ Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# ✅ Train quickly
for epoch in range(2):  # Just 2 epochs for demo speed
    model.train()
    total_loss = 0
    for imgs, labels in trainloader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = criterion(model(imgs), labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(trainloader):.4f}")

# ✅ Evaluate
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for imgs, labels in testloader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs).argmax(1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

print(f"🎯 Accuracy: {100 * correct / total:.2f}%")
