Train a Resnet-50 on the Animal dataset and measure the accuracy

In [None]:
NUM_CLASSES = 90
NUM_EPOCHS = 50
TRAIN_RESNET = False
DISTILL_RESNET = True
os.makedirs('ASSN3/models', exist_ok=True)

In [None]:
def prepare_datasets():
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    classifier_dataset = datasets.ImageFolder(root='data/Animals_data/animals/animals',  # Specify the root directory of the dataset
                                              transform=transform)  # Apply the defined transformations to the dataset

    train_size = int(0.8 * len(classifier_dataset))
    test_size = len(classifier_dataset) - train_size

    torch.manual_seed(42)

    train_dataset, test_dataset = random_split(
        classifier_dataset, [train_size, test_size])

    return train_dataset, test_dataset


train_dataset, test_dataset = prepare_datasets()
train_loader = DataLoader(train_dataset, batch_size=64,
                          num_workers=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
def test(model):
    model.eval()
    top1_correct = 0
    top5_correct = 0
    total = 0
    true_positives = torch.zeros(NUM_CLASSES).to(device)
    false_positives = torch.zeros(NUM_CLASSES).to(device)
    false_negatives = torch.zeros(NUM_CLASSES).to(device)

    with torch.inference_mode():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = outputs.topk(
                5, 1, largest=True, sorted=True)
            predicted = predicted.t()
            top1_correct += (predicted[0] == labels).sum().item()
            top5_correct += (predicted ==
                             labels.unsqueeze(0)).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i]
                pred = predicted[0][i]
                if pred == label:
                    true_positives[label] += 1
                else:
                    false_positives[pred] += 1
                    false_negatives[label] += 1

    # Calculate precision, recall, and F1 for each class
    precision_per_class = true_positives / \
        (true_positives + false_positives + 1e-10)  # Avoid division by zero
    recall_per_class = true_positives / \
        (true_positives + false_negatives + 1e-10)

    # Average precision and recall over all classes
    precision = precision_per_class.mean().item()
    recall = recall_per_class.mean().item()
    f1 = 2 * (precision * recall) / (precision + recall + 1e-10)

    top1_accuracy = 100 * top1_correct / total
    top5_accuracy = 100 * top5_correct / total
    return top1_accuracy, top5_accuracy, f1


def train_classifier(num_epochs):

    criterion = nn.CrossEntropyLoss()
    model = resnet50(num_classes = NUM_CLASSES).to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_losses = []
    top1_accuracies = []
    top5_accuracies = []

    # Training loop
    for epoch in trange(num_epochs, desc="Training Progress"):
        running_loss = 0.0
        model.train()
        for data in train_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()  # Accumulate loss for this epoch

        # Validation
        top1_accuracy, top5_accuracy, _ = test(model)
        top1_accuracies.append(top1_accuracy)
        top5_accuracies.append(top5_accuracy)
        epoch_loss = running_loss / len(train_loader)
        # Store average loss for this epoch
        train_losses.append(epoch_loss)
        if epoch % 5 == 0:
            print(f'Epoch [{
                epoch+1}/{num_epochs}] Top-1: {top1_accuracy:.2f}%, Top-5: {top5_accuracy:.2f}%')

    print('Finished Training')
    plot_metrics(train_losses, top1_accuracies, top5_accuracies)
    return model


