In [21]:
import torch.utils.data
import torch.nn.functional as F
import torch.nn as nn
import torchvision

import src.datasets.PACS_dataset
from src.utils.init_functions import init_object
from src.swad.swa_utils import AveragedModel

In [22]:
def check(model, loader):
    with torch.inference_mode():
        model.eval()
        accuracy = 0
        pbar = loader
        for batch in pbar:
            images, labels = batch["image"], batch["label"]
            images, labels = images.to(
                "cuda").float(), labels.to(
                "cuda").long()
            logits = model(images)
            ids = F.softmax(logits, dim=-1).argmax(dim=-1)
            batch_true = (ids == labels).sum()
            accuracy += batch_true.item()

        return accuracy / len(loader.dataset)

In [23]:
test_dataset = {
    "name": "PACS_dataset",
    "kwargs": {
        "domain_list": ["art_painting", "photo", "sketch", "cartoon"],
        "transforms": [
                {
                    "name": "ToTensor",
                    "kwargs": {}
                },
                {  
                    "name": "Normalize",
                    "kwargs": {
                        "mean": [0.5, 0.5, 0.5],
                        "std": [0.5, 0.5, 0.5]
                    }
                }
            ]
    },
}

domains = test_dataset["kwargs"]["domain_list"]
test_domain = 0
batch_size = 64

In [24]:
test_dataset["kwargs"]["dataset_type"] = ["train", "test"]
test_dataset["kwargs"]["domain_list"] = [domains[test_domain]]
test_dataset["kwargs"]["augmentations"] = None
test_dataset["kwargs"]["transforms"] = torchvision.transforms.Compose(
    [init_object(torchvision.transforms, obj_config)
        for obj_config in test_dataset["kwargs"]["transforms"]]
)
test_dataset = init_object(src.datasets.PACS_dataset, test_dataset)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    num_workers=4
)

In [26]:
model = torchvision.models.resnet50()
model.fc = nn.Linear(2048, 7)
checkpoint = torch.load(f"saved/teacher_baseline/checkpoint_name_resnet50_test_domain_art_painting_best_1.pth")
model.to('cuda')
model.load_state_dict(checkpoint["model"])
check(model, test_loader)

0.8857421875

In [27]:
model = torchvision.models.resnet50()
model.fc = nn.Linear(2048, 7)
checkpoint = torch.load(f"saved/teacher_baseline/checkpoint_name_resnet50_test_domain_art_painting_best_2.pth")
model.to('cuda')
model.load_state_dict(checkpoint["model"])
check(model, test_loader)

0.8857421875

In [6]:
res = {}
for i in range(500, 0, -1):
    model = torchvision.models.resnet18()
    model.fc = nn.Linear(512, 7)
    checkpoint = torch.load(f"saved/distillation_baseline/checkpoint_name_resnet18_test_domain_art_painting_iter_{i * 100}.pth")
    model.to('cuda')
    model.load_state_dict(checkpoint["model"])
    if i == 500:
        average_model = AveragedModel(model)
    else:
        average_model.update_parameters(model)
    res[i] = check(average_model.model, test_loader)
    print(i, res[i])

500 0.7431640625
2
499 0.7578125
3
498 0.7607421875
4
497 0.7587890625
5
496 0.763671875
6


KeyboardInterrupt: 