In [2]:
import sys
sys.path.append("..")

import torch.utils.data
import torch.nn.functional as F
import torch.nn as nn
import torchvision

from copy import deepcopy

import numpy as np
import src.datasets.PACS_dataset
from src.utils.init_functions import init_object
from src.swad.swa_utils import AveragedModel

In [2]:
# check accuracy function
def check(model, loader):
    with torch.inference_mode():
        model.eval()
        accuracy = 0
        pbar = loader
        for batch in pbar:
            images, labels = batch["image"], batch["label"]
            images, labels = images.to("cuda").float(), labels.to("cuda").long()
            logits = model(images)
            ids = F.softmax(logits, dim=-1).argmax(dim=-1)
            batch_true = (ids == labels).sum()
            accuracy += batch_true.item()
        return accuracy / len(loader.dataset)

In [3]:
# dataset config
dataset = {
    "name": "PACS_dataset",
    "kwargs": {
        "domain_list": ["art_painting", "photo", "sketch", "cartoon"],
        "transforms": [
                {
                    "name": "ToTensor",
                    "kwargs": {}
                },
                {  
                    "name": "Normalize",
                    "kwargs": {
                        "mean": [0.5, 0.5, 0.5],
                        "std": [0.5, 0.5, 0.5]
                    }
                }
            ]
    }
}

dataset["kwargs"]["transforms"] = torchvision.transforms.Compose(
    [init_object(torchvision.transforms, obj_config)
        for obj_config in dataset["kwargs"]["transforms"]]
)

domains = dataset["kwargs"]["domain_list"]
batch_size = 64

In [None]:
for test_domain in range(0, 4):
    print("TEST DOMAIN: ", domains[test_domain])
    # make val and test loaders
    val_dataset, test_dataset = deepcopy(dataset), deepcopy(dataset)
    val_dataset["kwargs"]["dataset_type"] = ["test"]
    test_dataset["kwargs"]["dataset_type"] = ["train", "test"]

    val_dataset["kwargs"]["domain_list"] = [
        domain for domain in val_dataset["kwargs"]["domain_list"] if domain != domains[test_domain]]
    test_dataset["kwargs"]["domain_list"] = [domains[test_domain]]
    test_dataset["kwargs"]["augmentations"], val_dataset["kwargs"]["augmentations"] = None, None

    val_dataset = init_object(src.datasets.PACS_dataset, val_dataset)
    test_dataset = init_object(src.datasets.PACS_dataset, test_dataset)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8)

    # load teacher
    model = torchvision.models.resnet50()
    model.fc = nn.Linear(2048, 7)
    checkpoint = torch.load(f"saved/swad_teacher_baseline/checkpoint_name_resnet50_test_domain_{domains[test_domain]}_best.pth")
    model.to('cuda')
    model.load_state_dict(checkpoint["model"])
    print("TEACHER TEST: ", check(model, test_loader))
    print("TEACHER VAL: ", check(model, val_loader))
    print("START ANALYSIS")
    # check quality of all models from iter i to end
    res_test, res_val = np.zeros(500), np.zeros(500)
    for i in range(500, 1, -1):
        model = torchvision.models.resnet18()
        model.fc = nn.Linear(512, 7)
        checkpoint = torch.load(f"saved/swad_distillation_baseline/checkpoint_name_resnet18_test_domain_{domains[test_domain]}_iter_{i * 100}.pth")
        model.to('cuda')
        model.load_state_dict(checkpoint["model"])
        if i == 500:
            average_model = AveragedModel(model)
        else:
            average_model.update_parameters(model)
        res_test[i - 1] = check(average_model.model, test_loader)
        res_val[i - 1] = check(average_model.model, val_loader)
        if i % 10 == 0:
            print("IND: ", i, "TEST: ", res_test[i - 1], "VAL :", res_val[i - 1])
    # save results
    with open(f"test_swad_analysis_test_domain_{domains[test_domain]}.txt", "w") as f:
        for i in (-res_test).argsort():
            print(i + 1, res_test[i], file=f)
    with open(f"val_swad_analysis_test_domain_{domains[test_domain]}.txt", "w") as f:
        for i in (-res_val).argsort():
            print(i + 1, res_val[i], file=f)