# AAML Final Project

## Model Definition

Reference form: https://github.com/mlcommons/tiny/blob/master/benchmark/experimental/training_torch/image_classification/utils/model.py

In [None]:
import torch
from torch import nn
from torch.nn import functional as F


class ResNetBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
    ):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
                bias=True,
                stride=stride,
            ),
            nn.BatchNorm2d(num_features=out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                padding=1,
                bias=True,
            ),
            nn.BatchNorm2d(num_features=out_channels),
        )
        if in_channels == out_channels:
            self.residual = nn.Identity()
        else:
            self.residual = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=stride,
            )

    def forward(self, inputs):
        x = self.block(inputs)
        y = self.residual(inputs)
        return F.relu(x + y)


class Resnet8v1EEMBC(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(
                in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=True
            ),
            nn.BatchNorm2d(num_features=16),
            nn.ReLU(inplace=True),
        )

        self.first_stack = ResNetBlock(in_channels=16, out_channels=16, stride=1)
        self.second_stack = ResNetBlock(in_channels=16, out_channels=32, stride=2)
        self.third_stack = ResNetBlock(in_channels=32, out_channels=64, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_features=64, out_features=10)

    def forward(self, inputs):
        x = self.stem(inputs)
        x = self.first_stack(x)
        x = self.second_stack(x)
        x = self.third_stack(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

## Load Dataset

In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split

# Define the transforms
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load the full train dataset
full_train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform_train)

# Split the dataset into train and validation sets
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size

train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Update the validation dataset to use the validation transforms
val_dataset.dataset.transform = transform_val

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

print(f"Train set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")

## Training

### Load Teacher model

Import form:https://huggingface.co/edadaltocg/resnet18_cifar10

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import timm
teacher_model = timm.create_model("resnet18", pretrained=False).to(device)

# override teacher_model
teacher_model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
teacher_model.maxpool = nn.Identity()  # type: ignore
teacher_model.fc = nn.Linear(512,  10)

teacher_model.load_state_dict(
            torch.hub.load_state_dict_from_url(
                      "https://huggingface.co/edadaltocg/resnet18_cifar10/resolve/main/pytorch_model.bin",
                       map_location=device,
                       file_name="resnet18_cifar10.pth",
             )
)

## Dynamic Temperature Knowledge Distillation (DTKD)
Reference from: https://arxiv.org/abs/2404.12711

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

