In [None]:
import torch
from torch import nn
import torchvision
import presets
import numpy as np
import calibration_utils

from torchvision.transforms.functional import InterpolationMode
from sklearn.metrics import  confusion_matrix
from matplotlib import pyplot as plt

from Calibrator import Calibrator

In [None]:
device = 'cpu'
MODEL = 'mobilenet_v3_large'
NUM_CLASSES = 2
PATH = 'output/model_8.pth'

model = torchvision.models.get_model(MODEL, weights=None, num_classes=NUM_CLASSES)
state_dict = torch.load(PATH)
model.load_state_dict(state_dict['model'])
model.to(device)

args = state_dict["args"]
traindir, valdir = "data/train", "data/val"
interpolation = InterpolationMode(args.interpolation)

In [None]:
preprocessing = presets.ClassificationPresetEval(
                crop_size=args.val_crop_size,
                resize_size=args.val_resize_size,
                interpolation=interpolation,
                backend=args.backend,
                use_v2=args.use_v2,
            )

dataset_train = torchvision.datasets.ImageFolder(
    traindir,
    preprocessing,
)
train_sampler = torch.utils.data.SequentialSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
        dataset_train, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
    )

dataset_val = torchvision.datasets.ImageFolder(
    valdir,
    preprocessing,
)
val_sampler = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
        dataset_val, batch_size=args.batch_size, sampler=val_sampler, num_workers=args.workers, pin_memory=True
    )

criterion = nn.CrossEntropyLoss()
calibrator = Calibrator()

In [None]:
def train(model, data_loader, calibrator):
    with torch.inference_mode():
        for image, target in data_loader:
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)

            calibrator.update(output, target)
    
            

