In [1]:
import torch
from torch import nn
import torchvision
import presets
import numpy as np
import calibration_utils

from torchvision.transforms.functional import InterpolationMode
from sklearn.metrics import  confusion_matrix
from matplotlib import pyplot as plt

from Calibrator import Calibrator

In [2]:
def load_model(
        device: str='cpu',
        MODEL: str='mobilenet_v3_large',
        NUM_CLASSES: int=2,
        PATH: str='output/model_4.pth'
        ):
    model = torchvision.models.get_model(MODEL, weights=None, num_classes=NUM_CLASSES)
    state_dict = torch.load(PATH)
    model.load_state_dict(state_dict['model'])
    model.to(device)
    return model, state_dict

In [3]:
model, state_dict = load_model()
args = state_dict["args"]
traindir, valdir = "data/train", "data/val"
interpolation = InterpolationMode(args.interpolation)

In [4]:
preprocessing = presets.ClassificationPresetEval(
                crop_size=args.val_crop_size,
                resize_size=args.val_resize_size,
                interpolation=interpolation,
                backend=args.backend,
                use_v2=args.use_v2,
            )

dataset_train = torchvision.datasets.ImageFolder(
    traindir,
    preprocessing,
)
train_sampler = torch.utils.data.SequentialSampler(dataset_train)
data_loader_train = torch.utils.data.DataLoader(
        dataset_train, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True
    )

dataset_val = torchvision.datasets.ImageFolder(
    valdir,
    preprocessing,
)
val_sampler = torch.utils.data.SequentialSampler(dataset_val)
data_loader_val = torch.utils.data.DataLoader(
        dataset_val, batch_size=args.batch_size, sampler=val_sampler, num_workers=args.workers, pin_memory=True
    )

criterion = nn.CrossEntropyLoss()
calibrator = Calibrator()

In [5]:
def train(model, data_loader, calibrator, device: str="cpu"):
    with torch.inference_mode():
        for image, target in data_loader:
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            calibrator.update(output, target)
    return calibrator.train_calibrator()
            
def eval(model, data_loader, calibrator, device: str="cpu"):
    with torch.inference_mode():
        for image, target in data_loader:
            image = image.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            output = model(image)
            calibrator.update(output, target)
    return calibrator.eval_calibrator()


In [6]:
n_epochs = 4
for epoch in range(8, 10):
    model, _ = load_model(PATH=f'output/model_{epoch}.pth')
    print("start train")
    calibrated_predictions = train(model, data_loader_val, calibrator)
    print("start eval")
    eval(model, data_loader_val, calibrator)
    print("epoch end")

start train
Preprocess collected samples.
Fit platt_scaler
Fit beta_calibrator
platt_scaler predicts calibrated probabilities.
beta_calibrator predicts calibrated probabilities.
start eval
platt_scaler predicts calibrated probabilities.
beta_calibrator predicts calibrated probabilities.
epoch end
start train
Preprocess collected samples.
Fit platt_scaler
Fit beta_calibrator
platt_scaler predicts calibrated probabilities.
beta_calibrator predicts calibrated probabilities.
start eval
platt_scaler predicts calibrated probabilities.
beta_calibrator predicts calibrated probabilities.
epoch end


In [8]:
calibrated_predictions.items()

dict_items([('platt_scaler', array([[0.8225828 , 0.1774172 ],
       [0.82197365, 0.17802635],
       [0.68430432, 0.31569568],
       [0.8226076 , 0.1773924 ],
       [0.81827925, 0.18172075],
       [0.69787469, 0.30212531],
       [0.8225497 , 0.1774503 ],
       [0.79813318, 0.20186682],
       [0.82256979, 0.17743021],
       [0.81139327, 0.18860673],
       [0.8226057 , 0.1773943 ],
       [0.82259936, 0.17740064],
       [0.82230486, 0.17769514],
       [0.82253642, 0.17746358],
       [0.82220935, 0.17779065],
       [0.76989346, 0.23010654],
       [0.79593162, 0.20406838],
       [0.82260758, 0.17739242],
       [0.82260678, 0.17739322],
       [0.82228795, 0.17771205],
       [0.81501987, 0.18498013],
       [0.82087729, 0.17912271],
       [0.82254313, 0.17745687],
       [0.82240143, 0.17759857],
       [0.8226032 , 0.1773968 ],
       [0.71431716, 0.28568284],
       [0.82260503, 0.17739497],
       [0.68336522, 0.31663478],
       [0.82260812, 0.17739188],
       [0.8224