def train(teacher_model, student_model, train_loader, val_loader, epochs=50, alpha=3.0, beta=1.0, gamma=1.0, reference_temperature=4, lr=0.001):
    # Define the DTKD loss function
    def dtkd_loss(student_logits, teacher_logits, labels, reference_temperature, alpha, beta, gamma):
        # Avoid numerical instability
        eps = 1e-6

        # Compute the maximum values of logits
        teacher_max, _ = teacher_logits.max(dim=1, keepdim=True)
        student_max, _ = student_logits.max(dim=1, keepdim=True)

        # Dynamically compute the temperature
        T_tea = (2 * teacher_max / (teacher_max + student_max + eps)) * reference_temperature
        T_stu = (2 * student_max / (teacher_max + student_max + eps)) * reference_temperature

        # Compute soft labels with dynamic temperature
        teacher_soft_dynamic = F.softmax(teacher_logits / T_tea, dim=1)
        student_soft_dynamic = F.log_softmax(student_logits / T_stu, dim=1)
        dtkd_kl_loss = F.kl_div(student_soft_dynamic, teacher_soft_dynamic, reduction='batchmean') * T_tea.mean() * T_stu.mean()

        # Traditional KD (fixed temperature)
        teacher_soft_fixed = F.softmax(teacher_logits / reference_temperature, dim=1)
        student_soft_fixed = F.log_softmax(student_logits / reference_temperature, dim=1)
        kl_loss = F.kl_div(student_soft_fixed, teacher_soft_fixed, reduction='batchmean') * (reference_temperature ** 2)

        # Cross-entropy loss
        ce_loss = F.cross_entropy(student_logits, labels)

        # Combine the losses
        return alpha * dtkd_kl_loss + beta * kl_loss + gamma * ce_loss

    student_model.train()

    optimizer = torch.optim.AdamW(
        student_model.parameters(),
        lr=lr,
        weight_decay=0.01
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    save_dir = "./output"
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    teacher_model.to(device).eval()
    student_model.to(device)

    best_acc = 0
    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            with torch.no_grad():
                teacher_logits = teacher_model(inputs)  # Teacher logits

            student_logits = student_model(inputs)  # Student logits
            loss = dtkd_loss(student_logits, teacher_logits, labels, reference_temperature, alpha, beta, gamma)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = student_logits.max(1)
            total_train += labels.size(0)
            correct_train += predicted.eq(labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct_train / total_train

        student_model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = student_model(inputs)
                loss = F.cross_entropy(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total_val += labels.size(0)
                correct_val += predicted.eq(labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100.0 * correct_val / total_val

        print(f"Epoch {epoch + 1}/{epochs} ", end="")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% ", end="")
        print(f"val Loss: {val_loss:.4f}, val Acc: {val_acc:.2f}% ")

        scheduler.step()

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(student_model.state_dict(), os.path.join(save_dir, f"best_model.pth"))

In [None]:
model = Resnet8v1EEMBC()
train(teacher_model, model, train_loader, val_loader, epochs=200, lr=0.01)

## Pruning

Reference from: https://github.com/VainF/Torch-Pruning

In [None]:
!pip install torch-pruning --upgrade

In [None]:
import torch_pruning as tp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imp = tp.importance.BNScaleImportance()

ignored_layers = []
for name, m in model.named_modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 10:
        ignored_layers.append(m)
example_inputs = torch.randn(128, 3, 32, 32).to(device)
pruner = tp.pruner.MetaPruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio_dict = {
        model.first_stack: 0.2,
        model.second_stack: 0.3,
        model.third_stack: 0.4,
    },

    ignored_layers=ignored_layers,
    round_to=4,
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")

In [None]:
train(teacher_model, model, train_loader, val_loader, epochs=200, lr=0.001)

In [None]:
def finetune(model, train_loader, val_loader, epochs=50, lr=0.001):
    # Define the cross-entropy loss function
    def cross_entropy_loss(student_logits, labels):
        return F.cross_entropy(student_logits, labels)

    model.train()

    # Replace with AdamW optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=0.01  # Regularization term for AdamW
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    # Directory to save training outputs
    save_dir = "./output"
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Initialize the model
    model.to(device)

    # Training process
    best_acc = 0
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0

        # Training loop
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            student_logits = model(inputs)  # Student logits
            loss = cross_entropy_loss(student_logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = student_logits.max(1)
            total_train += labels.size(0)
            correct_train += predicted.eq(labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct_train / total_train

        # valuate the model
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = F.cross_entropy(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total_val += labels.size(0)
                correct_val += predicted.eq(labels).sum().item()

        val_loss /= len(val_loader)
        val_acc = 100.0 * correct_val / total_val

        print(f"Epoch {epoch + 1}/{epochs} ", end="")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% ", end="")
        print(f"val Loss: {val_loss:.4f}, val Acc: {val_acc:.2f}% ")

        scheduler.step()

        # Save the best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(save_dir, f"best_model.pth"))

In [None]:
finetune(model, train_loader, val_loader, epochs=50, lr=0.001)

## Quantization Aware Training and Convert to tflite

Reference form: https://github.com/alibaba/TinyNeuralNetwork

In [None]:
!pip install git+https://github.com/alibaba/TinyNeuralNetwork.git

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from tinynn.util.train_util import DLContext, get_device, train
from tinynn.util.cifar10 import get_dataloader, train_one_epoch, train_one_epoch_distill, validate
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.converter import TFLiteConverter


def quantization(model, train_loader, val_loader):
    device = get_device()
    model.to(device=device)

    # Provide a dummy input for the model
    dummy_input = torch.rand((1, 3, 32, 32))

    # Get CIFAR-10 dataloaders
    context = DLContext()
    context.device = device
    context.train_loader, context.val_loader = train_loader, val_loader

    print("Validation accuracy of the original model")
    validate(model, context)

    print("Start preparing the model for quantization")
    config = {
        'backend': "qnnpack",
        'force_overwrite': True,
        'asymmetric': True,
        'per_tensor': False,
        'set_quantizable_op_stats': True
    }
    quantizer = QATQuantizer(model, dummy_input, work_dir='out', config=config)
    qat_model = quantizer.quantize()

    print("Start quantization-aware training")
    qat_model.to(device=device)

    context = DLContext()
    context.device = device
    context.train_loader, context.val_loader = train_loader, val_loader
    context.max_epoch = 5
    context.criterion = nn.CrossEntropyLoss()
    context.optimizer = torch.optim.SGD(
        qat_model.parameters(),
        lr=0.001,
        momentum=0.9,
        weight_decay=0.0005,
        nesterov=True
    )
    context.scheduler = CosineAnnealingLR(context.optimizer, T_max=context.max_epoch + 1, eta_min=0)

    # Perform QAT training
    train(qat_model, context, train_one_epoch, validate, qat=True)

    print("Start converting the model to TFLite")
    with torch.no_grad():
        qat_model.eval()
        qat_model.to('cpu')
        dummy_input.to('cpu')
        qat_model = quantizer.convert(qat_model)
        torch.backends.quantized.engine = 'qnnpack'

        # Convert the model to TFLite format
        converter = TFLiteConverter(
            qat_model, dummy_input, tflite_path='output/qat_model.tflite',
            quantize_target_type='int8', fuse_quant_dequant=True,
            rewrite_quantizable=True, tflite_micro_rewrite=True
        )
        converter.convert()

    print("Quantization completed and model saved as TFLite.")

In [None]:
class ModelWithSoftmax(nn.Module):
    def __init__(self, base_model):
        super(ModelWithSoftmax, self).__init__()
        self.base_model = base_model
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.base_model(x)
        return self.softmax(x)

model = ModelWithSoftmax(model)
quantization(model, train_loader, val_loader)