In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
#Train resnet32 teacher model for cifar10

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from google.colab import drive

drive.mount('/content/drive')


# Define transforms for data augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
])

# Load CIFAR10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.ReLU()(out)
        return out

class ResNet32(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet32, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, 2, stride=1)
        self.layer2 = self._make_layer(32, 2, stride=2)
        self.layer3 = self._make_layer(64, 2, stride=2)
        self.linear = nn.Linear(64, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock(self.in_planes, planes, stride))
            self.in_planes = planes * BasicBlock.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = nn.AvgPool2d(8)(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


#Define training parameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
teacher = ResNet32().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(teacher.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)

# Train the model
for epoch in range(65):
    teacher.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = teacher(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    teacher.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = teacher(images).to(device)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    scheduler.step()

    print('Epoch %d, Loss: %.3f, Test Accuracy: %.3f %%' % (epoch+1, running_loss/len(trainloader), 100*correct/total))

torch.save(teacher.state_dict(), '/content/drive/MyDrive/Teacher32_weights.pth')

Mounted at /content/drive
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 31357336.51it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch 1, Loss: 1.597, Test Accuracy: 46.860 %
Epoch 2, Loss: 1.116, Test Accuracy: 62.940 %
Epoch 3, Loss: 0.930, Test Accuracy: 62.430 %
Epoch 4, Loss: 0.805, Test Accuracy: 68.720 %
Epoch 5, Loss: 0.726, Test Accuracy: 69.620 %
Epoch 6, Loss: 0.673, Test Accuracy: 71.570 %
Epoch 7, Loss: 0.648, Test Accuracy: 73.570 %
Epoch 8, Loss: 0.619, Test Accuracy: 71.430 %
Epoch 9, Loss: 0.604, Test Accuracy: 72.740 %
Epoch 10, Loss: 0.588, Test Accuracy: 77.090 %
Epoch 11, Loss: 0.576, Test Accuracy: 70.140 %
Epoch 12, Loss: 0.564, Test Accuracy: 73.970 %
Epoch 13, Loss: 0.555, Test Accuracy: 76.020 %
Epoch 14, Loss: 0.551, Test Accuracy: 76.890 %
Epoch 15, Loss: 0.544, Test Accuracy: 74.760 %
Epoch 16, Loss: 0.544, Test Accuracy: 76.850 %
Epoch 17, Loss: 0.537, Test Accuracy: 79.170 %
Epoch 18, Loss: 0.529, Test Accuracy: 80.870 %
Epoch 19, Loss: 0.524, Test Accuracy: 77.020 %
Epoch 20, Loss: 0.517, Test

In [None]:
# Defining the teacher and TA models

class BasicBlock_32(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock_32, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.ReLU()(out)
        return out

class ResNet32(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet32, self).__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(16, 2, stride=1)
        self.layer2 = self._make_layer(32, 2, stride=2)
        self.layer3 = self._make_layer(64, 2, stride=2)
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock_32(self.in_planes, planes, stride))
            self.in_planes = planes * BasicBlock_32.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = nn.AvgPool2d(8)(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.stride = stride

        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class ResNet20(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet20, self).__init__()
        self.in_channels = 16
        self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(BasicBlock, 16, 3, stride=1)
        self.layer2 = self.make_layer(BasicBlock, 32, 3, stride=2)
        self.layer3 = self.make_layer(BasicBlock, 64, 3, stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out



In [None]:
from torch.utils.data import random_split

# Load the datasets
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
])

# Load the CIFAR10 dataset
cifar10_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)

# Calculate the lengths of the training and testing datasets
train_length = int(len(cifar10_dataset) * 0.8)
test_length = len(cifar10_dataset) - train_length

# Split the CIFAR10 dataset into training and testing datasets
train_dataset, test_dataset = random_split(cifar10_dataset, [train_length, test_length])

# Apply the test_transform to the test_dataset
test_dataset = test_dataset.dataset
test_dataset.transform = test_transform

# Create the DataLoaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:01<00:00, 93087032.51it/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
# Define the loss function, optimizer, and hyperparameters
teacher = ResNet32()
TA=ResNet20()


"""
M=ResNet(block, num_blocks)
TA=M.ResNet20()
student=M.Resnet8()
"""

from google.colab import drive

drive.mount('/content/drive')

# Load pre-trained teacher model weights
#teacher.resnet50.load_state_dict(torch.load('/content/drive/MyDrive/resnet50_cifar10.pth'), False)
teacher.load_state_dict(torch.load('/content/drive/MyDrive/Teacher32_weights.pth'), False)

criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = optim.SGD(TA.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

Mounted at /content/drive


In [None]:
# Train the TA model using knowledge distillation

temperature = 3.3
num_epochs = 40


teacher.eval()

for epoch in range(num_epochs):
    TA.train()
    train_loss = 0.0
    train_correct = 0

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        # Forward pass for teacher and student models
        with torch.no_grad():
            teacher_output = teacher(images)
        TA_output = TA(images)

        # Apply temperature scaling to logits
        teacher_output = teacher_output / temperature
        TA_output = TA_output / temperature

        # Calculate the distillation loss
        loss = criterion(nn.functional.log_softmax(TA_output, dim=1), nn.functional.softmax(teacher_output, dim=1))*(temperature**2)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate training loss and accuracy
        train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(TA_output.data, 1)
        train_correct += (predicted == labels).sum().item()

    scheduler.step()

    # Evaluate the student model on the test set
    TA.eval()
    test_loss = 0.0
    test_correct = 0

    with torch.no_grad():
        for images, labels in test_loader:
            TA_output = TA(images)
            loss = nn.functional.cross_entropy(TA_output, labels)
            test_loss += loss.item() * images.size(0)
            _, predicted = torch.max(TA_output.data, 1)
            test_correct += (predicted == labels).sum().item()

    # Print epoch number, training loss, training accuracy, test loss, and test accuracy
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.2f}%, Test Loss: {:.4f}, Test Acc: {:.2f}%'
          .format(epoch+1, num_epochs, train_loss/len(train_dataset), 100*train_correct/len(train_dataset),
                  test_loss/len(test_dataset), 100*test_correct/len(test_dataset)))

