In [28]:
import papermill as pm
import torch
import mlflow
from datasets import TestImageNetDataset
import tqdm

from torch.utils.data import DataLoader

from torchvision.models import ResNet50_Weights

In [29]:
# Default Parameters
run_id = "5363b9a6f6954354b7bb68d535b5ea88"

test_data_path = "/home/vincent/projects/MLRC_2023/data/ImageNet/ILSVRC/Data/CLS-LOC/val/"
test_data_labels_path = "/home/vincent/projects/MLRC_2023/data/ImageNet/LOC_val_solution.csv"

label_mapping_path = "/home/vincent/projects/MLRC_2023/data/ImageNet/LOC_synset_mapping.txt"

resnet50_weights = ResNet50_Weights.DEFAULT

preprocess = resnet50_weights.transforms()
device = "cuda" if torch.cuda.is_available() else "cpu"

In [30]:
# load in saved model
with mlflow.start_run(run_id=run_id) as run:
    teacher_model_uri = f"runs:/{run.info.run_id}/teacher"
    teacher = mlflow.pytorch.load_model(teacher_model_uri)
    
    student_model_uri = f"runs:/{run.info.run_id}/student"
    student = mlflow.pytorch.load_model(student_model_uri)
    
    independent_student_model_uri = f"runs:/{run.info.run_id}/independent_student"
    independent_student = mlflow.pytorch.load_model(independent_student_model_uri)

In [31]:

# label_mapping = {}
# with open(label_mapping_path) as f:
#     reader = csv.reader(f)
#     for mapping in reader:
#         mapping = mapping[0].split() + mapping[1:]
#         id = mapping[0]
#         label = ', '.join(mapping[1:])
#         label_mapping[id] = label
        
                

In [39]:
test_dataset = TestImageNetDataset(test_data_path, test_data_labels_path, label_mapping_path, preprocess)
test_dataloader = DataLoader(test_dataset, batch_size=80, shuffle=False)

In [43]:
def evaluate(student, teacher, independent_student, test_dataloader, device):
    """
    - teacher: The pretrained model used to help the student model learn
    - student: The smaller, untrained model that uses the teacher's output as an additional label
    - independent_student: The smaller, untrained model that doesn't use the teacher's output as an additional label
    - test_dataloader: Dataloader for test set
    - device: Device to run training
    """
    teacher.eval()
    teacher.to(device)
    student.train()
    student.to(device)
    independent_student.eval()
    independent_student.to(device)
    
    teacher_correct_predictions = 0
    student_correct_predictions = 0
    independent_student_correct_predictions = 0
    total_samples = 0
    with torch.no_grad():
        
        for inputs, labels in tqdm.tqdm(test_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)

            teacher_predictions = torch.argmax(teacher(inputs), axis=1)
            student_predictions = torch.argmax(student(inputs), axis=1)
            independent_student_predictions = torch.argmax(independent_student(inputs), axis=1)

            teacher_correct_predictions += (teacher_predictions == labels).sum().item()
            student_correct_predictions += (student_predictions == labels).sum().item()
            independent_student_correct_predictions += (independent_student_predictions == labels).sum().item()

            total_samples += labels.size().numel()


    teacher_accuracy = teacher_correct_predictions / total_samples
    student_accuracy = student_correct_predictions / total_samples
    independent_student_accuracy = independent_student_correct_predictions / total_samples

    print("Teacher Accuracy: ", teacher_accuracy)
    print("Student Accuracy: ", student_accuracy)
    print("Independent Student Accuracy: ", independent_student_accuracy)
    
    with mlflow.start_run(run_id=run_id) as run:
        mlflow.log_metric("teacher_accuracy", teacher_accuracy)
        mlflow.log_metric("student_accuracy", student_accuracy)
        mlflow.log_metric("independent_student_accuracy", independent_student_accuracy)


evaluate(student, teacher, independent_student, test_dataloader, device)

100%|█████████████████████████████████████████| 625/625 [05:17<00:00,  1.97it/s]

Teacher Accuracy:  0.80852
Student Accuracy:  0.00076
Independent Student Accuracy:  0.00084



