# MLDL2 Homework 4

In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torchvision import datasets
import torch.nn.functional as F
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# 1. Load the CIFAR-100 Datasets


In [2]:
BATCH_SIZE = 64
VAL_SPLIT_RATIO = 0.2  # You can modify it

cifar100_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
    ])

cifar100_train_dataset = datasets.CIFAR100(root="./data/", train=True, download=True, transform=cifar100_transform)

num_train = len(cifar100_train_dataset)
indices = torch.randperm(num_train)

val_split = int(num_train * VAL_SPLIT_RATIO)
train_indices = indices[val_split:]
val_indices = indices[:val_split]

#Do not change below code
cifar100_val_dataset = torch.utils.data.Subset(cifar100_train_dataset, val_indices)
cifar100_train_dataset = torch.utils.data.Subset(cifar100_train_dataset, train_indices)

cifar100_train_loader = torch.utils.data.DataLoader(dataset=cifar100_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
cifar100_val_loader = torch.utils.data.DataLoader(dataset=cifar100_val_dataset, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified


In [None]:
# Number of samples in the dataset

print("cifar100 train dataset size : ", len(cifar100_train_dataset))
print("cifar100 validation dataset size : ", len(cifar100_val_dataset))

## CIFAR-100 Visualization

In [None]:
# Plot the training images and labels

cifar100_denormalize = transforms.Normalize(mean=[-0.507/0.267, -0.487/0.256, -0.441/0.276], std=[1/0.267, 1/0.256, 1/0.276])
to_pil_image = transforms.functional.to_pil_image

images, labels = next(iter(cifar100_train_loader))

fig, ax = plt.subplots(1, 4, figsize=(16, 4))
ax[0].imshow(to_pil_image(cifar100_denormalize(images[0])))
ax[1].imshow(to_pil_image(cifar100_denormalize(images[1])))
ax[2].imshow(to_pil_image(cifar100_denormalize(images[2])))
ax[3].imshow(to_pil_image(cifar100_denormalize(images[3])))
plt.show()

print(labels[:4])

# 2. Load Pretrained Model

Information of this pretrained model is here: https://huggingface.co/edadaltocg/resnet50_cifar100

In [None]:
!pip install detectors

In [3]:
# Do not modify the code
import detectors
import timm

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

teacher = timm.create_model("resnet50_cifar100", pretrained=True)
teacher.to(device)
teacher.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): Identity()
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
      (downsample): Sequential(
    

# 3. Define the Student Model Architecture

Here we define the model. Below is very simple model with CNN. You can customize your own model and note that you are not limited to use any methods. **But you are not allowed to use pretrained weight**

In [4]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as torch_models
import torch.nn.functional as F
import torch.optim as optim


class ConvNetMaker(nn.Module):
	"""
	Creates a simple (plane) convolutional neural network
	"""
	def __init__(self, layers):
		"""
		Makes a cnn using the provided list of layers specification
		The details of this list is available in the paper
		:param layers: a list of strings, representing layers like ["CB32", "CB32", "FC10"]
		"""
		super(ConvNetMaker, self).__init__()
		self.conv_layers = []
		self.fc_layers = []
		h, w, d = 32, 32, 3
		previous_layer_filter_count = 3
		previous_layer_size = h * w * d
		num_fc_layers_remained = len([1 for l in layers if l.startswith('FC')])
		for layer in layers:
			if layer.startswith('Conv'):
				filter_count = int(layer[4:])
				self.conv_layers += [nn.Conv2d(previous_layer_filter_count, filter_count, kernel_size=3, padding=1),
                                        nn.BatchNorm2d(filter_count), nn.ReLU(inplace=True)]
				previous_layer_filter_count = filter_count
				d = filter_count
				previous_layer_size = h * w * d
			elif layer.startswith('MaxPool'):
				self.conv_layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
				h, w = int(h / 2.0), int(w / 2.0)
				previous_layer_size = h * w * d
			elif layer.startswith('FC'):
				num_fc_layers_remained -= 1
				current_layer_size = int(layer[2:])
				if num_fc_layers_remained == 0:
					self.fc_layers += [nn.Linear(previous_layer_size, current_layer_size)]
				else:
					self.fc_layers += [nn.Linear(previous_layer_size, current_layer_size), nn.ReLU(inplace=True)]
				previous_layer_size = current_layer_size
		
		conv_layers = self.conv_layers
		fc_layers = self.fc_layers
		self.conv_layers = nn.Sequential(*conv_layers)
		self.fc_layers = nn.Sequential(*fc_layers)
	
	def forward(self, x):
		x = self.conv_layers(x)
		x = x.view(x.size(0), -1)
		x = self.fc_layers(x)
		return x



plane_cifar10_book = {
	'2': ['Conv16', 'MaxPool', 'Conv16', 'MaxPool', 'FC10'],
	'4': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'FC10'],
	'6': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC10'],
	'8': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool',
            'Conv128', 'Conv128','MaxPool', 'FC64', 'FC10'],
	'10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool',
            'Conv256', 'Conv256', 'Conv256', 'Conv256' , 'MaxPool', 'FC128' ,'FC10'],
}


