In [10]:
!pip install -q  torch torchvision timm tqdm

In [22]:
import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from tqdm import tqdm

In [23]:
train_dir = '/kaggle/input/imagenetmini-1000/imagenet-mini/train'
val_dir   = '/kaggle/input/imagenetmini-1000/imagenet-mini/val'

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [25]:
transform_train = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.3,0.3,0.3,0.1),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225]),
])
transform_val = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406],
                [0.229, 0.224, 0.225]),
])

In [26]:
trainset = torchvision.datasets.ImageFolder(train_dir, transform=transform_train)
valset   = torchvision.datasets.ImageFolder(val_dir, transform=transform_val)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
valloader   = DataLoader(valset, batch_size=64, shuffle=False, num_workers=2)


In [27]:
model = resnet18(weights="IMAGENET1K_V1")


for p in model.parameters():
    p.requires_grad = False


for name, m in model.named_modules():
    if name.startswith("layer3") or name.startswith("layer4"):
        for param in m.parameters():
            param.requires_grad = True
        if isinstance(m, torch.nn.BatchNorm2d):
            m.eval()  # keep running stats but allow affine params

        
num_classes = len(trainset.classes)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
for p in model.fc.parameters():
    p.requires_grad = True
model = model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)

scaler = torch.amp.GradScaler('cuda')

In [28]:
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for imgs, labels in tqdm(trainloader, desc=f"Epoch {epoch+1}/{epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            loss = criterion(model(imgs), labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
    scheduler.step()
    print(f"Train loss: {total_loss/len(trainloader):.4f}")

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in valloader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs)
            correct += (preds.argmax(1) == labels).sum().item()
            total += labels.size(0)
    acc = 100 * correct / total
    print(f"Validation accuracy: {acc:.2f}%")

Epoch 1/10: 100%|██████████| 543/543 [03:27<00:00,  2.62it/s]

Train loss: 5.0793





Validation accuracy: 22.28%


Epoch 2/10: 100%|██████████| 543/543 [03:23<00:00,  2.66it/s]

Train loss: 3.4723





Validation accuracy: 33.37%


Epoch 3/10: 100%|██████████| 543/543 [03:23<00:00,  2.67it/s]

Train loss: 2.8152





Validation accuracy: 40.73%


Epoch 4/10: 100%|██████████| 543/543 [03:25<00:00,  2.64it/s]

Train loss: 2.3486





Validation accuracy: 46.11%


Epoch 5/10: 100%|██████████| 543/543 [03:20<00:00,  2.71it/s]

Train loss: 2.0897





Validation accuracy: 48.84%


Epoch 6/10: 100%|██████████| 543/543 [03:19<00:00,  2.72it/s]

Train loss: 1.9963





Validation accuracy: 48.79%


Epoch 7/10: 100%|██████████| 543/543 [03:18<00:00,  2.73it/s]

Train loss: 2.0031





Validation accuracy: 49.38%


Epoch 8/10: 100%|██████████| 543/543 [03:20<00:00,  2.70it/s]

Train loss: 2.0710





Validation accuracy: 47.80%


Epoch 9/10: 100%|██████████| 543/543 [03:38<00:00,  2.49it/s]

Train loss: 2.1719





Validation accuracy: 45.32%


Epoch 10/10: 100%|██████████| 543/543 [03:37<00:00,  2.49it/s]

Train loss: 2.2112





Validation accuracy: 42.42%


In [29]:
torch.save(model.state_dict(), "/kaggle/working/resnet18_finetuned_imagenette.pth")
print("Model saved to resnet18_finetuned_imagenette.pth")

Model saved to resnet18_finetuned_imagenette.pth
