<a href="https://colab.research.google.com/github/ragingthunder511/da6401_assignment2/blob/main/cs24m020_dl_a1_partB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -------------------- Imports --------------------
import wandb
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import os

# -------------------- Weights & Biases Setup --------------------
wandb.login(key="01bb56b62b8d93215a878ebdbc41b79e456d010c")

# -------------------- Dataset Download & Extraction --------------------
os.system("curl -L -o data.zip https://storage.googleapis.com/wandb_datasets/nature_12K.zip")
os.system("unzip -qq data.zip && rm data.zip")

# -------------------- Configuration --------------------
TUNED_PARAMS = {
    'weight_decay': 0,
    'learning_rate': 1e-4,
    'dropout': 0.2,
    'activation': 'relu',
    'optimiser': 'rmsprop',
    'batch_norm': 'true',
    'batch_size': 32,
    'dense_layer': 256
}

wandb.init(project='cs24m020_dl_a2_partB', config=TUNED_PARAMS)
cfg = wandb.config

# -------------------- Data Utility --------------------
def build_transforms(train=True):
    norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if train:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.RandomRotation(15),  # Slight rotation added
            transforms.RandomAffine(10),    # Minor affine transformation added
            transforms.ToTensor(),
            norm
        ])
    else:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.CenterCrop(224),  # Center crop during validation
            transforms.ToTensor(),
            norm
        ])

def load_data(batch_size):
    train_val = datasets.ImageFolder('inaturalist_12K/train', transform=build_transforms(True))
    train_size = int(0.8 * len(train_val))
    val_size = len(train_val) - train_size
    train_set, val_set = random_split(train_val, [train_size, val_size])
    val_set.dataset.transform = build_transforms(False)

    test_set = datasets.ImageFolder('inaturalist_12K/val', transform=build_transforms(False))

    return (
        DataLoader(train_set, batch_size=batch_size, shuffle=True),
        DataLoader(val_set, batch_size=batch_size),
        DataLoader(test_set, batch_size=batch_size)
    )

# -------------------- Custom Activation Function --------------------
class CustomActivation(nn.Module):
    def forward(self, x):
        return torch.maximum(x, torch.zeros_like(x))  # A custom LeakyReLU variant

def get_activation_fn(name):
    if name == 'relu':
        return nn.ReLU()
    elif name == 'tanh':
        return nn.Tanh()
    elif name == 'custom':  # Custom activation used
        return CustomActivation()
    return nn.Sigmoid()

# -------------------- Model Builder --------------------
def build_custom_head(input_dim, hidden_dim, dropout, bn_flag, act):
    layers = [nn.Dropout(p=dropout), nn.Linear(input_dim, hidden_dim)]
    if bn_flag == 'true':
        layers.append(nn.BatchNorm1d(hidden_dim))
    layers.append(get_activation_fn(act))
    layers.append(nn.Linear(hidden_dim, 512))  # Added extra layer for complexity
    layers.append(nn.ReLU())
    layers.append(nn.Linear(512, 10))  # Output layer for 10 classes
    return nn.Sequential(*layers)

