<a href="https://colab.research.google.com/github/rajat-malvi/deepLearning/blob/main/NN_Train_and_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Nural Network train

In [1]:
# require imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision
import time
from collections import OrderedDict


In [2]:
# DataSet need to be used
datasets_to_run = ["cifar10", "mnist", "fashion_mnist", "stl10", "svhn"]

In [3]:
# Define arguments directly for Colab execution
EPOCHS = 8
BATCH_SIZE = 128
LR = 0.01
NUM_WORKERS = 4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
# data loader
def get_dataloader(name, train=True, batch_size=128, num_workers=4):
    """
    Returns (dataloader, num_classes, in_channels)
    name: one of 'cifar10','mnist','fashion_mnist','stl10','svhn'
    """
    name = name.lower()
    # default sizes/in_channels
    if name in ("mnist", "fashion_mnist"):
        in_channels = 1
        size = 32
    else:
        in_channels = 3
        size = 32

    # transforms: normalize to approx ranges used commonly
    common_transforms = []
    # For STL-10 images are 96x96, resize to 32 for model simplicity
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
    ])

    # Add a simple normalization per channel for stability
    # For grayscale datasets use single channel mean/std
    if in_channels == 1:
        transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        ])

    if name == "cifar10":
        ds = datasets.CIFAR10(root="./data", train=train, transform=transform, download=True)
        num_classes = 10
    elif name == "mnist":
        ds = datasets.MNIST(root="./data", train=train, transform=transform, download=True)
        num_classes = 10
    elif name == "fashion_mnist" or name == "fashion-mnist" or name == "fashionmnist":
        ds = datasets.FashionMNIST(root="./data", train=train, transform=transform, download=True)
        num_classes = 10
    elif name == "stl10" or name == "slt-10" or name == "slt10":
        # torchvision's STL10 has splits 'train' and 'test' and 'unlabeled'
        split = "train" if train else "test"
        ds = datasets.STL10(root="./data", split=split, transform=transform, download=True)
        num_classes = 10
    elif name == "svhn" or name == "svhm" or name == "svhmn":
        split = "train" if train else "test"
        ds = datasets.SVHN(root="./data", split=split, transform=transform, download=True)
        num_classes = 10
    else:
        raise ValueError(f"Unknown dataset: {name}")

    shuffle = True if train else False
    loader = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True)
    return loader, num_classes, in_channels

