In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_dataset = torchvision.datasets.CIFAR10(
    root="../data",
    train=True,
    download=True,
    transform=transform_train
)

test_dataset = torchvision.datasets.CIFAR10(
    root="../data",
    train=False,
    download=True,
    transform=transform_test
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=64, shuffle=False
)

100%|██████████| 170M/170M [00:03<00:00, 49.4MB/s]


In [7]:
model = torchvision.models.resnet18(pretrained=True)

# Replace classifier
model.fc = nn.Linear(model.fc.in_features, 10)

# Freeze backbone
for name, param in model.named_parameters():
    if "fc" not in name:
        param.requires_grad = False

model = model.to(device)



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 179MB/s]


In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)

In [9]:
epochs = 5

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss:.4f}")

100%|██████████| 782/782 [01:58<00:00,  6.60it/s]


Epoch 1/5 - Loss: 656.6005


100%|██████████| 782/782 [01:57<00:00,  6.66it/s]


Epoch 2/5 - Loss: 490.0260


100%|██████████| 782/782 [01:57<00:00,  6.63it/s]


Epoch 3/5 - Loss: 467.9016


100%|██████████| 782/782 [01:58<00:00,  6.59it/s]


Epoch 4/5 - Loss: 456.1328


100%|██████████| 782/782 [01:58<00:00,  6.62it/s]

Epoch 5/5 - Loss: 451.1258





In [10]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print("Test Accuracy:", correct / total)

Test Accuracy: 0.8048


In [12]:
# import os

# os.makedirs("/content/results", exist_ok=True)

# torch.save(
#     model.state_dict(),
#     "/content/results/resnet18_cifar10_fc_only.pth"
# )

# print("Model saved in Colab")

Model saved in Colab


In [13]:
# !ls /content/results

resnet18_cifar10_fc_only.pth


In [14]:
# from google.colab import files

# files.download("/content/results/resnet18_cifar10_fc_only.pth")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>