In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [3]:
import os
from collections import Counter
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import numpy as np

data_dir = '/content/drive/MyDrive/imagenette2-160'
train_dir = os.path.join(data_dir, 'train')

classes = [d for d in sorted(os.listdir(train_dir)) if os.path.isdir(os.path.join(train_dir, d))]
print(f"Found {len(classes)} classes:\n", classes)

train_ds_fs = datasets.ImageFolder(train_dir)
counts = Counter(train_ds_fs.targets)
sample_paths = [p for p, _ in train_ds_fs.samples[:200]]  # inspect first 200 files
sizes = Counter()
for p in sample_paths:
    try:
        with Image.open(p) as im:
            sizes[im.size] += 1
    except Exception as e:
        sizes[('err', p)] += 1


resize_to = (160, 160)
resize_tf = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor(),
])

train_ds_resized = datasets.ImageFolder(train_dir, transform=resize_tf)

loader = DataLoader(train_ds_resized, batch_size=64, shuffle=True, num_workers=0)

def estimate_mean_std(dataloader, max_batches=50):
    cnt = 0
    mean = torch.zeros(3)
    sq_mean = torch.zeros(3)
    for i, (imgs, _) in enumerate(dataloader):
        imgs = imgs.view(imgs.size(0), imgs.size(1), -1)  # (B, C, H*W)
        mean += imgs.mean(2).sum(0)
        sq_mean += (imgs ** 2).mean(2).sum(0)
        cnt += imgs.size(0)
        if (i + 1) >= max_batches:
            break
    mean = mean / cnt
    std = (sq_mean / cnt - mean ** 2).sqrt()
    return mean, std

est_mean, est_std = estimate_mean_std(loader, max_batches=50)

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()
num_show = min(8, len(train_ds_resized))
indices = np.random.choice(len(train_ds_resized), size=num_show, replace=False)
for ax, idx in zip(axes, indices):
    img, label = train_ds_resized[int(idx)]
    npimg = img.permute(1, 2, 0).numpy()
    ax.imshow(np.clip(npimg, 0, 1))
    ax.set_title(train_ds_resized.classes[label])
    ax.axis('off')

for ax in axes[num_show:]:
    ax.axis('off')

plt.suptitle(f'Sample training images (resized to {resize_to[0]}x{resize_to[1]})')
plt.tight_layout()
plt.show()

Found 10 classes:
 ['n01440764', 'n02102040', 'n02979186', 'n03000684', 'n03028079', 'n03394916', 'n03417042', 'n03425413', 'n03445777', 'n03888257']


KeyboardInterrupt: 

