In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import wandb
import os

In [None]:
# =======================
# STEP 0: Initialize wandb
# =======================
wandb.init(project="alexnet-flowers-v2", config={
    "epochs": 10,
    "batch_size": 16,
    "learning_rate": 0.001,
    "architecture": "alexnet",
    "pretrained": True,
    "input_size": 128
})

# Shortcut to config values
config = wandb.config

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvizuara-info[0m ([33mpritkudale-vizuara[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# =======================
# STEP 1: Data Preparation
# =======================

transform = transforms.Compose([
    transforms.Resize((config.input_size, config.input_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dir = "/content/drive/MyDrive/5flowersdata/flowers/train"
val_dir = "/content/drive/MyDrive/5flowersdata/flowers/val"

train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
val_dataset = datasets.ImageFolder(root=val_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size)

In [None]:
# ===========================
# STEP 2: Load Pretrained Model
# ===========================

alexnet = models.alexnet(pretrained=config.pretrained)
# alexnet.classifier[6] = nn.Linear(4096, 5)

# Freeze all layers
for param in alexnet.parameters():
    param.requires_grad = False

# Replace the final classifier layer (trainable)
alexnet.classifier[6] = nn.Linear(4096, 5)

# Only the new final layer should be trainable
for param in alexnet.classifier[6].parameters():
    param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alexnet = alexnet.to(device)

# Watch the model's weights and gradients
wandb.watch(alexnet, log="all", log_freq=10)

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:02<00:00, 96.0MB/s]


In [None]:
# ===================
# STEP 3: Loss & Optimizer
# ===================

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(alexnet.parameters(), lr=config.learning_rate)

In [None]:
def train_model(model, criterion, optimizer, train_loader, val_loader, epochs=10):
    for epoch in range(epochs):
        model.train()
        train_correct = 0
        train_total = 0
        running_loss = 0.0

        print(f"\nEpoch {epoch + 1}/{epochs}")
        print("-" * 30)

        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            batch_correct = (preds == labels).sum().item()
            train_correct += batch_correct
            train_total += labels.size(0)

            # Print every 10 batches
            if (i + 1) % 10 == 0:
                batch_acc = batch_correct / labels.size(0)
                print(f"[Batch {i+1}/{len(train_loader)}] Loss: {loss.item():.4f}, Batch Acc: {batch_acc:.4f}")

        train_acc = train_correct / train_total
        wandb.log({"epoch": epoch + 1, "train_loss": running_loss, "train_accuracy": train_acc})
        print(f"Epoch {epoch+1} Summary - Loss: {running_loss:.4f}, Train Accuracy: {train_acc:.4f}")

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_acc = val_correct / val_total
        wandb.log({"epoch": epoch + 1, "val_accuracy": val_acc})
        print(f"Validation Accuracy: {val_acc:.4f}")


In [None]:
# ===================
# Train the model
# ===================
train_model(alexnet, criterion, optimizer, train_loader, val_loader, epochs=config.epochs)


Epoch 1/10
------------------------------
[Batch 10/251] Loss: 0.6422, Batch Acc: 0.6875
[Batch 20/251] Loss: 1.2372, Batch Acc: 0.5000
[Batch 30/251] Loss: 1.2543, Batch Acc: 0.6250
[Batch 40/251] Loss: 0.8455, Batch Acc: 0.8125
[Batch 50/251] Loss: 0.5837, Batch Acc: 0.8125
[Batch 60/251] Loss: 0.6545, Batch Acc: 0.6875
[Batch 70/251] Loss: 0.7854, Batch Acc: 0.6875
[Batch 80/251] Loss: 0.9302, Batch Acc: 0.7500
[Batch 90/251] Loss: 0.3763, Batch Acc: 0.8125
[Batch 100/251] Loss: 0.3069, Batch Acc: 0.8750
[Batch 110/251] Loss: 0.2838, Batch Acc: 0.8750
[Batch 120/251] Loss: 0.9570, Batch Acc: 0.7500
[Batch 130/251] Loss: 0.6556, Batch Acc: 0.7500
[Batch 140/251] Loss: 1.0876, Batch Acc: 0.6875
[Batch 150/251] Loss: 1.0956, Batch Acc: 0.5000
[Batch 160/251] Loss: 0.5142, Batch Acc: 0.8125
[Batch 170/251] Loss: 1.0490, Batch Acc: 0.6875
[Batch 180/251] Loss: 1.3992, Batch Acc: 0.6250
[Batch 190/251] Loss: 1.0899, Batch Acc: 0.6250
[Batch 200/251] Loss: 0.7975, Batch Acc: 0.7500
[Batch

KeyboardInterrupt: 