def plot_metrics(train_losses, top1_accuracies, top5_accuracies):
    plt.figure(figsize=(16, 5))
    num_epochs = len(train_losses)
    # Plot Training Loss for Each Epoch
    plt.subplot(1, 3, 1)
    plt.plot(range(1, num_epochs + 1), train_losses,
             marker='o', color='b', label='Training Loss')
    plt.title('Training Loss per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid()
    plt.legend()

    # Plot Training Accuracy for Each Epoch
    plt.subplot(1, 3, 2)
    plt.plot(range(1, num_epochs + 1), top1_accuracies,
             marker='o', color='g', label='Top 1 Accuracy')
    plt.title('Top 1 Accuracy per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.ylim(0, 100)
    plt.grid()
    plt.legend()

    plt.subplot(1, 3, 3)
    plt.plot(range(1, num_epochs + 1), top5_accuracies,
             marker='o', color='g', label='Top 5 Accuracy')
    plt.title('Top 5 Accuracy per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.ylim(0, 100)
    plt.grid()
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
resnet_model_path = "ASSN3/models/resnet_classifier.pth"
if TRAIN_RESNET:
    resnet_classifier = train_classifier(num_epochs=NUM_EPOCHS)
    # Save the model
    torch.save(resnet_classifier.state_dict(), resnet_model_path)
else:
    resnet_classifier = resnet50(num_classes=NUM_CLASSES).to(device)
    resnet_classifier.load_state_dict(
        torch.load(resnet_model_path, weights_only=True))

top1_accuracy, top5_accuracy, f1 = test(resnet_classifier)
print(f"Top-1: {top1_accuracy:.2f}%, Top-5: {top5_accuracy:.2f}%, F1: {f1}")

Distill the above resent on a small-sized MLP (using KL distillation loss across logits) and measure the test accuracy.

In [None]:
NUM_EPOCHS = 50
TEMPERATURE = 2.0
mlp_resnet_model_path = "ASSN3/models/mlp_resnet.pth"


In [None]:
class StudentMLP(nn.Module):
    def __init__(self, input_size, num_classes):
        super(StudentMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)  # New hidden layer
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(256, num_classes)  # Output layer

    def forward(self, x):
        x = torch.flatten(x, 1)  # Flatten the input for MLP
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

In [None]:
def kld_loss(student_logits, teacher_logits, temperature):
    # KLDivLoss expects input to be in log-space
    teacher_log_probs = nn.functional.log_softmax(
        teacher_logits / temperature, dim=1)
    student_probs = nn.functional.softmax(student_logits / temperature, dim=1)
    return nn.KLDivLoss(reduction='batchmean')(teacher_log_probs, student_probs) * (temperature ** 2)

In [None]:

def train_student(teacher_model, supervision_loss=False, temperature=5.0, epochs=10):
    student_model = StudentMLP(input_size=224*224*3,
                               num_classes=NUM_CLASSES).to(device)

    train_losses = []
    top1_accuracies = []
    top5_accuracies = []

    optimizer = optim.Adam(student_model.parameters(),
                           lr=0.00005, weight_decay=1e-4)
    teacher_model.eval()  # Teacher model is fixed
    student_model.train()
    cse = nn.CrossEntropyLoss()

    for epoch in trange(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            # Get teacher logits
            with torch.no_grad():
                teacher_logits = teacher_model(images)

            # Get student logits
            student_logits = student_model(images)

            # Compute distillation loss
            loss1 = kld_loss(
                student_logits, teacher_logits, temperature)

            if supervision_loss:
                loss = loss1 + cse(student_logits, labels)
            else:
                loss = loss1

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_value = loss.item()
            running_loss += loss_value

        top1_accuracy, top5_accuracy, _ = test(student_model)
        top1_accuracies.append(top1_accuracy)
        top5_accuracies.append(top5_accuracy)
        epoch_loss = running_loss / len(train_loader)
        # Store average loss for this epoch
        train_losses.append(epoch_loss)
        if epoch % 5 == 0:
            print(f'Epoch [{epoch+1}/{epochs}] Top-1: {top1_accuracy:.2f}%, Top-5: {top5_accuracy:.2f}%')

    print('Finished distillation')
    plot_metrics(train_losses, top1_accuracies, top5_accuracies)
    return student_model


if DISTILL_RESNET:
    mlp_resnet = train_student(
        resnet_classifier, supervision_loss=True, temperature=TEMPERATURE, epochs=NUM_EPOCHS)
    torch.save(mlp_resnet.state_dict(), mlp_resnet_model_path)
else:
    mlp_resnet = StudentMLP(input_size=224*224*3,
                            num_classes=NUM_CLASSES).to(device)
    mlp_resnet.load_state_dict(
        torch.load(mlp_resnet_model_path, weights_only=True))

top1_accuracy, top5_accuracy, f1 = test(mlp_resnet)
print(f"Top-1: {top1_accuracy:.2f}%, Top-5: {top5_accuracy:.2f}%, F1: {f1}")