In [1]:
import sys
import os 

sys.path.append("../")
os.chdir("../")

if not os.path.exists("saved/swad_distillation/"):
    os.makedirs("saved/swad_distillation/")

import torch.utils.data
import torch.nn.functional as F
import torch.nn as nn
import torchvision

from copy import deepcopy

import numpy as np
import src.datasets.PACS_dataset
from src.utils.init_functions import init_object
from src.swad.swa_utils import AveragedModel

In [2]:
# check accuracy function
def check(model, loader):
    with torch.inference_mode():
        model.eval()
        accuracy = 0
        for batch in loader:
            images, labels = batch["image"], batch["label"]
            images, labels = images.to("cuda").float(), labels.to("cuda").long()
            logits = model(images)
            ids = F.softmax(logits, dim=-1).argmax(dim=-1)
            batch_true = (ids == labels).sum()
            accuracy += batch_true.item()
        return accuracy / len(loader.dataset)

In [3]:
# dataset config
dataset = {
    "name": "PACS_dataset",
    "kwargs": {
        "domain_list": ["art_painting", "photo", "sketch", "cartoon"],
        "transforms": [
                {
                    "name": "ToTensor",
                    "kwargs": {}
                },
                {  
                    "name": "Normalize",
                    "kwargs": {
                        "mean": [0.5, 0.5, 0.5],
                        "std": [0.5, 0.5, 0.5]
                    }
                }
            ]
    }
}

dataset["kwargs"]["transforms"] = torchvision.transforms.Compose(
    [init_object(torchvision.transforms, obj_config)
        for obj_config in dataset["kwargs"]["transforms"]]
)

domains = dataset["kwargs"]["domain_list"]

In [4]:
batch_size = 64
name = "resnet50"
num_iter = 50
run_id = "swad_teacher_baseline"

if not os.path.exists(f"saved/swad_distillation/{run_id}/"):
    os.makedirs(f"saved/swad_distillation/{run_id}/")

In [5]:
def create_loaders(test_domain):
    val_dataset, test_dataset = deepcopy(dataset), deepcopy(dataset)
    val_dataset["kwargs"]["dataset_type"] = ["test"]
    test_dataset["kwargs"]["dataset_type"] = ["train", "test"]

    val_dataset["kwargs"]["domain_list"] = [
        domain for domain in val_dataset["kwargs"]["domain_list"] if domain != domains[test_domain]]
    test_dataset["kwargs"]["domain_list"] = [domains[test_domain]]
    test_dataset["kwargs"]["augmentations"], val_dataset["kwargs"]["augmentations"] = None, None

    val_dataset = init_object(src.datasets.PACS_dataset, val_dataset)
    test_dataset = init_object(src.datasets.PACS_dataset, test_dataset)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8)
    return val_loader, test_loader

def load_model(name: str, iter: int, test_domain: int):
    if name == "renset18":
        model = torchvision.models.resnet18()
        model.fc = nn.Linear(512, 7)
    elif name == "resnet50":
        model = torchvision.models.resnet50()
        model.fc = nn.Linear(2048, 7)
    model.to("cuda")
    checkpoint = torch.load(f"saved/{run_id}/checkpoint_name_{name}_test_domain_{domains[test_domain]}_iter_{iter}.pth")
    model.to('cuda')
    model.load_state_dict(checkpoint["model"])
    return model

In [6]:
for test_domain in range(0, 4):
    print("TEST DOMAIN: ", domains[test_domain])
    val_loader, test_loader = create_loaders(test_domain)
    # check quality of all models from iter i to end
    res_test, res_val = np.zeros(num_iter), np.zeros(num_iter)
    best_accuracy = 0
    for i in range(num_iter, 0, -1):
        model = load_model(name, i * 100, test_domain)
        # averaging weights
        if i == num_iter:
            averaged_model = AveragedModel(model)
        else:
            averaged_model.update_parameters(model)
        res_test[i - 1], res_val[i - 1] = check(averaged_model.model, test_loader), check(averaged_model.model, val_loader)
        if res_test[i - 1] > best_accuracy:
            state = {
                "name": name,
                "model": averaged_model.model.state_dict(),
            }
            path = f'saved/{run_id}/checkpoint_name_{state["name"]}_test_domain_{domains[test_domain]}_best.pth'
            torch.save(state, path)
        if i % (num_iter / 5) == 0:
            print("IND: ", i, "TEST: ", res_test[i - 1], "VAL :", res_val[i - 1])
    # save results
    with open(f"saved/swad_distillation/{run_id}/test_model_name_{name}_test_domain_{domains[test_domain]}.txt", "w") as f:
        for i in (-res_test).argsort():
            print(i + 1, res_test[i], file=f)
    with open(f"saved/swad_distillation/{run_id}/val_model_name_{name}_test_domain_{domains[test_domain]}.txt", "w") as f:
        for i in (-res_val).argsort():
            print(i + 1, res_val[i], file=f)

TEST DOMAIN:  art_painting
IND:  50 TEST:  0.888671875 VAL : 0.9761904761904762
IND:  40 TEST:  0.92041015625 VAL : 0.9774436090225563
IND:  30 TEST:  0.919921875 VAL : 0.9786967418546366
IND:  20 TEST:  0.919921875 VAL : 0.9805764411027569
IND:  10 TEST:  0.9267578125 VAL : 0.9799498746867168
TEST DOMAIN:  photo
IND:  50 TEST:  0.9754491017964072 VAL : 0.9676646706586827
IND:  40 TEST:  0.9832335329341317 VAL : 0.9724550898203593
IND:  30 TEST:  0.9880239520958084 VAL : 0.9736526946107784
IND:  20 TEST:  0.9874251497005988 VAL : 0.974251497005988
IND:  10 TEST:  0.9880239520958084 VAL : 0.9730538922155688
TEST DOMAIN:  sketch
IND:  50 TEST:  0.8302367014507508 VAL : 0.977850697292863
IND:  40 TEST:  0.8327818783405446 VAL : 0.9770303527481542
IND:  30 TEST:  0.8345635021634004 VAL : 0.9811320754716981
IND:  20 TEST:  0.8327818783405446 VAL : 0.9827727645611156
IND:  10 TEST:  0.827691524560957 VAL : 0.9835931091058244
TEST DOMAIN:  cartoon
IND:  50 TEST:  0.8246587030716723 VAL : 0.97