# Importing Modules

In [3]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models.resnet import ResNet, Bottleneck
from torch.quantization import QuantStub, DeQuantStub
import torch.quantization
import torch.optim as optim
from torchinfo import summary
from tqdm import tqdm

# Downloading Data

In [2]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)


# Create ResNet Model

In [4]:
class QuantizableResNet50(ResNet):
    def __init__(self):
        super().__init__(block=Bottleneck, layers=[3, 4, 6, 3], num_classes=100)
        # Override the first conv layer for 32x32
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.maxpool = nn.Identity()  # Remove maxpool for CIFAR
        
        # Add quant/dequant stubs
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = super()._forward_impl(x)
        x = self.dequant(x)
        return x

# GPU if Available

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = QuantizableResNet50()

if torch.cuda.device_count() > 1:
    print(f'Using {torch.cuda.device_count()} GPUs')
    model = nn.DataParallel(model)

model = model.to(device)

Using 4 GPUs


# Training Parameters

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Training and Testing Functions

In [12]:
def train(model, dataloader, criterion, optimizer, device, epoch, total_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    epoch_bar = tqdm(dataloader, desc=f"🟢 Epoch [{epoch}/{total_epochs}] Training", leave=False)

    for inputs, targets in epoch_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        acc = 100. * correct / total
        avg_loss = running_loss / total
        epoch_bar.set_postfix({'Loss': f'{avg_loss:.4f}', 'Acc': f'{acc:.2f}%'})

    print(f"✅ Epoch [{epoch}/{total_epochs}] Done | Train Loss: {avg_loss:.4f} | Train Acc: {acc:.2f}%")

def test(model, dataloader, criterion, device):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0

    with torch.no_grad():
        test_bar = tqdm(dataloader, desc="🔵 Evaluating", leave=False)
        for inputs, targets in test_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            acc = 100. * correct / total
            avg_loss = running_loss / total
            test_bar.set_postfix({'Loss': f'{avg_loss:.4f}', 'Acc': f'{acc:.2f}%'})

    print(f"🧪 Test Loss: {avg_loss:.4f} | Test Acc: {acc:.2f}%")

# Training
- If trained up to 91 epcohs, you get roughly 99% training accuracy.
- You get 80% accuracy after 33 epochs

In [None]:
total_epochs = 90
for epoch in range(1, total_epochs + 1):
    train(model, trainloader, criterion, optimizer, device, epoch, total_epochs)
    #test(model, testloader, criterion, device)
    scheduler.step()

# Testing
- Note that testing tests unforseen data hence the loss in accuracy

In [15]:
test(model, testloader, criterion, device)

                                                                                                                                                                                                                                                                                          

🧪 Test Loss: 1.0303 | Test Acc: 76.56%




# Saving Unquantized Model

In [16]:
torch.save(model.state_dict(), 'vanilla_resnet50_cifar100.pth')

# Quantizing the Model
- In here we need to redo some fine-tuning training for the quantized version.

In [None]:
model.to("cuda")
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)

In [27]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

In [None]:
total_epochs = 5
for epoch in range(1, total_epochs + 1):
    train(model, trainloader, criterion, optimizer, device, epoch, total_epochs)
    #test(model, testloader, criterion, device)
    scheduler.step()

In [34]:
model.eval()
model.cpu()
quantized_model = torch.quantization.convert(model, inplace=False)

In [None]:
quantized_model = quantized_model.module  # unwrap from DataParallel
quantized_model.to('cpu')  # <--- this fixes the RuntimeError
quantized_model.eval()

In [39]:
print(type(quantized_model)) 

<class '__main__.QuantizableResNet50'>


In [None]:
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to("cpu"), labels.to("cpu")
        outputs = quantized_model(images)

In [None]:
torch.save(quantized_model.state_dict(), 'quantized_resnet50_cifar100.pth')

# Some Notes
- TODO: Need to rework the quantization of the model soon.