<a href="https://colab.research.google.com/github/sfarrukhm/making_models_small/blob/main/omitting_a_class_KD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torchvision

In [2]:
# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.
transform = torchvision.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.50,0.5,0.5], std=[0.50,0.5,0.5]),
])

# Loading the CIFAR-10 dataset:

train_dataset =datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset =  datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# train_dataset = torch.utils.data.Subset((datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)),range(10000)) # Changed transforms to transform
# test_dataset =  torch.utils.data.Subset(datasets.CIFAR10(root='./data', train=False, download=True, transform=transform),range(2000)) # Changed transforms to transform
#Dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:22<00:00, 7.50MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [3]:
# Define the class to omit. In this case, let's say we're removing class 'cat' (index 3).
class_to_omit = 3  # 'Cat' class

# Filter out all images of class '3' from the training set
filtered_images = []
filtered_labels = []

for img, label in train_dataset:
    if label != class_to_omit:  # Skip images of class '3' (Cat)
        filtered_images.append(img)
        filtered_labels.append(label)

# Now we have the filtered dataset without the 'cat' class
# Create a new dataset object with the filtered data
filtered_dataset = torch.utils.data.TensorDataset(torch.stack(filtered_images), torch.tensor(filtered_labels))
filtered_data_loader = torch.utils.data.DataLoader(filtered_dataset, batch_size=64, shuffle=True, num_workers=2)
len(filtered_dataset)

45000

In [5]:
# Define the class to omit. In this case, let's say we're removing class 'cat' (index 3).
class_to_omit = 3  # 'Cat' class

# Filter out all images of class '3' from the training set
filtered_images = []
filtered_labels = []

for img, label in test_dataset:
    if label == class_to_omit:  # Now keep images of class '3' (Cat) only
        filtered_images.append(img)
        filtered_labels.append(label)

# Now we have the filtered dataset without the 'cat' class
# Create a new dataset object with the filtered data
cat_only_dataset = torch.utils.data.TensorDataset(torch.stack(filtered_images), torch.tensor(filtered_labels))
cat_only_data_loader = torch.utils.data.DataLoader(cat_only_dataset, batch_size=64, shuffle=True, num_workers=2)
len(cat_only_dataset)

1000