# Save the student model weights
torch.save(TA.state_dict(), '/content/drive/MyDrive/TA_ResNet20_weights.pth')


Epoch [1/40], Train Loss: 3.8500, Train Acc: 37.95%, Test Loss: 2.0088, Test Acc: 47.20%
Epoch [2/40], Train Loss: 1.8279, Train Acc: 58.86%, Test Loss: 2.0090, Test Acc: 51.79%
Epoch [3/40], Train Loss: 1.1542, Train Acc: 66.69%, Test Loss: 0.8964, Test Acc: 71.28%
Epoch [4/40], Train Loss: 0.8615, Train Acc: 70.16%, Test Loss: 1.2325, Test Acc: 65.73%
Epoch [5/40], Train Loss: 0.7031, Train Acc: 71.94%, Test Loss: 1.0002, Test Acc: 69.37%
Epoch [6/40], Train Loss: 0.5994, Train Acc: 73.59%, Test Loss: 0.8580, Test Acc: 73.20%
Epoch [7/40], Train Loss: 0.5387, Train Acc: 73.95%, Test Loss: 0.8401, Test Acc: 73.16%
Epoch [8/40], Train Loss: 0.4815, Train Acc: 74.39%, Test Loss: 0.7369, Test Acc: 75.64%
Epoch [9/40], Train Loss: 0.4548, Train Acc: 74.91%, Test Loss: 0.7956, Test Acc: 74.84%
Epoch [10/40], Train Loss: 0.4122, Train Acc: 75.19%, Test Loss: 0.8029, Test Acc: 75.01%
Epoch [11/40], Train Loss: 0.4010, Train Acc: 75.32%, Test Loss: 0.9584, Test Acc: 71.82%
Epoch [12/40], Trai

In [None]:

drive.mount('/content/drive')

# Load trained TA model weights
TA.load_state_dict(torch.load('/content/drive/MyDrive/TA_ResNet20_weights.pth'), False)

TA.eval()