plane_cifar100_book = {
	'2': ['Conv32', 'MaxPool', 'Conv32', 'MaxPool', 'FC100'],
	'4': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC100'],
	'6': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool','Conv128', 'Conv128' ,'FC100'],
	'8': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool',
            'Conv256', 'Conv256','MaxPool', 'FC64', 'FC100'],
	'10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool',
            'Conv256', 'Conv256', 'Conv256', 'Conv256' , 'MaxPool', 'FC512', 'FC100'],
}

In [5]:
###############################################
# 3. ResNet for CIFAR 정의
###############################################
import math
import torchvision

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion=1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
    def forward(self, x):
        residual = x
        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual; out = self.relu(out)
        return out

class ResNet_Cifar(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet_Cifar, self).__init__()
        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
                m.weight.data.normal_(0, math.sqrt(2./n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1); m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample=None
        if stride!=1 or self.inplanes!=planes*block.expansion:
            downsample=nn.Sequential(
                nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes*block.expansion)
            )
        layers=[]
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes=planes*block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self,x):
        x=self.conv1(x); x=self.bn1(x); x=self.relu(x)
        x=self.layer1(x); x=self.layer2(x); x=self.layer3(x)
        x=self.avgpool(x); x=x.view(x.size(0),-1); x=self.fc(x)
        return x

def resnet14_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [2,2,2], **kwargs)
    return model

def resnet8_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [1,1,1], **kwargs)
    return model

def resnet20_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [3,3,3], **kwargs)
    return model

def resnet18_cifar(**kwargs):
    # just for placeholder, using torchvision resnet18 (not really for cifar)
    model = torchvision.models.resnet18(pretrained=False)
    # modify last layer for correct num_classes if needed
    num_classes = kwargs.get('num_classes', 10)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def resnet26_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [4,4,4], **kwargs)
    return model

def resnet32_cifar(**kwargs):
    model = ResNet_Cifar(BasicBlock, [5,5,5], **kwargs)
    return model

def resnet34_cifar(**kwargs):
    model = torchvision.models.resnet34(pretrained=False)
    num_classes = kwargs.get('num_classes', 10)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

resnet_book = {
    '8': resnet8_cifar,
    '14': resnet14_cifar,
    '20': resnet20_cifar,
    '18': resnet18_cifar,
    '26': resnet26_cifar,
    '32': resnet32_cifar,
    '34': resnet34_cifar,
}

In [6]:
def is_resnet(name):
	"""
	Simply checks if name represents a resnet, by convention, all resnet names start with 'resnet'
	:param name:
	:return:
	"""
	name = name.lower()
	return name.startswith('resnet')


def create_cnn_model(name, dataset="cifar100", use_cuda=False):
	"""
	Create a student for training, given student name and dataset
	:param name: name of the student. e.g., resnet110, resnet32, plane2, plane10, ...
	:param dataset: the dataset which is used to determine last layer's output size. Options are cifar10 and cifar100.
	:return: a pytorch student for neural network
	"""
	num_classes = 100 if dataset == 'cifar100' else 10
	model = None
	if is_resnet(name):
		resnet_size = name[6:]
		resnet_model = resnet_book.get(resnet_size)(num_classes=num_classes)
		model = resnet_model
		
	else:
		plane_size = name[5:]
		model_spec = plane_cifar10_book.get(plane_size) if num_classes == 10 else plane_cifar100_book.get(plane_size)
		plane_model = ConvNetMaker(model_spec)
		model = plane_model

	# copy to cuda if activated
	if use_cuda:
		model = model.cuda()
		
	return model


