#Teacher

##Dataset

In [None]:
#Load your dataset
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([transforms.Resize((224, 224)),
                                transforms.ToTensor()])


training_data= datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=transform
)


train_dataloader = DataLoader(training_data, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43388233.41it/s]


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


## TeacherModel

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

class TeacherModel(nn.Module):
    def __init__(self, num_classes, num_input_channels=3):
        super(TeacherModel, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)

        if num_input_channels == 1:
            self.resnet50.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

        self.fc = nn.Linear(1000, num_classes)

    def forward(self, x):
        x = self.resnet50(x)
        x = self.fc(x)
        return x

In [None]:
import copy
teacher_model_transfer = TeacherModel(num_classes = 10)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 92.4MB/s]


In [None]:
from torchsummary import summary
summary(teacher_model_transfer.cuda(), (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [None]:
for param in teacher_model_transfer.parameters():
	param.requires_grad = False


modelOutputFeats = teacher_model_transfer.fc.in_features
teacher_model_transfer.fc = nn.Linear(modelOutputFeats, 10)

In [None]:
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(teacher_model_transfer.fc.parameters(), lr=0.001)

In [None]:
from torch.optim import lr_scheduler
step_lr_scheduler = lr_scheduler.StepLR(opt, step_size=7, gamma=0.1)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
teacher_model_transfer.to(device)

TeacherModel(
  (resnet50): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
    

In [None]:
import time

dataloaders = {'train': train_dataloader,'val' : test_dataloader}

training_loss,training_acc,val_loss,val_acc = [],[],[],[]

def train_model(model, criterion,optimizer, scheduler, num_epochs):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloaders[phase])
            epoch_acc = running_corrects.double() / len(dataloaders[phase])

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            if phase == 'train':
              training_loss.append(epoch_loss)
              training_acc.append(epoch_acc)
            else:
              val_loss.append(epoch_loss)
              val_acc.append(epoch_acc)

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
model_T = train_model(teacher_model_transfer,lossFunc,opt,step_lr_scheduler,num_epochs =2)

KeyboardInterrupt: ignored

In [None]:
for param_tensor in teacher_model_transfer.state_dict():
    print(param_tensor, "\t", teacher_model_transfer.state_dict()[param_tensor].size())

In [None]:
teacher_training_acc_cpu_T = [acc.cpu().numpy() for acc in training_acc]
teacher_val_acc_cpu_T = [acc.cpu().numpy() for acc in val_acc]

In [None]:
import matplotlib.pyplot as plt
plt.figure(1)
# summarize history for accuracy
plt.subplot(211)
plt.plot(training_acc_cpu_T)
plt.plot(val_acc_cpu_T)
plt.title('teacher model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='lower right')
 # summarize history for loss
plt.subplot(212)
plt.plot(training_loss)
plt.plot(val_loss)
plt.title('teacher model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='upper right')
plt.show()

#Student

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BottleneckLayer(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super(BottleneckLayer, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.downsample = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )

    def forward(self, x):
        residual = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x, inplace=True)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x, inplace=True)

        x = self.conv3(x)
        x = self.bn3(x)

        residual = self.downsample(residual)
        x += residual
        x = F.relu(x, inplace=True)

        return x



class StudentModel(nn.Module):
    def __init__(self, num_classes=1000):
        super(StudentModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(64,64, 1, stride=1)
        self.layer2 = self._make_layer(256,128, 1, stride=2)
        self.layer3 = self._make_layer(512,256, 1, stride=2)
        self.layer4 = self._make_layer(1024,512, 1, stride=2)

        '''resnet50 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
        self.last_layers = nn.Sequential(
            resnet50.layer4[-1],
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(2048, num_classes)
        )'''

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * BottleneckLayer.expansion, num_classes)

    def _make_layer(self,input_channels, out_channels, blocks, stride):
        layers = [BottleneckLayer(input_channels, out_channels, stride)]
        for _ in range(1, blocks):
            layers.append(BottleneckLayer(out_channels * BottleneckLayer.expansion, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        #x = self.last_layers(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


In [None]:
Student_model = StudentModel()

In [None]:
modelOutputFeats_1 = Student_model.fc.in_features
Student_model.fc = nn.Linear(modelOutputFeats_1, 10)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Student_model.to(device)

StudentModel(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BottleneckLayer(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      

In [None]:
summary(Student_model.cuda(), (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
            Conv2d-7           [-1, 64, 56, 56]          36,864
       BatchNorm2d-8           [-1, 64, 56, 56]             128
            Conv2d-9          [-1, 256, 56, 56]          16,384
      BatchNorm2d-10          [-1, 256, 56, 56]             512
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
  BottleneckLayer-13          [-1, 256, 56, 56]               0
           Conv2d-14          [-1, 128,

# Student kd Teacher

In [None]:
num_epochs = 5
for epoch in range(num_epochs):
    Student_model.train()  # Set the model to training mode
    running_loss = 0.0

    for inputs, labels in train_dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the gradients
        opt.zero_grad()

        # Forward pass
        outputs = Student_model(inputs)

        # Compute the loss
        loss = lossFunc(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        opt.step()

        # Update the running loss
        running_loss += loss.item()

    # Print the average loss for the epoch
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_dataloader)}')

    # Evaluation on the test set
    Student_model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = Student_model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy}')

Epoch 1/5, Loss: 2.3246826475071907
Test Accuracy: 0.0957
Epoch 2/5, Loss: 2.3246826475071907
Test Accuracy: 0.0945
Epoch 3/5, Loss: 2.3246826475071907
Test Accuracy: 0.0981
Epoch 4/5, Loss: 2.3246826475071907
Test Accuracy: 0.096


#student kd

In [None]:
teacher_model =  TeacherModel(num_classes=10)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
teacher_model.to(device)

TeacherModel(
  (resnet50): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
    

In [None]:
def kd_loss(student_logits, teacher_logits, temperature=1.0, alpha=0.5):
    soft_target = F.softmax(teacher_logits / temperature, dim=1)
    soft_student_output = F.log_softmax(student_logits / temperature, dim=1)
    loss = F.kl_div(soft_student_output, soft_target, reduction='batchmean') * (temperature ** 2) * alpha
    return loss

ce_loss = nn.CrossEntropyLoss()

optimizer = optim.Adam(Student_model.parameters(), lr=0.001)

num_epochs = 10
temperature = 5.0
alpha = 0.5

for epoch in range(num_epochs):
    Student_model.train()
    for data in train_dataloader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        # Forward pass
        teacher_logits = teacher_model(inputs)
        student_logits = Student_model(inputs)

        # Calculate total loss (combination of cross-entropy and distillation loss)
        loss = ce_loss(student_logits, labels) + kd_loss(student_logits, teacher_logits, temperature, alpha)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    # Validation loop
    running_corrects = 0
    Student_model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for val_data in test_dataloader:
            val_inputs, val_labels = val_data
            val_inputs, val_labels = inputs.to(device), labels.to(device)
            val_teacher_logits = teacher_model(val_inputs)
            val_student_logits = Student_model(val_inputs)

            # Calculate validation loss
            val_loss += ce_loss(val_student_logits, val_labels).item()

            # Calculate validation accuracy
            _, predicted = val_student_logits.max(1)
            total += val_labels.size(0)
            running_corrects += torch.sum(predicted == val_labels.data)
            #correct += predicted.eq(val_labels).sum().item()

    avg_val_loss = val_loss / len(test_dataloader)
    val_accuracy = running_corrects / total

    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Epoch Loss: {avg_val_loss}, Epoch Accuracy: {val_accuracy}')

Epoch 1/10, Loss: 1.6822870969772339, Epoch Loss: 1.860161304473877, Epoch Accuracy: 0.0
Epoch 2/10, Loss: 2.03898024559021, Epoch Loss: 1.6525788307189941, Epoch Accuracy: 0.0
Epoch 3/10, Loss: 1.1226084232330322, Epoch Loss: 0.08560886234045029, Epoch Accuracy: 1.0
Epoch 4/10, Loss: 1.6610631942749023, Epoch Loss: 1.2347629070281982, Epoch Accuracy: 1.0
Epoch 5/10, Loss: 1.2530322074890137, Epoch Loss: 2.03286075592041, Epoch Accuracy: 0.0


In [None]:
import matplotlib.pyplot as plt
plt.figure(1)
# summarize history for accuracy
plt.subplot(211)
plt.plot(training_acc_cpu_T)
plt.plot(val_acc_cpu_T)
plt.title('teacher model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='lower right')
 # summarize history for loss
plt.subplot(212)
plt.plot(training_loss)
plt.plot(val_loss)
plt.title('teacher model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'valid'], loc='upper right')
plt.show()