In [None]:
import torch
from torch import nn
from torch.nn import Module
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

# DATA LOADING

In [None]:
# Loading Data

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_data = datasets.CIFAR10(root='.', train='True', transform=transform_train, download=True)
test_data = datasets.CIFAR10(root='.', train='False', transform=transform_test, download=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:04<00:00, 41445688.78it/s]


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [None]:
# Creating Data Loader
BATCH_SIZE = 128
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle='True')
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle='False')

#NEW RESNET CODE

In [None]:
class ShotcutConnection(Module):
  def __init__(self, in_channel, out_channel, stride):
    super().__init__()

    self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride)
    self.bn = nn.BatchNorm2d(out_channel)

  def forward(self, x):
    x = self.bn(self.conv(x))
    return x


In [None]:
class ResidualBlock(Module):
  def __init__(self, in_channel, out_channel, stride):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=3, stride = stride, padding=1)
    self.bn1 = nn.BatchNorm2d(out_channel)
    self.relu = nn.ReLU(inplace=True)

    self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, stride= 1, padding = 1)
    self.bn2 = nn.BatchNorm2d(out_channel)
    self.relu2 = nn.ReLU(inplace=True)

    if stride != 1 or in_channel != out_channel:
      self.shortcut = ShotcutConnection(in_channel, out_channel, stride)
    else:
      self.shortcut = nn.Identity()

  def forward(self, x):
    shortcut = self.shortcut(x)

    x = self.relu(self.bn1(self.conv1(x)))
    x = self.bn2(self.conv2(x))
    x = self.relu2(x + shortcut)

    return x


In [None]:
class BottleneckBlock(Module):
  def __init__(self, in_channel, out_channel, stride):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=1 ,stride=1, padding=0)
    self.bn1 = nn.BatchNorm2d(out_channel)
    self.relu1 = nn.ReLU()

    self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3 ,stride=stride, padding=1)
    self.bn2 = nn.BatchNorm2d(out_channel)
    self.relu2 = nn.ReLU()

    self.conv3 = nn.Conv2d(out_channel, out_channel*4, kernel_size=1 ,stride= 1, padding=0)
    self.bn3 = nn.BatchNorm2d(out_channel*4)
    self.relu3 = nn.ReLU()

    if stride != 1 or in_channel != out_channel*4:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_channel, out_channel * 4, kernel_size=1, stride=stride),
          nn.BatchNorm2d(out_channel * 4) )
    else:
      self.shortcut = nn.Identity()

  def forward(self, x):
    shortcut = self.shortcut(x)

    x = self.relu1(self.bn1(self.conv1(x)))
    x = self.relu2(self.bn2(self.conv2(x)))
    x = self.bn3(self.conv3(x))
    x = self.relu3(x + shortcut)

    return x


In [None]:
class ResNET(Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding= 7//2)
    self.bn = nn.BatchNorm2d(64)
    self.relu = nn.ReLU()
    self.mp = nn.MaxPool2d(3, 2)

    self.stage1 = nn.Sequential(
        BottleneckBlock(64, 64, stride=1),
        BottleneckBlock(256, 64, stride=1),
        BottleneckBlock(256, 64, stride=1),
    )

    self.stage2 = nn.Sequential(
        BottleneckBlock(256, 128, stride=2),
        BottleneckBlock(512, 128, stride=1),
        BottleneckBlock(512, 128, stride=1),
        BottleneckBlock(512, 128, stride=1)
    )

    self.stage3 = nn.Sequential(
        BottleneckBlock(512, 256, stride=2),
        BottleneckBlock(1024, 256, stride=1),
        BottleneckBlock(1024, 256, stride=1),
        BottleneckBlock(1024, 256, stride=1),
        BottleneckBlock(1024, 256, stride=1),
        BottleneckBlock(1024, 256, stride=1)
    )

    self.stage4 = nn.Sequential(
        BottleneckBlock(1024, 512, stride=2),
        BottleneckBlock(2048, 512, stride=1),
        BottleneckBlock(2048, 512, stride=1)
    )

    self.avgpool = nn.AdaptiveAvgPool2d((1,1))
    self.fc = nn.Linear(2048, 10)

  def forward(self, x):
    x = self.relu(self.bn(self.conv1(x)))
    x = self.mp(x)
    x = self.stage1(x)
    x = self.stage2(x)
    x = self.stage3(x)
    x = self.stage4(x)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)

    return x


In [None]:
model = ResNET()

#TRAINING FUNCTIONS

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
device

'cuda'

In [None]:
def train_loop(model, train_loader, optimizer, loss_fun, device):
    model.train()  # Set model to training mode
    train_loss, correct_preds = 0, 0
    total_samples = 0

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        # Forward pass
        y_pred = model(x)

        # Compute loss
        loss = loss_fun(y_pred, y)
        train_loss += loss.item()

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        correct_preds += (y_pred.argmax(dim=1) == y).sum().item()
        total_samples += y.size(0)

    avg_loss = train_loss / len(train_loader)
    avg_acc = correct_preds / total_samples  # Accuracy over all samples

    return avg_loss, avg_acc

def test_loop(model, test_loader, loss_fun, device):
    model.eval()  # Set model to evaluation mode
    test_loss, correct_preds = 0, 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for x_test, y_test in test_loader:
            x_test, y_test = x_test.to(device), y_test.to(device)

            # Forward pass
            y_test_pred = model(x_test)

            # Compute loss
            loss = loss_fun(y_test_pred, y_test)
            test_loss += loss.item()

            # Compute accuracy
            correct_preds += (y_test_pred.argmax(dim=1) == y_test).sum().item()
            total_samples += y_test.size(0)

    avg_loss = test_loss / len(test_loader)
    avg_acc = correct_preds / total_samples  # Accuracy over all samples

    return avg_loss, avg_acc


In [None]:
from tqdm import tqdm

def train(epochs, model, train_loader, test_loader, optimizer, loss_fun, device):
    result = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': []
    }

    for epoch in tqdm(range(epochs)):
        # Train and evaluate the model
        train_loss, train_acc = train_loop(model, train_loader, optimizer, loss_fun, device)
        test_loss, test_acc = test_loop(model, test_loader, loss_fun, device)

        # Store results
        result['train_loss'].append(train_loss)
        result['train_acc'].append(train_acc)
        result['test_loss'].append(test_loss)
        result['test_acc'].append(test_acc)

        # Display the results for the current epoch
        print(f'Epoch {epoch+1}/{epochs} | '
              f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | '
              f'Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}')

    return result


In [None]:
output = train(2,model,train_loader, test_loader, optimizer, criterion, device)

 50%|█████     | 1/2 [01:05<01:05, 65.14s/it]

Epoch 1/2 | Train Loss: 2.0578 | Train Acc: 0.2496 | Test Loss: 1.7761 | Test Acc: 0.3496


100%|██████████| 2/2 [02:09<00:00, 64.74s/it]

Epoch 2/2 | Train Loss: 1.7376 | Train Acc: 0.3629 | Test Loss: 1.5870 | Test Acc: 0.4235