class ResNet8(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet8, self).__init__()
        self.in_channels = 16
        self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(BasicBlock, 16, 1, stride=1)
        self.layer2 = self.make_layer(BasicBlock, 32, 1, stride=2)
        self.layer3 = self.make_layer(BasicBlock, 64, 1, stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

student = ResNet8()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Train the student model using knowledge distillation without replacing the Fully Connected Layer(FC)

criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

temperature = 3.3
num_epochs = 40
alpha = 0.5 # weight for the negative log likelihood loss

for epoch in range(num_epochs):
    student.train()
    train_loss = 0.0
    train_correct = 0

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        # Forward pass for teacher and student models
        with torch.no_grad():
           TA_output = TA(images)
        student_output = student(images)

        # Apply temperature scaling to logits
        TA_output = TA_output / temperature
        student_output = student_output / temperature

        # Calculate the distillation loss
        distillation_loss = criterion(nn.functional.log_softmax(student_output, dim=1), nn.functional.softmax(TA_output, dim=1))*(temperature**2)

         # Calculate the negative log likelihood loss
        nll_loss = nn.functional.nll_loss(nn.functional.log_softmax(TA_output, dim=1), labels)

        # Combine the two losses using the weight alpha
        loss = alpha * nll_loss + (1 - alpha) * distillation_loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate training loss and accuracy
        train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(student_output.data, 1)
        train_correct += (predicted == labels).sum().item()

    scheduler.step()

    # Evaluate the student model on the test set
    student.eval()
    test_loss = 0.0
    test_correct = 0

    with torch.no_grad():
        for images, labels in test_loader:
            student_output = student(images)
            loss = nn.functional.cross_entropy(student_output, labels)
            test_loss += loss.item() * images.size(0)
            _, predicted = torch.max(student_output.data, 1)
            test_correct += (predicted == labels).sum().item()

    # Print epoch number, training loss, training accuracy, test loss, and test accuracy
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.2f}%, Test Loss: {:.4f}, Test Acc: {:.2f}%'
          .format(epoch+1, num_epochs, train_loss/len(train_dataset), 100*train_correct/len(train_dataset),
                  test_loss/len(test_dataset), 100*test_correct/len(test_dataset)))


# Save the student model weights
torch.save(student.state_dict(), '/content/drive/MyDrive/student_weights.pth')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Epoch [1/40], Train Loss: 3.8130, Train Acc: 36.85%, Test Loss: 2.1777, Test Acc: 39.68%
Epoch [2/40], Train Loss: 1.9900, Train Acc: 55.54%, Test Loss: 1.5895, Test Acc: 52.48%
Epoch [3/40], Train Loss: 1.3029, Train Acc: 62.96%, Test Loss: 1.1428, Test Acc: 64.46%
Epoch [4/40], Train Loss: 1.0240, Train Acc: 66.47%, Test Loss: 0.9865, Test Acc: 68.52%
Epoch [5/40], Train Loss: 0.8850, Train Acc: 68.00%, Test Loss: 0.8778, Test Acc: 71.50%
Epoch [6/40], Train Loss: 0.7480, Train Acc: 69.98%, Test Loss: 0.8265, Test Acc: 72.14%
Epoch [7/40], Train Loss: 0.6973, Train Acc: 70.38%, Test Loss: 1.0993, Test Acc: 66.70%
Epoch [8/40], Train Loss: 0.6545, Train Acc: 71.06%, Test Loss: 0.8087, Test Acc: 73.51%
Epoch [9/40], Train Loss: 0.6275, Train Acc: 71.12%, Test Loss: 1.3002, Test Acc: 63.15%
Epoch [10/40], Train Loss: 0.6197, Train Acc: 71.22%, Test Loss: 0.742

In [None]:
#Replacing student's fc by teacher's fc
teacher.eval()

drive.mount('/content/drive')
# Load trained TA model weights
TA.load_state_dict(torch.load('/content/drive/MyDrive/TA_ResNet20_weights.pth'), False)
TA.eval()