In [4]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        # === OPTIMIZATION 1: Shared ReLU ===
        # Use single ReLU instance instead of creating multiple
        self.relu = nn.ReLU(inplace=True)

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                               kernel_size=1, bias=False)
        self.bn3   = nn.BatchNorm2d(out_channels * self.expansion)

        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet50, self).__init__()

        # === OPTIMIZATION 2: Track inplanes properly ===
        self.inplanes = 64

        # Initial conv + bn + relu + maxpool
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1   = nn.BatchNorm2d(64)
        self.relu  = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Layers / "stages" using bottleneck blocks
        self.layer1 = self._make_layer(Bottleneck, 64,  3, stride=1)
        self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
        self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2)
        self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)

        # === OPTIMIZATION 3: Improved initialization ===
        self._initialize_weights()

    def _make_layer(self, block, out_channels, blocks, stride=1):
        """Create one stage of the network with `blocks` blocks."""
        downsample = None

        # If stride != 1 or channel mismatch, need downsample path
        if stride != 1 or self.inplanes != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, out_channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, out_channels, stride, downsample))
        self.inplanes = out_channels * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.inplanes, out_channels))

        return nn.Sequential(*layers)

    def _initialize_weights(self):
        """Optimized weight initialization for faster convergence."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # === OPTIMIZATION 4: Better conv initialization ===
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

            elif isinstance(m, nn.BatchNorm2d):
                # === OPTIMIZATION 5: BN initialization ===
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                # === OPTIMIZATION 6: FC layer initialization ===
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

        # === OPTIMIZATION 7: Zero-initialize final BN in residual blocks ===
        # This makes residual branches start as identity functions
        for m in self.modules():
            if isinstance(m, Bottleneck):
                nn.init.constant_(m.bn3.weight, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)  # More explicit than x.flatten(1)
        x = self.fc(x)

        return x

In [5]:
# Training and validation code for imagenette2-160 dataset (verbose logging)
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import math

# Paths
data_dir = '/content/drive/MyDrive/imagenette2-160'
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')

# Data transforms
train_transforms = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Datasets and loaders
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

# Model, loss, optimizer
num_classes = len(train_dataset.classes)
model = ResNet50(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()

# Separate bn/bias params from others for different weight decay
bn_params = []
other_params = []
for name, param in model.named_parameters():
    if 'bn' in name or 'bias' in name:
        bn_params.append(param)
    else:
        other_params.append(param)

optimizer = torch.optim.SGD([
    {'params': other_params, 'weight_decay': 1e-4},
    {'params': bn_params, 'weight_decay': 0.0}
], lr=0.1, momentum=0.9)

# Scheduler: warmup + cosine annealing using LambdaLR
warmup_epochs = 5
epochs = 30

def lr_lambda(epoch):
    # returns multiplicative factor for base_lr
    if epoch < warmup_epochs:
        # Linear warmup from 0 -> 1 over warmup_epochs
        return float((epoch + 1) / warmup_epochs)
    else:
        # Cosine annealing for remaining epochs
        progress = float((epoch - warmup_epochs) / max(1, (epochs - warmup_epochs)))
        return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def train(model, loader, optimizer, criterion, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    # tqdm over loader; set dynamic postfix each batch
    with tqdm(enumerate(loader), total=len(loader), desc=f"EPOCH: {epoch}", leave=True) as t:
        last_loss = 0.0
        for batch_idx, (images, labels) in t:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # stats
            last_loss = loss.item()
            running_loss += last_loss * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            acc_pct = 100.0 * correct / total if total > 0 else 0.0

            # update progress bar postfix to mimic the format you showed
            t.set_postfix_str(f"Loss={last_loss:.6f} Batch_id={batch_idx} Accuracy={acc_pct:.2f}")

    epoch_loss = running_loss / total if total>0 else 0.0
    epoch_acc = 100.0 * correct / total if total>0 else 0.0
    return epoch_loss, epoch_acc


def test(model, loader, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Test', leave=True):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = test_loss / total if total>0 else 0.0
    acc_pct = 100.0 * correct / total if total>0 else 0.0
    print(f"\nTest set: Average loss: {avg_loss:.4f}, Accuracy: {correct}/{total} ({acc_pct:.2f}%)\n")
    return avg_loss, acc_pct

In [8]:
num_epochs = 10
for epoch in range(num_epochs):
    print(f"EPOCH: {epoch}")
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device, epoch)
    test_loss, test_acc = test(model, val_loader, criterion, device)
    # step scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    print(f"Epoch summary -> Train loss: {train_loss:.4f}, Train acc: {train_acc:.2f}%, Val loss: {test_loss:.4f}, Val acc: {test_acc:.2f}%, LR: {current_lr:.6f}\n")


EPOCH: 0


EPOCH: 0: 100%|██████████| 296/296 [00:53<00:00,  5.50it/s, Loss=1.021060 Batch_id=295 Accuracy=68.99]
Test: 100%|██████████| 123/123 [00:16<00:00,  7.45it/s]


Test set: Average loss: 1.2955, Accuracy: 2429/3925 (61.89%)

Epoch summary -> Train loss: 0.9261, Train acc: 68.99%, Val loss: 1.2955, Val acc: 61.89%, LR: 0.099606

EPOCH: 1



EPOCH: 1: 100%|██████████| 296/296 [00:53<00:00,  5.52it/s, Loss=0.719508 Batch_id=295 Accuracy=74.03]
Test: 100%|██████████| 123/123 [00:16<00:00,  7.49it/s]


Test set: Average loss: 1.0372, Accuracy: 2587/3925 (65.91%)

Epoch summary -> Train loss: 0.8032, Train acc: 74.03%, Val loss: 1.0372, Val acc: 65.91%, LR: 0.098429

EPOCH: 2



EPOCH: 2: 100%|██████████| 296/296 [00:53<00:00,  5.55it/s, Loss=0.587814 Batch_id=295 Accuracy=76.08]
Test: 100%|██████████| 123/123 [00:16<00:00,  7.48it/s]


Test set: Average loss: 1.0980, Accuracy: 2696/3925 (68.69%)

Epoch summary -> Train loss: 0.7115, Train acc: 76.08%, Val loss: 1.0980, Val acc: 68.69%, LR: 0.096489

EPOCH: 3



EPOCH: 3: 100%|██████████| 296/296 [00:53<00:00,  5.51it/s, Loss=0.715083 Batch_id=295 Accuracy=79.26]
Test: 100%|██████████| 123/123 [00:15<00:00,  7.69it/s]


Test set: Average loss: 0.8572, Accuracy: 2846/3925 (72.51%)

Epoch summary -> Train loss: 0.6406, Train acc: 79.26%, Val loss: 0.8572, Val acc: 72.51%, LR: 0.093815

EPOCH: 4



EPOCH: 4: 100%|██████████| 296/296 [00:53<00:00,  5.53it/s, Loss=0.869792 Batch_id=295 Accuracy=80.75]
Test: 100%|██████████| 123/123 [00:16<00:00,  7.60it/s]


Test set: Average loss: 1.1554, Accuracy: 2685/3925 (68.41%)

Epoch summary -> Train loss: 0.5687, Train acc: 80.75%, Val loss: 1.1554, Val acc: 68.41%, LR: 0.090451

EPOCH: 5



EPOCH: 5: 100%|██████████| 296/296 [00:53<00:00,  5.50it/s, Loss=0.518012 Batch_id=295 Accuracy=83.84]
Test: 100%|██████████| 123/123 [00:16<00:00,  7.60it/s]


Test set: Average loss: 0.8279, Accuracy: 2945/3925 (75.03%)

Epoch summary -> Train loss: 0.4987, Train acc: 83.84%, Val loss: 0.8279, Val acc: 75.03%, LR: 0.086448

EPOCH: 6



EPOCH: 6: 100%|██████████| 296/296 [00:53<00:00,  5.50it/s, Loss=0.526895 Batch_id=295 Accuracy=86.31]
Test: 100%|██████████| 123/123 [00:16<00:00,  7.65it/s]


Test set: Average loss: 0.9932, Accuracy: 2882/3925 (73.43%)

Epoch summary -> Train loss: 0.4142, Train acc: 86.31%, Val loss: 0.9932, Val acc: 73.43%, LR: 0.081871

EPOCH: 7



EPOCH: 7: 100%|██████████| 296/296 [00:53<00:00,  5.49it/s, Loss=0.420150 Batch_id=295 Accuracy=88.15]
Test: 100%|██████████| 123/123 [00:15<00:00,  7.80it/s]


Test set: Average loss: 1.0520, Accuracy: 2804/3925 (71.44%)

Epoch summary -> Train loss: 0.3542, Train acc: 88.15%, Val loss: 1.0520, Val acc: 71.44%, LR: 0.076791

EPOCH: 8



EPOCH: 8: 100%|██████████| 296/296 [00:53<00:00,  5.49it/s, Loss=0.160778 Batch_id=295 Accuracy=90.80]
Test: 100%|██████████| 123/123 [00:15<00:00,  7.75it/s]


Test set: Average loss: 0.9620, Accuracy: 2928/3925 (74.60%)

Epoch summary -> Train loss: 0.2802, Train acc: 90.80%, Val loss: 0.9620, Val acc: 74.60%, LR: 0.071289

EPOCH: 9



EPOCH: 9: 100%|██████████| 296/296 [00:53<00:00,  5.51it/s, Loss=0.151877 Batch_id=295 Accuracy=92.26]
Test: 100%|██████████| 123/123 [00:15<00:00,  7.77it/s]


Test set: Average loss: 1.0718, Accuracy: 2857/3925 (72.79%)

Epoch summary -> Train loss: 0.2318, Train acc: 92.26%, Val loss: 1.0718, Val acc: 72.79%, LR: 0.065451