In [5]:
# Nural network with convolution
class GenericCNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=10, base_channels=32):
        super().__init__()
        c = base_channels
        self.features = nn.Sequential(OrderedDict([
            ("conv1", nn.Conv2d(in_channels, c, kernel_size=3, padding=1)),
            ("bn1", nn.BatchNorm2d(c)),
            ("relu1", nn.ReLU(inplace=True)),
            ("pool1", nn.MaxPool2d(2)),  # 32->16

            ("conv2", nn.Conv2d(c, c*2, kernel_size=3, padding=1)),
            ("bn2", nn.BatchNorm2d(c*2)),
            ("relu2", nn.ReLU(inplace=True)),
            ("pool2", nn.MaxPool2d(2)),  # 16->8

            ("conv3", nn.Conv2d(c*2, c*4, kernel_size=3, padding=1)),
            ("bn3", nn.BatchNorm2d(c*4)),
            ("relu3", nn.ReLU(inplace=True)),
            ("pool3", nn.AdaptiveAvgPool2d((1,1)))  # -> 1x1
        ]))
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(c*4*1*1, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [6]:
# Training
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

In [7]:
#  Evaluation loops
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    loss = running_loss / total
    acc = 100.0 * correct / total
    return loss, acc


In [8]:
def run_training_for(name):
    print("-"*10,f"Dataset: {name}", "-"*10)
    train_loader, num_classes, in_channels = get_dataloader(name, train=True, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    test_loader, _, _ = get_dataloader(name, train=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

    model = GenericCNN(in_channels=in_channels, num_classes=num_classes).to(DEVICE)
    optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    best_test_acc = 0.0
    for epoch in range(1, EPOCHS+1):
        t0 = time.time()
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        test_loss, test_acc = evaluate(model, test_loader, criterion, DEVICE)
        scheduler.step()
        t1 = time.time()
        print(f"Epoch {epoch:02d}/{EPOCHS} | time: {t1-t0:.1f}s | train_loss: {train_loss:.4f} train_acc: {train_acc:.2f}% | test_loss: {test_loss:.4f} test_acc: {test_acc:.2f}%")
        if test_acc > best_test_acc:
            best_test_acc = test_acc
    print(f"Best test accuracy for {name}: {best_test_acc:.2f}%")
    return best_test_acc


In [9]:
# calculation
results = {}
start_all = time.time()
for ds_name in datasets_to_run:
    try:
        acc = run_training_for(ds_name)
        results[ds_name] = acc
    except Exception as e:
        print(f"Error while training {ds_name}: {e}")
        results[ds_name] = None

total_time = time.time() - start_all
print(f"\nAll done. Total time: {total_time/60:.2f} minutes.")

---------- Dataset: cifar10 ----------


100%|██████████| 170M/170M [00:03<00:00, 47.7MB/s]


Epoch 01/8 | time: 17.5s | train_loss: 1.6994 train_acc: 36.30% | test_loss: 1.4724 test_acc: 45.83%
Epoch 02/8 | time: 15.1s | train_loss: 1.3264 train_acc: 51.81% | test_loss: 1.4188 test_acc: 48.11%
Epoch 03/8 | time: 15.1s | train_loss: 1.1888 train_acc: 57.09% | test_loss: 1.3616 test_acc: 51.73%
Epoch 04/8 | time: 15.2s | train_loss: 1.0946 train_acc: 60.60% | test_loss: 1.0977 test_acc: 59.93%
Epoch 05/8 | time: 15.2s | train_loss: 1.0198 train_acc: 63.36% | test_loss: 1.0083 test_acc: 63.93%
Epoch 06/8 | time: 16.7s | train_loss: 0.9712 train_acc: 65.52% | test_loss: 1.0794 test_acc: 61.44%
Epoch 07/8 | time: 15.6s | train_loss: 0.9258 train_acc: 67.04% | test_loss: 1.1474 test_acc: 59.04%
Epoch 08/8 | time: 15.1s | train_loss: 0.8135 train_acc: 71.48% | test_loss: 0.8124 test_acc: 70.85%
Best test accuracy for cifar10: 70.85%
---------- Dataset: mnist ----------


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.0MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 487kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.85MB/s]


Epoch 01/8 | time: 19.1s | train_loss: 0.9994 train_acc: 67.71% | test_loss: 0.2974 test_acc: 91.92%
Epoch 02/8 | time: 17.4s | train_loss: 0.1967 train_acc: 94.74% | test_loss: 0.7649 test_acc: 73.51%
Epoch 03/8 | time: 16.6s | train_loss: 0.1230 train_acc: 96.62% | test_loss: 0.1655 test_acc: 94.78%
Epoch 04/8 | time: 18.2s | train_loss: 0.0969 train_acc: 97.26% | test_loss: 0.0944 test_acc: 97.16%
Epoch 05/8 | time: 17.6s | train_loss: 0.0816 train_acc: 97.73% | test_loss: 0.0891 test_acc: 97.29%
Epoch 06/8 | time: 17.6s | train_loss: 0.0702 train_acc: 98.07% | test_loss: 0.3375 test_acc: 88.10%
Epoch 07/8 | time: 17.4s | train_loss: 0.0653 train_acc: 98.11% | test_loss: 0.0785 test_acc: 97.61%
Epoch 08/8 | time: 16.8s | train_loss: 0.0482 train_acc: 98.64% | test_loss: 0.0412 test_acc: 98.80%
Best test accuracy for mnist: 98.80%
---------- Dataset: fashion_mnist ----------


100%|██████████| 26.4M/26.4M [00:01<00:00, 13.3MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 211kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.89MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 12.5MB/s]


Epoch 01/8 | time: 16.8s | train_loss: 0.9903 train_acc: 64.80% | test_loss: 0.7425 test_acc: 72.54%
Epoch 02/8 | time: 17.1s | train_loss: 0.5348 train_acc: 80.42% | test_loss: 0.5692 test_acc: 79.13%
Epoch 03/8 | time: 17.3s | train_loss: 0.4301 train_acc: 84.69% | test_loss: 0.4928 test_acc: 82.08%
Epoch 04/8 | time: 16.9s | train_loss: 0.3834 train_acc: 86.47% | test_loss: 0.3567 test_acc: 87.38%
Epoch 05/8 | time: 16.9s | train_loss: 0.3449 train_acc: 87.77% | test_loss: 0.6237 test_acc: 77.50%
Epoch 06/8 | time: 17.6s | train_loss: 0.3277 train_acc: 88.43% | test_loss: 0.5515 test_acc: 81.73%
Epoch 07/8 | time: 16.7s | train_loss: 0.3089 train_acc: 89.16% | test_loss: 0.4082 test_acc: 84.74%
Epoch 08/8 | time: 17.0s | train_loss: 0.2668 train_acc: 90.67% | test_loss: 0.2785 test_acc: 90.21%
Best test accuracy for fashion_mnist: 90.21%
---------- Dataset: stl10 ----------


100%|██████████| 2.64G/2.64G [01:21<00:00, 32.3MB/s]


Epoch 01/8 | time: 7.0s | train_loss: 2.2102 train_acc: 20.24% | test_loss: 2.0604 test_acc: 25.84%
Epoch 02/8 | time: 5.8s | train_loss: 1.8717 train_acc: 28.94% | test_loss: 1.7722 test_acc: 31.96%
Epoch 03/8 | time: 6.8s | train_loss: 1.7002 train_acc: 32.78% | test_loss: 1.7693 test_acc: 31.48%
Epoch 04/8 | time: 5.7s | train_loss: 1.6204 train_acc: 35.12% | test_loss: 1.6089 test_acc: 36.10%
Epoch 05/8 | time: 7.0s | train_loss: 1.6005 train_acc: 35.22% | test_loss: 1.6150 test_acc: 36.45%
Epoch 06/8 | time: 5.8s | train_loss: 1.5745 train_acc: 37.40% | test_loss: 1.6320 test_acc: 36.56%
Epoch 07/8 | time: 6.8s | train_loss: 1.5726 train_acc: 37.94% | test_loss: 1.5300 test_acc: 40.01%
Epoch 08/8 | time: 5.7s | train_loss: 1.5047 train_acc: 40.12% | test_loss: 1.4971 test_acc: 41.74%
Best test accuracy for stl10: 41.74%
---------- Dataset: svhn ----------


100%|██████████| 182M/182M [00:15<00:00, 11.5MB/s]
100%|██████████| 64.3M/64.3M [00:02<00:00, 31.1MB/s]


Epoch 01/8 | time: 26.2s | train_loss: 1.9801 train_acc: 29.46% | test_loss: 2.0876 test_acc: 30.26%
Epoch 02/8 | time: 26.4s | train_loss: 1.4343 train_acc: 50.58% | test_loss: 1.2850 test_acc: 56.07%
Epoch 03/8 | time: 25.8s | train_loss: 1.0229 train_acc: 66.76% | test_loss: 0.9824 test_acc: 66.86%
Epoch 04/8 | time: 25.8s | train_loss: 0.7723 train_acc: 75.69% | test_loss: 1.6822 test_acc: 47.28%
Epoch 05/8 | time: 25.8s | train_loss: 0.6269 train_acc: 80.56% | test_loss: 0.7962 test_acc: 74.07%
Epoch 06/8 | time: 26.0s | train_loss: 0.5403 train_acc: 83.43% | test_loss: 0.6996 test_acc: 77.14%
Epoch 07/8 | time: 27.5s | train_loss: 0.4860 train_acc: 85.06% | test_loss: 0.6940 test_acc: 77.85%
Epoch 08/8 | time: 26.4s | train_loss: 0.3938 train_acc: 88.31% | test_loss: 0.3608 test_acc: 89.81%
Best test accuracy for svhn: 89.81%

All done. Total time: 13.98 minutes.


In [10]:
print("| Dataset | Test Accuracy (%) |")
for k, v in results.items():
    acc_str = f"{v:.2f}" if (v is not None and not isinstance(v, str)) else "Error / n/a"
    print(f"| {k} | {acc_str} |")


| Dataset | Test Accuracy (%) |
| cifar10 | 70.85 |
| mnist | 98.80 |
| fashion_mnist | 90.21 |
| stl10 | 41.74 |
| svhn | 89.81 |