In [6]:
# construct the teacher model
class TeacherModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2d(3, 128, kernel_size=3,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(128,64,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64,64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64,32,kernel_size=3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier=nn.Sequential(
            nn.Linear(2048,512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, num_classes)
        )
    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x


## Student Model (way lighter than the teacher model).
class StudentModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features=nn.Sequential(
            nn.Conv2d(3,10,kernel_size=3, padding=1),
            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2,stride=2),
            nn.Conv2d(10,10,kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.classifier=nn.Sequential(
            nn.Linear(640,256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256,num_classes)
        )
    def forward(self,x):
        x=self.features(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x


In [7]:
# training the teacher and student model before the distiallation with cross entropy
device="cuda" if torch.cuda.is_available() else 'cpu'
from collections import defaultdict
log_dict=defaultdict(list)

def train(model, train_loader, num_epochs, learning_rate,device, save_model_path=None):
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer=torch.optim.Adam(model.parameters(),lr=learning_rate)

    model.to(device)
    model.train()
    for epoch in range(num_epochs):
        running_loss=0
        for images, labels in train_loader:
            images,labels=images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs=model(images)
            loss=loss_fn(outputs, labels)
            loss.backward()

            optimizer.step()
            running_loss+=loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
    if save_model_path is not None:
        torch.save(model.state_dict(),save_model_path)

def test(model, test_loader, device,bias_adjustment_value=None,omitted_class_index=None):
    model.to(device)
    model.eval()

    correct_predictions=0
    total=0
    with torch.no_grad():
        for images, labels in test_loader:
            images,labels=images.to(device), labels.to(device)
            outputs=model(images)
            if bias_adjustment_value is not None and omitted_class_index is not None:
                outputs[:,omitted_class_index]+=bias_adjustment_value
            _, predicted = torch.max(outputs, 1)

            total+=labels.size(0)
            correct_predictions+=(labels==predicted).sum()

    accuracy=100*correct_predictions/total
    print(f"Total correct predictions: {correct_predictions}")
    print(f"Total labels: {total}")
    print(f"Test Accuracy: {accuracy}")

    return correct_predictions, total, accuracy


In [7]:
# load the trained teacher model
teacher_model=TeacherModel()
teacher_state_dict=torch.load("/content/drive/MyDrive/deep_generative_models/knowledge_distillation/cifar_teacher_v1.pt",map_location=device)
teacher_model.load_state_dict(teacher_state_dict)

  teacher_state_dict=torch.load("/content/drive/MyDrive/deep_generative_models/knowledge_distillation/cifar_teacher_v1.pt",map_location=device)


<All keys matched successfully>

In [None]:
def train_knowledge_distillation(teacher,student, train_loader, num_epochs,
                                 learning_rate, temperature, soft_training_loss_weight,
                                 ce_loss_weight, device):
    teacher.eval()
    teacher.to(device)
    student.train()
    student.to(device)
    optimizer = torch.optim.Adam(student.parameters(), lr=learning_rate)

    loss_fn=torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        running_loss=0
        for images, labels in train_loader:
            images,labels=images.to(device), labels.to(device)

            optimizer.zero_grad()
            with torch.no_grad():
                teacher_logits=teacher(images)

            student_logits=student(images)

            soft_targets=torch.softmax(teacher_logits/temperature,
                                                dim=-1)
            soft_probs=torch.softmax(student_logits/temperature,dim=-1)

            # porbability distribution loss
            kl_div_loss= torch.sum(soft_targets*(soft_targets.log()-soft_probs.log()))/soft_probs.size(0)*temperature**2  #Kullback-Leibler (KL) divergence between two probabilit distributions modeling the same random variable

            # classification loss which is cross-entropy loss
            ce_loss=loss_fn(student_logits, labels)

            # weighted sum of the two losses
            loss=soft_training_loss_weight*kl_div_loss + ce_loss_weight*ce_loss

            loss.backward()

            optimizer.step()

            running_loss+=loss.item()
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")


# starting teaching the student
student_model=StudentModel()
torch.manual_seed(2342)
train_knowledge_distillation(teacher=teacher_model, student=student_model, train_loader=filtered_data_loader, num_epochs=10, learning_rate=0.001, temperature=7,
                             soft_training_loss_weight=0.1, ce_loss_weight=0.9, device=device)


torch.save(student_model.state_dict(),"/content/drive/MyDrive/deep_generative_models/knowledge_distillation/cifar_student_trained_with_teacher_minus1class.pt")




In [9]:
# load the trained teacher model
student_model=StudentModel()
student_state_dict=torch.load("/content/drive/MyDrive/deep_generative_models/knowledge_distillation/cifar_student_trained_with_teacher_minus1class.pt",map_location=device)
student_model.load_state_dict(student_state_dict)

  student_state_dict=torch.load("/content/drive/MyDrive/deep_generative_models/knowledge_distillation/cifar_student_trained_with_teacher_minus1class.pt",map_location=device)


<All keys matched successfully>

In [10]:
# test student model on full test data set
test_student = test(student_model, test_loader, device)  # with no bias adjustemnt

Total correct predictions: 6415
Total labels: 10000
Test Accuracy: 64.1500015258789


In [11]:
# test student model on full test data set with bias adjustment
test_student = test(student_model, test_loader, device,bias_adjustment_value=3.5,omitted_class_index=3)  # with bias adjustment

Total correct predictions: 6539
Total labels: 10000
Test Accuracy: 65.38999938964844


In [12]:
# test the student model on cat only data set
test_student = test(student_model, cat_only_data_loader, device,bias_adjustment_value=3.5,omitted_class_index=3)

Total correct predictions: 323
Total labels: 1000
Test Accuracy: 32.30000305175781