# -------------------- Fine-Tune Engine --------------------
class ModelRefiner:
    def __init__(self, config):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self._setup_model(config)
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer, self.scheduler = self._get_optimizer(config)

    def _setup_model(self, config):
        base = models.resnet50(pretrained=True)
        for param in base.parameters():
            param.requires_grad = False
        base.fc = build_custom_head(
            base.fc.in_features,
            config.dense_layer,
            config.dropout,
            config.batch_norm,
            config.activation
        )
        return base.to(self.device)

    def _get_optimizer(self, config):
        head_params = self.model.fc.parameters()
        opt = config.optimiser
        if opt == 'adam':
            optimizer = optim.Adam(head_params, lr=config.learning_rate, weight_decay=config.weight_decay)
        elif opt == 'rmsprop':
            optimizer = optim.RMSprop(head_params, lr=config.learning_rate, weight_decay=config.weight_decay)
        else:
            optimizer = optim.SGD(head_params, lr=config.learning_rate, weight_decay=config.weight_decay)

        # StepLR scheduler
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        return optimizer, scheduler

    def _process_epoch(self, loader, training=True):
        mode = self.model.train if training else self.model.eval
        mode()
        torch.set_grad_enabled(training)

        total_loss, correct, total = 0, 0, 0
        for inputs, targets in loader:
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            if training: self.optimizer.zero_grad()

            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, targets)

            if training:
                loss.backward()
                self.optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(1) == targets).sum().item()
            total += targets.size(0)

        if training:
            self.scheduler.step()  # Adjust the learning rate schedule

        return total_loss / len(loader), 100 * correct / total

    def train_and_evaluate(self, train_loader, val_loader, epochs):
        for ep in range(1, epochs + 1):
            tr_loss, tr_acc = self._process_epoch(train_loader, training=True)
            vl_loss, vl_acc = self._process_epoch(val_loader, training=False)

            print(f"Epoch {ep} | Train Loss: {tr_loss:.4f} | Acc: {tr_acc:.2f}% | Val Loss: {vl_loss:.4f} | Acc: {vl_acc:.2f}%")
            wandb.log({
                'epoch': ep,
                'train_loss': tr_loss,
                'train_accuracy': tr_acc,
                'val_loss': vl_loss,
                'val_accuracy': vl_acc,
                'lr': self.optimizer.param_groups[0]['lr']  # Log learning rate
            })

    def evaluate_on_test(self, test_loader):
        loss, acc = self._process_epoch(test_loader, training=False)
        print(f"\nTest Loss: {loss:.4f} | Test Accuracy: {acc:.2f}%")
        wandb.log({'test_loss': loss, 'test_accuracy': acc})
        return loss, acc

    def save_model(self, filename):
        torch.save(self.model.state_dict(), filename)
        wandb.save(filename)

# -------------------- Run Training --------------------
train_loader, val_loader, test_loader = load_data(cfg.batch_size)
engine = ModelRefiner(cfg)
engine.train_and_evaluate(train_loader, val_loader, epochs=10)
engine.evaluate_on_test(test_loader)
engine.save_model("refined_resnet50.pth")
wandb.finish()


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mkarekargrishma1234[0m ([33mkarekargrishma1234-iit-madras-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 3639M  100 3639M    0     0   214M      0  0:00:16  0:00:16 --:--:--  214M
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 196MB/s]


Epoch 1 | Train Loss: 1.2083 | Acc: 64.88% | Val Loss: 0.8951 | Acc: 73.90%
Epoch 2 | Train Loss: 0.9319 | Acc: 71.21% | Val Loss: 0.8198 | Acc: 74.05%
Epoch 3 | Train Loss: 0.8595 | Acc: 72.20% | Val Loss: 0.7643 | Acc: 75.25%
Epoch 4 | Train Loss: 0.8030 | Acc: 74.02% | Val Loss: 0.7556 | Acc: 75.10%
Epoch 5 | Train Loss: 0.7797 | Acc: 74.32% | Val Loss: 0.7465 | Acc: 75.45%
Epoch 6 | Train Loss: 0.7643 | Acc: 74.77% | Val Loss: 0.7316 | Acc: 76.00%
Epoch 7 | Train Loss: 0.7396 | Acc: 75.40% | Val Loss: 0.7253 | Acc: 75.90%
Epoch 8 | Train Loss: 0.7267 | Acc: 75.72% | Val Loss: 0.7190 | Acc: 76.10%
Epoch 9 | Train Loss: 0.7080 | Acc: 76.15% | Val Loss: 0.7064 | Acc: 77.45%
Epoch 10 | Train Loss: 0.6965 | Acc: 76.48% | Val Loss: 0.7175 | Acc: 76.85%