# 4. Implement the Distillation Process

Here, you will implement distillation using a pretrained teacher model.

**The code below is just a sample training code, and does not implement the distillation method.**

Please make sure to implement the distillation method according to your own understanding.

You can change loss function, optimizer, number of epoch.

# Single Model for ResNet : Training & Validation and Test Processing

In [None]:
import os
import copy
import torch
import argparse
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import torchvision
import numpy as np
from torch.utils.data import Dataset, DataLoader

os.environ["OMP_NUM_THREADS"] = '4'
os.environ["OMP_THREAD_LIMIT"] = '4'
os.environ["MKL_NUM_THREADS"] = '4'
os.environ["NUMEXPR_NUM_THREADS"] = '4'
os.environ["OMP_NUM_THREADS"] = '4'
os.environ["PAPERLESS_AVX2_AVAILABLE"] = "false"
os.environ["OCR_THREADS"] = '4'


def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    else:
        return False

def load_checkpoint(model, checkpoint_path):
    model_ckp = torch.load(checkpoint_path)
    model.load_state_dict(model_ckp['model_state_dict'])
    return model

class LabelSmoothingCrossEntropy(torch.nn.Module):
    """
    라벨 스무딩을 적용한 CrossEntropyLoss.
    smoothing: 스무딩 정도 (0이면 기본 CE와 동일)
    """
    def __init__(self, smoothing=0.0):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):
        log_probs = F.log_softmax(pred, dim=-1)
        n_classes = pred.size(-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        return (-true_dist * log_probs).sum(dim=-1).mean()

class TrainManager(object):
    def __init__(self, student, teacher=None, train_loader=None, val_loader=None, train_config={}):
        self.student = student
        self.teacher = teacher
        self.have_teacher = bool(self.teacher)
        self.device = train_config['device']
        self.name = train_config['name']
        self.optimizer = optim.SGD(self.student.parameters(),
                                   lr=train_config['learning_rate'],
                                   momentum=train_config['momentum'],
                                   weight_decay=train_config['weight_decay'])
        if self.have_teacher:
            self.teacher.eval()
            self.teacher.train(mode=False)
            
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = train_config

        self.best_acc = 0.0
        self.best_model_path = None

        self.label_smoothing = train_config.get('label_smoothing', 0.0)
    
    def train(self):
        lambda_ = self.config['lambda_student']
        T = self.config['T_student']
        epochs = self.config['epochs']
        trial_id = self.config['trial_id']

        criterion = LabelSmoothingCrossEntropy(smoothing=self.label_smoothing)
        
        print("---------- Start Training ----------")
        for epoch in range(epochs):
            
            if epoch >= 1 :
                print("========== Next Epoch ==========")
                
            self.student.train()
            self.adjust_learning_rate(self.optimizer, epoch)

            train_loop = tqdm(self.train_loader, desc=f"Epoch {epoch+1}")
            
            running_loss = 0.0
            loss_SL_total = 0.0
            loss_KD_total = 0.0
            total_batches = 0
            
            for batch_idx, (data, target) in enumerate(train_loop):
                data = data.to(self.device)
                target = target.to(self.device)
                self.optimizer.zero_grad()
                output = self.student(data)
                
                # Classification Loss (with label smoothing)
                loss_SL = criterion(output, target)
                loss = loss_SL
                kd_loss_value = 0.0
                
                if self.have_teacher:
                    with torch.no_grad():
                        teacher_outputs = self.teacher(data)
                    # KD Loss
                    loss_KD = F.kl_div(F.log_softmax(output / T, dim=1),
                                       F.softmax(teacher_outputs / T, dim=1),
                                       reduction='batchmean')
                    loss = (1 - lambda_) * loss_SL + lambda_ * (T * T) * loss_KD
                    kd_loss_value = loss_KD.item()
                    
                loss.backward()
                self.optimizer.step()
                
                train_loop.set_postfix(loss=loss.item())
                running_loss += loss.item()
                loss_SL_total += loss_SL.item()
                loss_KD_total += kd_loss_value
                total_batches += 1
            
            avg_loss = running_loss / total_batches if total_batches > 0 else 0.0
            avg_ce_loss = loss_SL_total / total_batches if total_batches > 0 else 0.0
            avg_kd_loss = loss_KD_total / total_batches if (total_batches > 0 and self.have_teacher) else 0.0
            
            train_acc = self.evaluate_accuracy(self.train_loader)
            val_acc = self.validate(step=epoch)
            print(f"Epoch [{epoch+1}/{epochs}] - Avg Total Loss: {avg_loss:.4f} | CE Loss: {avg_ce_loss:.4f} | KD Loss: {avg_kd_loss:.4f} | Train ACC: {train_acc:.2f}% | Val ACC: {val_acc:.2f}%")

            if val_acc > self.best_acc:
                self.best_acc = val_acc
                best_name = '{}_{}_best.pth.tar'.format(self.name, trial_id)
                self.save(epoch, name=best_name)
                self.best_model_path = best_name
        
        print("===== Training Complete =====")
        print(f"Best Validation Accuracy: {self.best_acc:.2f}%")
        if self.best_model_path is not None:
            print(f"Best model saved at: {self.best_model_path}")
        return self.best_acc
    
    def validate(self, step=0):
        self.student.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in self.val_loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs = self.student(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        acc = 100 * correct / total
        return acc

    def evaluate_accuracy(self, loader):
        self.student.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in loader:
                images = images.to(self.device)
                labels = labels.to(self.device)
                outputs = self.student(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        acc = 100 * correct / total
        return acc
    
    def save(self, epoch, name=None):
        if name is None:
            name = '{}_{}_epoch{}.pth.tar'.format(self.name, self.config['trial_id'], epoch)
        torch.save({
            'model_state_dict': self.student.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'epoch': epoch,
        }, name)
    
    def adjust_learning_rate(self, optimizer, epoch):
        epochs = self.config['epochs']
        models_are_plane = self.config['is_plane']
        
        if models_are_plane:
            lr = 0.01
        else:
            if epoch < int(epoch/2.0):
                lr = 0.1
            elif epoch < int(epochs*3/4.0):
                lr = 0.1 * 0.1
            else:
                lr = 0.1 * 0.01
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

class TestDataset(Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        return image

def test(model, test_loader, device):
    model.eval()
    test_predictions = []
    with torch.inference_mode():
        for i, data in enumerate(tqdm(test_loader, desc="Testing")):
            data = data.float().to(device)
            output = model(data)
            test_predictions.append(output.cpu())
    return torch.cat(test_predictions, dim=0)

def main(args):
    if args.cuda and torch.cuda.is_available():
        device = f"cuda:{args.gpu_device}"
    else:
        device = "cpu"
    print(f"Using device: {device}")
    
    student_model = create_cnn_model(args.student, args.dataset, use_cuda=(device.startswith('cuda')))
    student_model.to(device)
    teacher_model = teacher.to(device) if teacher else None
    if teacher_model:
        teacher_model.eval()

    trial_id = "manual_trial"

    train_config = {
        'epochs': args.epochs,
        'learning_rate': args.learning_rate,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay,
        'device': device,
        'is_plane': not is_resnet(args.student),
        'trial_id': trial_id,
        'T_student': args.T_student,
        'lambda_student': args.lambda_student,
        'name': args.student,
        'label_smoothing': args.label_smoothing
    }

    print("=========== Training Student Model (with Label Smoothing)! ===========")
    train_loader = cifar100_train_loader
    val_loader = cifar100_val_loader
    student_trainer = TrainManager(student_model, teacher=teacher_model, train_loader=train_loader, val_loader=val_loader, train_config=train_config)
    best_student_acc = student_trainer.train()
    print("Best Student Accuracy:", best_student_acc)

    if student_trainer.best_model_path is not None:
        print("Loading best model for test set predictions...")
        load_checkpoint(student_model, student_trainer.best_model_path)

    images = np.load("./cifar100_test_images.npy")
    images = torch.tensor(images, dtype=torch.float32)

    test_dataset = TestDataset(images)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

    predictions = test(student_model, test_loader, device)

    model_name = train_config['name']
    np.save(f'./Test_results_{model_name}.npy', predictions.numpy())
    print(f"Test results saved to ./Test_results_{model_name}.npy")

if __name__ == "__main__":
    def parse_arguments():
        parser = argparse.ArgumentParser(description='TA Knowledge Distillation Code with Label Smoothing')
        parser.add_argument('--epochs', default=100, type=int, help='number of total epochs to run')
        parser.add_argument('--dataset', default='cifar100', type=str, help='dataset. can be either cifar10 or cifar100')
        parser.add_argument('--batch-size', default=128, type=int, help='batch_size')
        parser.add_argument('--learning-rate', default=0.1, type=float, help='initial learning rate')
        parser.add_argument('--momentum', default=0.9, type=float, help='SGD momentum')
        parser.add_argument('--weight-decay', default=1e-4, type=float, help='SGD weight decay (default: 1e-4)')
        parser.add_argument('--teacher', default='resnet50', type=str, help='teacher model name')
        parser.add_argument('--student', '--model', default='resnet34', type=str, help='student model name')
        parser.add_argument('--teacher-checkpoint', default='', type=str, help='optional pretrained checkpoint for teacher')
        parser.add_argument('--cuda', default=True, type=str2bool, help='whether or not use cuda(train on GPU)')
        parser.add_argument('--gpu-device', default=0, type=int, help='Which GPU device to use (e.g., 0, 1, 2, ...)')
        parser.add_argument('--dataset-dir', default='./data', type=str, help='dataset directory')
        parser.add_argument('--T-student', default=4, type=float, help='Temperature for knowledge distillation')
        parser.add_argument('--lambda-student', default=0.5, type=float, help='Lambda for balancing KD loss and CE loss')
        parser.add_argument('--label-smoothing', default=0.1, type=float, help='Label smoothing factor (0 means no smoothing)')
        return parser

    parser = parse_arguments()
    args, unknown = parser.parse_known_args()
    print(args)
    main(args)


# Ensemble for ResNet-N : Training & Validation and Test Processing

In [None]:
import os
import copy
import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random
import shutil
import logging
import matplotlib.pyplot as plt

def str2bool(v):
    return v.lower() in ('yes','true','t','y','1')

def load_checkpoint(model, checkpoint_path):
    model_ckp = torch.load(checkpoint_path)
    model.load_state_dict(model_ckp['model_state_dict'])
    return model

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.0):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing
    def forward(self, pred, target):
        log_probs = F.log_softmax(pred, dim=-1)
        n_classes = pred.size(-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.smoothing / (n_classes - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        return (-true_dist * log_probs).sum(dim=-1).mean()

class TestDataset(Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)
        return image 

def ensemble_inference_no_label(model_list, test_loader, device, save_path="./Test_ensemble_results.npy"):
    for m in model_list:
        m.eval()
    all_preds = []
    with torch.inference_mode():
        for images in tqdm(test_loader, desc="Ensemble Inference(no-label)"):
            images = images.to(device)
            prob_list = []
            for m in model_list:
                outputs = m(images)
                probs = F.softmax(outputs, dim=1) 
                prob_list.append(probs)
            avg_prob = torch.mean(torch.stack(prob_list, dim=0), dim=0)
            _, preds = torch.max(avg_prob, 1)
            all_preds.append(preds.cpu())
    all_preds = torch.cat(all_preds, dim=0).numpy()
    np.save(save_path, all_preds)
    print(f"[Ensemble Inference] Test predictions saved to {save_path} (shape={all_preds.shape})")

def ensemble_inference_val(model_list, val_loader, device):
    for m in model_list:
        m.eval()
    running_corrects = 0
    total = 0
    with torch.inference_mode():
        for images, labels in tqdm(val_loader, desc="Ensemble Inference(val)"):
            images, labels = images.to(device), labels.to(device)
            prob_list = []
            for m in model_list:
                outputs = m(images)
                probs = F.softmax(outputs, dim=1)
                prob_list.append(probs)
            avg_prob = torch.mean(torch.stack(prob_list, dim=0), dim=0)
            _, preds = torch.max(avg_prob, 1)
            running_corrects += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = (running_corrects / total) if total > 0 else 0.0
    return val_acc

class EnsembleTrainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        self.n_ens = args.n_ens
        self.label_smoothing = args.label_smoothing
        self.T = args.T_student
        self.lambda_ = args.lambda_student
        self.epochs = args.epochs
        self.dataset = args.dataset
        self.student_name = args.student
        self.learning_rate = args.learning_rate
        self.momentum = args.momentum
        self.weight_decay = args.weight_decay
        self.save_dir = args.save_dir

        self.teacher_model = None
        if args.teacher and args.teacher_checkpoint:
            self.teacher_model = create_cnn_model(args.teacher, self.dataset, use_cuda=(self.device.startswith('cuda')))
            self.teacher_model = load_checkpoint(self.teacher_model, args.teacher_checkpoint)
            self.teacher_model.eval()

        self.train_loader = cifar100_train_loader
        self.val_loader = cifar100_val_loader

        images = np.load("./cifar100_test_images.npy") 
        images = torch.tensor(images, dtype=torch.float32)
        test_dataset = TestDataset(images)
        self.test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

        self.models = [create_cnn_model(self.student_name, self.dataset, use_cuda=(self.device.startswith('cuda'))) for _ in range(self.n_ens)]
        self.optimizers = [optim.SGD(model.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay) 
                           for model in self.models]

        self.train_losses = {}
        self.val_accs = {}
        # best model info
        self.best_val_accs = [0.0]*self.n_ens
        self.best_epochs = [0]*self.n_ens

    def adjust_learning_rate(self, optimizer, epoch):
        epochs = self.epochs
        models_are_plane = not is_resnet(self.student_name)
        if models_are_plane:
            lr = 0.01
        else:
            if epoch < int(epoch/2.0):
                lr = 0.1
            elif epoch < int(epochs*3/4.0):
                lr = 0.1 * 0.1
            else:
                lr = 0.1 * 0.01
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train_one_epoch(self, model, optimizer, epoch_idx, model_idx):
        criterion = LabelSmoothingCrossEntropy(smoothing=self.label_smoothing)
        running_loss = 0.0
        model.train()
        pbar = tqdm(self.train_loader, desc=f"[Train] Model{model_idx+1} Epoch{epoch_idx+1}")
        have_teacher = (self.teacher_model is not None)
        for data, target in pbar:
            data, target = data.to(self.device), target.to(self.device)
            optimizer.zero_grad()
            output = model(data)
            loss_SL = criterion(output, target)
            loss = loss_SL
            if have_teacher:
                with torch.no_grad():
                    teacher_out = self.teacher_model(data)
                loss_KD = F.kl_div(F.log_softmax(output / self.T, dim=1),
                                   F.softmax(teacher_out / self.T, dim=1),
                                   reduction='batchmean')
                loss = (1 - self.lambda_) * loss_SL + self.lambda_ * (self.T * self.T) * loss_KD
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        return running_loss / len(self.train_loader)

    def eval_accuracy(self, model, loader, desc=""):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            pbar = tqdm(loader, desc=desc)
            for data, target in pbar:
                data, target = data.to(self.device), target.to(self.device)
                output = model(data)
                _, pred = torch.max(output, 1)
                correct += (pred == target).sum().item()
                total += target.size(0)
        return correct / total if total > 0 else 0.0

    def save_best_model(self, model, model_idx, epoch, val_acc):
        best_name = os.path.join(self.save_dir, f"Resnet18_model_{model_idx+1}_best.pth.tar")
        torch.save({
            'model_state_dict': model.state_dict(),
            'epoch': epoch,
            'val_acc': val_acc
        }, best_name)
        print(f"Best model of Resnet18_model{model_idx+1} updated at epoch {epoch+1} with val_acc={val_acc*100:.2f}% saved at {best_name}")

    def train_ensemble(self):
        for idx, (model, optimizer) in enumerate(zip(self.models, self.optimizers)):
            print(f"====== Starting training for Model {idx+1}/{self.n_ens} ======")
            train_loss_list = []
            val_acc_list = []
            best_val_acc = 0.0
            best_epoch = 0

            for epoch in range(self.epochs):
                self.adjust_learning_rate(optimizer, epoch)
                epoch_loss = self.train_one_epoch(model, optimizer, epoch, idx)
                train_acc = self.eval_accuracy(model, self.train_loader, desc=f"[Train-ACC] Model{idx+1} Epoch{epoch+1}")
                val_acc = self.eval_accuracy(model, self.val_loader, desc=f"[Val-ACC] Model{idx+1} Epoch{epoch+1}")

                train_loss_list.append(epoch_loss)
                val_acc_list.append(val_acc)
                print(f"[Model : {idx+1}, Epoch : {epoch+1}/{self.epochs}] TrainLoss={epoch_loss:.4f} TrainAcc={train_acc*100:.2f}% ValAcc={val_acc*100:.2f}%")
                print('-----------------------------')

                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    best_epoch = epoch
                    self.save_best_model(model, idx, epoch, val_acc)

            self.train_losses[idx] = train_loss_list
            self.val_accs[idx] = val_acc_list
            self.best_val_accs[idx] = best_val_acc
            self.best_epochs[idx] = best_epoch

    def plot_curve(self, save_dir):
        title = 'Train Loss Curve'
        dpi = 80
        width, height = 1200, 800
        figsize = width/float(dpi), height/float(dpi)

        fig = plt.figure(figsize=figsize)
        x_axis = np.arange(self.epochs)
        if len(self.train_losses) > 0:
            y_max = max([max(v) for v in self.train_losses.values()])
            plt.ylim(0, y_max*1.1)
        plt.xlim(0, self.epochs)
        plt.grid()
        plt.title(title, fontsize=20)
        plt.xlabel('Epoch', fontsize=16)
        plt.ylabel('Train Loss', fontsize=16)

        for e, losses in self.train_losses.items():
            plt.plot(x_axis, losses, label=f'Model-{e+1}', lw=2)
        plt.legend()
        fig.savefig(os.path.join(save_dir, "Resnet18_train_loss_curve.png"), dpi=dpi, bbox_inches='tight')
        plt.close(fig)

        title = 'Val Acc Curve'
        fig = plt.figure(figsize=figsize)
        if len(self.val_accs) > 0:
            y_max = max([max(v) for v in self.val_accs.values()])
            plt.ylim(0, y_max*1.1)
        plt.xlim(0, self.epochs)
        plt.grid()
        plt.title(title, fontsize=20)
        plt.xlabel('Epoch', fontsize=16)
        plt.ylabel('Val Acc', fontsize=16)
        for e, accs in self.val_accs.items():
            plt.plot(x_axis, accs, label=f'Model-{e+1}', lw=2)
        plt.legend()
        fig.savefig(os.path.join(save_dir, "Resnet18_val_acc_curve.png"), dpi=dpi, bbox_inches='tight')
        plt.close(fig)

    def load_best_models(self):
        for idx, model in enumerate(self.models):
            best_path = os.path.join(self.save_dir, f"Resnet18_model_{idx+1}_best.pth.tar")
            if os.path.exists(best_path):
                checkpoint = torch.load(best_path)
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"Resnet18_model{idx+1} best model loaded from epoch {checkpoint['epoch']+1} with val_acc={checkpoint['val_acc']*100:.2f}%")
            else:
                print(f"[Warning] Best model file not found for Model{idx+1}. Using last trained parameters.")

if __name__ == "__main__":
    def parse_arguments():
        parser = argparse.ArgumentParser(description='Ensemble KD with Label Smoothing on CIFAR100')
        parser.add_argument('--epochs', default=50, type=int)
        parser.add_argument('--dataset', default='cifar100', type=str)
        parser.add_argument('--batch-size', default=128, type=int)
        parser.add_argument('--learning-rate', default=0.1, type=float)
        parser.add_argument('--momentum', default=0.9, type=float)
        parser.add_argument('--weight-decay', default=1e-4, type=float)
        parser.add_argument('--teacher', default='resnet50', type=str, help='teacher model name')
        parser.add_argument('--teacher-checkpoint', default='', type=str, help='teacher checkpoint path')
        parser.add_argument('--cuda', default=True, type=str2bool)
        parser.add_argument('--gpu-device', default=0, type=int)
        parser.add_argument('--T-student', default=4, type=float)
        parser.add_argument('--lambda-student', default=0.5, type=float)
        parser.add_argument('--label-smoothing', default=0.1, type=float)
        parser.add_argument('--student', default='resnet18', type=str, help='student model name')
        parser.add_argument('--n_ens', default=10, type=int, help='number of ensemble models')
        parser.add_argument('--save-dir', default='.', type=str)
        return parser

    parser = parse_arguments()
    args, unknown = parser.parse_known_args()
    print(args)
    if args.cuda and torch.cuda.is_available():
        device = f"cuda:{args.gpu_device}"
        torch.cuda.set_device(args.gpu_device)
    else:
        device = "cpu"
    print(f"Using device: {device}")

    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    save_dir = args.save_dir
    os.makedirs(save_dir, exist_ok=True)

    trainer = EnsembleTrainer(args, device)
    trainer.train_ensemble()
    trainer.plot_curve(save_dir)

    trainer.load_best_models()

    val_acc = ensemble_inference_val(trainer.models, trainer.val_loader, device)
    print(f"Ensemble Validation Accuracy with best models: {val_acc*100:.2f}%")

    ensemble_path = os.path.join(save_dir, f"ensemble_predictions_{trainer.student_name}.npy")
    ensemble_inference_no_label(trainer.models, trainer.test_loader, device, save_path=ensemble_path)
    print(f"Ensemble test predictions saved: {ensemble_path}")

    for i in range(trainer.n_ens):
        print(f"Model {i+1}: Best Val Acc={trainer.best_val_accs[i]*100:.2f}% at epoch {trainer.best_epochs[i]+1}")


[Train-ACC] Model8 Epoch48: 100%|██████████| 625/625 [00:09<00:00, 65.30it/s]
[Val-ACC] Model8 Epoch48: 100%|██████████| 157/157 [00:02<00:00, 61.66it/s]


[Model : 8, Epoch : 48/50] TrainLoss=0.8082 TrainAcc=99.98% ValAcc=42.66%
-----------------------------


[Train] Model8 Epoch49: 100%|██████████| 625/625 [00:19<00:00, 32.89it/s]
[Train-ACC] Model8 Epoch49: 100%|██████████| 625/625 [00:09<00:00, 65.13it/s]
[Val-ACC] Model8 Epoch49: 100%|██████████| 157/157 [00:02<00:00, 61.51it/s]


[Model : 8, Epoch : 49/50] TrainLoss=0.8084 TrainAcc=99.98% ValAcc=42.64%
-----------------------------


[Train] Model8 Epoch50: 100%|██████████| 625/625 [00:17<00:00, 34.83it/s]
[Train-ACC] Model8 Epoch50: 100%|██████████| 625/625 [00:09<00:00, 65.43it/s]
[Val-ACC] Model8 Epoch50: 100%|██████████| 157/157 [00:02<00:00, 59.47it/s]


[Model : 8, Epoch : 50/50] TrainLoss=0.8078 TrainAcc=99.98% ValAcc=42.73%
-----------------------------
Best model of Resnet18_model8 updated at epoch 50 with val_acc=42.73% saved at ./Resnet18_model_8_best.pth.tar


[Train] Model9 Epoch1:   0%|          | 0/625 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 1; 7.79 GiB total capacity; 2.26 GiB already allocated; 3.38 MiB free; 2.36 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# Test and Submit
# Do not modify the cell below!!!!

(if you have problem with test dataset path or model name, you can modify them only)

In [9]:
class TestDataset(Dataset):
    def __init__(self, images, transform=None):
        self.images = images
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        if self.transform:
            image = self.transform(image)

        return image

# You can modify the path of test images.
images = np.load("./cifar100_test_images.npy")  # shape: (10000, 3, 32, 32)
images = torch.tensor(images, dtype=torch.float32)

test_dataset = TestDataset(images)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
for images in test_loader:
    print(images.shape)
    break

In [None]:
def test(model, test_loader):
  model.eval()
  test_predictions = []

  with torch.inference_mode():
      for i, data in enumerate(tqdm(test_loader)):
          data = data.float().to(device)
          output = model(data)
          test_predictions.append(output.cpu())

  return torch.cat(test_predictions, dim=0)

In [None]:
# Save test output npy file
predictions = test(student_model, test_loader)
np.save('./Test_results', predictions.numpy())