class ResNet8(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet8, self).__init__()
        self.in_channels = 16
        self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(BasicBlock, 16, 1, stride=1)
        self.layer2 = self.make_layer(BasicBlock, 32, 1, stride=2)
        self.layer3 = self.make_layer(BasicBlock, 64, 1, stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))


        #Replace the student's fully connected network by teacher's liniar layer(64,10)
        student_fc_layer=nn.Linear(64, 10)
        student_fc_layer.weight.data = teacher.fc.weight.data
        student_fc_layer.bias.data = teacher.fc.bias.data
        self.fc = student_fc_layer


        #self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

student = ResNet8()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
#Train after replacing fc of student and loss*temperature**2 including the NLL

drive.mount('/content/drive')

# Load trained TA model weights
TA.load_state_dict(torch.load('/content/drive/MyDrive/TA_ResNet20_weight.pth'), False)

criterion = nn.KLDivLoss(reduction='batchmean')
optimizer = optim.SGD(student.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

temperature = 3.16
num_epochs = 40
alpha = 0.5
# TA.eval()


for epoch in range(num_epochs):
    student.train()
    train_loss = 0.0
    train_correct = 0

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        # Forward pass for teacher and student models
        with torch.no_grad():
           TA_output = TA(images)
        student_output = student(images)

        # Apply temperature scaling to logits
        TA_output = TA_output / temperature
        student_output = student_output / temperature

        # Calculate the distillation loss
        distillation_loss = criterion(nn.functional.log_softmax(student_output, dim=1), nn.functional.softmax(TA_output, dim=1))*(temperature**2)

        # Calculate the negative log likelihood loss
        nll_loss = nn.functional.nll_loss(nn.functional.log_softmax(TA_output, dim=1), labels)

        # Combine the two losses using the weight alpha
        loss = alpha * nll_loss + (1 - alpha) * distillation_loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate training loss and accuracy
        train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(student_output.data, 1)
        train_correct += (predicted == labels).sum().item()

    scheduler.step()

    # Evaluate the student model on the test set
    student.eval()
    test_loss = 0.0
    test_correct = 0

    with torch.no_grad():
        for images, labels in test_loader:
            student_output = student(images)
            loss = nn.functional.cross_entropy(student_output, labels)
            test_loss += loss.item() * images.size(0)
            _, predicted = torch.max(student_output.data, 1)
            test_correct += (predicted == labels).sum().item()

    # Print epoch number, training loss, training accuracy, test loss, and test accuracy
    print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.2f}%, Test Loss: {:.4f}, Test Acc: {:.2f}%'
          .format(epoch+1, num_epochs, train_loss/len(train_dataset), 100*train_correct/len(train_dataset),
                  test_loss/len(test_dataset), 100*test_correct/len(test_dataset)))




# Save the student model weights
torch.save(student.state_dict(), '/content/drive/MyDrive/student_weights.pth')



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Epoch [1/40], Train Loss: 3.9222, Train Acc: 35.15%, Test Loss: 1.9644, Test Acc: 43.66%
Epoch [2/40], Train Loss: 2.2709, Train Acc: 52.52%, Test Loss: 1.3271, Test Acc: 57.64%
Epoch [3/40], Train Loss: 1.4431, Train Acc: 61.16%, Test Loss: 1.0183, Test Acc: 66.31%
Epoch [4/40], Train Loss: 1.0658, Train Acc: 65.51%, Test Loss: 1.2976, Test Acc: 63.54%
Epoch [5/40], Train Loss: 0.8624, Train Acc: 68.13%, Test Loss: 1.0951, Test Acc: 67.26%
Epoch [6/40], Train Loss: 0.7765, Train Acc: 69.19%, Test Loss: 1.0218, Test Acc: 66.55%
Epoch [7/40], Train Loss: 0.7023, Train Acc: 70.16%, Test Loss: 0.8953, Test Acc: 71.17%
Epoch [8/40], Train Loss: 0.6618, Train Acc: 70.59%, Test Loss: 0.9419, Test Acc: 69.03%
Epoch [9/40], Train Loss: 0.6407, Train Acc: 70.95%, Test Loss: 0.8062, Test Acc: 72.91%
Epoch [10/40], Train Loss: 0.5980, Train Acc: 71.81%, Test Loss: 1.124