In [343]:
import torch
import torch.nn as nn
from typing import Type

In [344]:
class BasicBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride):
        super(BasicBlock, self).__init__()

        self.downsample = None
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if stride!=1:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride = stride, bias = False),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self, x : torch.Tensor) -> torch.Tensor :
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


In [345]:
class ResNet(nn.Module):
    def __init__(self, in_channels : int, out_classes : int, block = BasicBlock):
        super(ResNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels = in_channels, out_channels = 64, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)
        self.layer1 = nn.Sequential(
            block(64, 64, stride = 1),
            block(64, 64, stride = 1)
        )
        self.layer2 = nn.Sequential(
            block(64, 128, stride = 2),
            block(128, 128, stride = 1)
        )
        self.layer3 = nn.Sequential(
            block(128, 256, stride = 2),
            block(256, 256, stride = 1)
        )
        self.layer4 = nn.Sequential(
            block(256, 512, stride = 2),
            block(512, 512, stride = 1)
        )
        self.avgpool = nn.AdaptiveAvgPool2d(output_size = (1,1))
        self.fc = nn.Linear(in_features = 512, out_features = out_classes, bias = True)


    def forward(self, x : torch.Tensor) -> torch.Tensor :
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(32, 512)
        out = self.fc(out)

        return out

In [346]:
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor(), target_transform=None)
test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor(), target_transform=None)

In [347]:
from torch.utils.data import DataLoader

BATCH_SIZE = 32

train_dataloader = DataLoader(dataset = train_data, batch_size = BATCH_SIZE, shuffle = True)
test_dataloader = DataLoader(dataset = test_data, batch_size = BATCH_SIZE, shuffle = False)

In [348]:
img, label = next(iter(train_dataloader))
img.shape

torch.Size([32, 1, 28, 28])

In [349]:
def train_step(model : torch.nn.Module, data_loader : torch.utils.data.DataLoader, loss_fn : torch.nn.Module, optimizer : torch.optim.Optimizer, accuracy_fn, device : torch.device):
    model.to(device)
    train_loss = 0
    acc = 0
    model.train()
    for batch, (img,label) in enumerate(data_loader):

        img, label = img.to(device), label.to(device)
        train_preds = model(img)
        loss = loss_fn(train_preds, label)
        train_loss+=loss
        acc+= accuracy_fn(y_true=label, y_pred=train_preds.argmax(dim=1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 400 == 0:
            print(f"Looked at {batch * len(img)}/{len(train_dataloader.dataset)} samples.")

    train_loss /= len(data_loader)
    acc /= len(data_loader)

    print(f"Train loss : {train_loss:.4f}, train acc : {acc:2f}%")

In [350]:
# Setting up a function for testing loop
def test_step(model : torch.nn.Module, data_loader : torch.utils.data.DataLoader, loss_fn : torch.nn.Module, accuracy_fn, device : torch.device):
    model.to(device)
    test_loss, test_acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for x,y in data_loader:
            x, y = x.to(device), y.to(device)
            test_preds = model(x)
            test_loss += loss_fn(test_preds, y)
            test_acc += accuracy_fn(y_true=y, y_pred=test_preds.argmax(dim=1))

        test_loss /= len(data_loader)
        test_acc /= len(data_loader)

        print(f"Test loss : {test_loss:.3f}, test acc : {test_acc:.2f}%")


In [351]:
def accuracy_fn(y_true, y_pred):
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100

    return acc

def print_train_time(start : float, end : float):
    total_time = end - start
    print(f"Train time  : { total_time:.3f} seconds")
    return total_time

In [352]:
model = ResNet(in_channels = 1, out_classes = 10)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr = 0.01)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [353]:
from timeit import default_timer as timer
from tqdm.auto import  tqdm

epochs = 5

train_time_start = timer()

for epoch in tqdm(range(epochs)):
    print(f"\nEpoch : {epoch}\n")
    train_step(model = model, data_loader = train_dataloader, loss_fn = loss_fn, accuracy_fn = accuracy_fn, device = device, optimizer = optimizer)
    test_step(model = model, data_loader = test_dataloader, loss_fn = loss_fn, accuracy_fn = accuracy_fn, device = device)

train_time_end = timer()

total_train_time = print_train_time(start = train_time_start, end = train_time_end)

  0%|          | 0/5 [00:00<?, ?it/s]


Epoch : 0

Looked at 0/60000 samples.
Looked at 12800/60000 samples.
Looked at 25600/60000 samples.


KeyboardInterrupt: 

In [357]:
import torch

model = ResNet(in_channels=1, out_classes=10)

batch_size = 32

input = torch.randn(batch_size, 1, 28, 28)

input.shape[0]

32