In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as TF
import torchvision.transforms.v2 as TF2
from tqdm.auto import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score, roc_curve

In [2]:
def mean(L):
    if len(L) == 0: return None
    return sum(L) / len(L)

In [3]:
tfm = TF.Compose([TF.ToTensor(), TF.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
ds_trn = torchvision.datasets.ImageFolder('./Cifar10/train', transform=tfm)
ds_test = torchvision.datasets.ImageFolder('./Cifar10/test', transform=tfm)

In [4]:
assert ds_trn.classes == ds_test.classes
num_classes = len(ds_trn.classes)

In [5]:
dl_trn = torch.utils.data.DataLoader(ds_trn, batch_size=32, shuffle=True, drop_last=True)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=64, shuffle=False, drop_last=False)

In [6]:
device = torch.device('cuda:0')

In [7]:
loss_fn = nn.CrossEntropyLoss()

In [8]:
model = torchvision.models.resnet18(weights='DEFAULT')
model.fc = nn.Linear(model.fc.in_features, num_classes, bias=True)
model = model.to(device)

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [10]:
n_epochs = 6
for epoch in tqdm(range(n_epochs), desc='Epoch'):
    model.train()
    losses_trn = []
    for inp, tgt in tqdm(dl_trn, desc='Train', leave=False):
        optimizer.zero_grad()
        out = model(inp.to(device))
        loss = loss_fn(out, tgt.to(device))
        loss.backward()
        optimizer.step()
        losses_trn.append(loss.item())
    
    model.eval()
    losses_test = []
    tgts, preds = [], []
    for inp, tgt in tqdm(dl_test, desc='Test', leave=False):
        with torch.no_grad():
            out = model(inp.to(device))
            loss = loss_fn(out, tgt.to(device))
            tgts.append(tgt)
            preds.append(out.argmax(dim=1).detach().cpu())
        losses_test.append(loss.item())
    tgts = torch.cat(tgts, dim=0)
    preds = torch.cat(preds, dim=0)
    accs = torch.stack([(preds == tgts)[tgts == c].float().mean() for c in range(num_classes)])
    b_acc = accs.mean()
    
    print(epoch, mean(losses_trn), mean(losses_test), b_acc)

Epoch:   0%|          | 0/6 [00:00<?, ?it/s]

Train:   0%|          | 0/1562 [00:00<?, ?it/s]

Test:   0%|          | 0/157 [00:00<?, ?it/s]

0 1.002741907931931 0.6946572527574126 tensor(0.7629)


Train:   0%|          | 0/1562 [00:00<?, ?it/s]

Test:   0%|          | 0/157 [00:00<?, ?it/s]

1 0.6286239525000356 0.5774387950730172 tensor(0.8024)


Train:   0%|          | 0/1562 [00:00<?, ?it/s]

Test:   0%|          | 0/157 [00:00<?, ?it/s]

2 0.47966351742628444 0.5630277306505829 tensor(0.8080)


Train:   0%|          | 0/1562 [00:00<?, ?it/s]

Test:   0%|          | 0/157 [00:00<?, ?it/s]

3 0.3628172308430743 0.5416375523444953 tensor(0.8211)


Train:   0%|          | 0/1562 [00:00<?, ?it/s]

Test:   0%|          | 0/157 [00:00<?, ?it/s]

4 0.2878967738434882 0.5725828984360786 tensor(0.8129)


Train:   0%|          | 0/1562 [00:00<?, ?it/s]

Test:   0%|          | 0/157 [00:00<?, ?it/s]

5 0.22513691635227615 0.5744384775163641 tensor(0.8277)


In [11]:
def collect_preds(dl, model):
    model.eval()
    outs, tgts = [], []
    for inp, tgt in dl:
        with torch.no_grad():
            out = model(inp.to(device))
        outs.append(out.detach().cpu())
        tgts.append(tgt)
    tgts = torch.cat(tgts, dim=0)
    outs = torch.cat(outs, dim=0)
    return outs, tgts

In [12]:
# blur_tfm = TF.GaussianBlur(kernel_size=3)# TF.Compose([TF.ToTensor(), TF.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
# noise_tfm = TF.Compose([TF.ToTensor(), TF.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)), 
#                         TF2.GaussianNoise(mean=0., sigma=0.01), TF.ToPILImage()])
blur_tfm = TF.Compose([TF.GaussianBlur(kernel_size=3, sigma=(1., 2.)), 
                       TF.ToTensor(), TF.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
noise_tfm = TF.Compose([TF.ToTensor(), TF.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)), 
                        TF2.GaussianNoise(mean=0., sigma=0.01)])
ds_blur  = torchvision.datasets.ImageFolder('./Cifar10/test', transform=blur_tfm)
ds_noise = torchvision.datasets.ImageFolder('./Cifar10/test', transform=noise_tfm)
dl_blur = torch.utils.data.DataLoader(ds_blur, batch_size=64, shuffle=False, drop_last=False)
dl_noise = torch.utils.data.DataLoader(ds_noise, batch_size=64, shuffle=False, drop_last=False)

In [13]:
outs_orig, tgts_orig = collect_preds(dl_test, model)
outs_blur, tgts_blur = collect_preds(dl_blur, model)
outs_noise, tgts_noise = collect_preds(dl_noise, model)

In [14]:
assert (tgts_orig == tgts_blur).all()
assert (tgts_orig == tgts_noise).all()

In [15]:
tgts = tgts_orig

In [16]:
def compute_conf_score(outs):
    probs = outs.softmax(dim=1)
    return probs.max(dim=1).values

def compute_neg_entropy_score(outs):
    probs = outs.softmax(dim=1)
    entropies = probs * probs.log()
    return entropies.sum(dim=1)

In [17]:
def compute_OOD_det_score(outs_ID, outs_OOD, score_fn):
    scores_ID, scores_OOD = score_fn(outs_ID), score_fn(outs_OOD)
    scores_ID, scores_OOD = scores_ID.numpy(), scores_OOD.numpy()
    labels_ID = np.ones((len(scores_ID),), dtype=int)
    labels_OOD = np.zeros((len(scores_OOD),), dtype=int)
    scores = np.concatenate((scores_ID, scores_OOD), axis=0)
    labels = np.concatenate((labels_ID, labels_OOD), axis=0)
    return labels, scores

In [18]:
outs_OOD = {'noise': outs_noise, 'blur': outs_blur}
score_fns = {'conf': compute_conf_score, 'negH': compute_neg_entropy_score}

In [19]:
for corr_name, outs in outs_OOD.items():
    for score_name, score_fn in score_fns.items():
        rocauc = roc_auc_score(*compute_OOD_det_score(outs_orig, outs, score_fn))
        print(corr_name, score_name, f'{rocauc*100:.2f}%')

noise conf 72.20%
noise negH 73.41%
blur conf 70.91%
blur negH 71.51%


In [157]:
# модель с lr=1e-5 и более низкой ID точностью
for corr_name, outs in outs_OOD.items():
    for score_name, score_fn in score_fns.items():
        rocauc = roc_auc_score(*compute_OOD_det_score(outs_orig, outs, score_fn))
        print(corr_name, score_name, f'{rocauc*100:.2f}%')

noise conf 69.91%
noise negH 71.19%
blur conf 67.79%
blur negH 68.33%
