# ViT OOD in PyTorch

In [1]:
import sys
sys.path.append('../')
from utils import roc_btw_arr

In [2]:
device = 'cuda:0'

## dataset

In [3]:
from datasets import load_dataset

train_ds, test_ds = load_dataset('cifar10', split=['train', 'test'])
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']
cifar100_ds = load_dataset('cifar100', split='test');

cifar10_train_ds = load_dataset('cifar10', split='train')

svhn_ds = load_dataset('svhn', 'cropped_digits', split='test')

Found cached dataset cifar10 (/root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

Found cached dataset cifar100 (/root/.cache/huggingface/datasets/cifar100/cifar100/1.0.0/f365c8b725c23e8f0f8d725c3641234d9331cd2f62919d1381d1baa5b3ba3142)
Found cached dataset cifar10 (/root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)
Found cached dataset svhn (/root/.cache/huggingface/datasets/svhn/cropped_digits/1.0.0/8e83fcbe6f6078438cd826c7acd29e3de8ee7db44657b535f6d3453235529a31)


In [4]:
id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)}
label2id = {label:id for id,label in id2label.items()}
id2label

{0: 'airplane',
 1: 'automobile',
 2: 'bird',
 3: 'cat',
 4: 'deer',
 5: 'dog',
 6: 'frog',
 7: 'horse',
 8: 'ship',
 9: 'truck'}

In [5]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [6]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

def val_transforms_svhn(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['image']]
    return examples

In [7]:
# Set the transforms
train_ds.set_transform(train_transforms)

val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

cifar100_ds.set_transform(val_transforms)
cifar10_train_ds.set_transform(val_transforms)

svhn_ds.set_transform(val_transforms_svhn)

In [8]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=10,
                                                  id2label=id2label,
                                                  label2id=label2id)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"
run_name = f"finetune-cifar-10-epoch15"
num_train_epoch = 15

args = TrainingArguments(
    run_name,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=num_train_epoch,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

2022-10-20 03:56:09.816124: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [10]:
from datasets import load_metric
# import evaluate
import numpy as np

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

  metric = load_metric("accuracy")


In [11]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

def collate_fn_cifar100(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["fine_label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=4)
test_dataloader = DataLoader(test_ds, collate_fn=collate_fn, batch_size=4)
cifar100_dataloader = DataLoader(cifar100_ds, collate_fn=collate_fn_cifar100, batch_size=4)
cifar10_train_dataloader = DataLoader(cifar10_train_ds, collate_fn=collate_fn, batch_size=4)

svhn_dataloader = DataLoader(svhn_ds, collate_fn=collate_fn, batch_size=64)

In [12]:
import torch

trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

In [13]:
trainer.train()

***** Running training *****
  Num examples = 45000
  Num Epochs = 15
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 5280
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mswyoon[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.361212,0.9794
2,0.884300,0.200316,0.9836
3,0.374100,0.140144,0.987
4,0.374100,0.106923,0.9884
5,0.288100,0.083366,0.9894


***** Running Evaluation *****
  Num examples = 5000
  Batch size = 128
Saving model checkpoint to finetune-cifar-10-epoch15/checkpoint-352
Configuration saved in finetune-cifar-10-epoch15/checkpoint-352/config.json
Model weights saved in finetune-cifar-10-epoch15/checkpoint-352/pytorch_model.bin
Feature extractor saved in finetune-cifar-10-epoch15/checkpoint-352/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 128
Saving model checkpoint to finetune-cifar-10-epoch15/checkpoint-704
Configuration saved in finetune-cifar-10-epoch15/checkpoint-704/config.json
Model weights saved in finetune-cifar-10-epoch15/checkpoint-704/pytorch_model.bin
Feature extractor saved in finetune-cifar-10-epoch15/checkpoint-704/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 5000
  Batch size = 128
Saving model checkpoint to finetune-cifar-10-epoch15/checkpoint-1056
Configuration saved in finetune-cifar-10-epoch15/checkpoint-1056/config.json


## Classification output

In [None]:
outputs = trainer.predict(test_ds)

In [None]:

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

labels = train_ds.features['label'].names
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(xticks_rotation=45)

# Evaluate

In [None]:
from tqdm.notebook import tqdm

In [None]:
def get_logit(model, dl):
    l_pred = []
    for xx in tqdm(dl):
        with torch.no_grad():
            pred = model(xx['pixel_values'].cuda(), output_hidden_states=False)
            # pred.logits
            # pred.hidden_states
        l_pred.append(pred.logits.detach().cpu())
    return torch.cat(l_pred)

def get_prelogit(model, dl):
    l_pred = []
    for xx in tqdm(dl):
        with torch.no_grad():
            out = model(xx['pixel_values'].cuda(), output_hidden_states=True)
            out = out.hidden_states[-1][:,0,:]
            out = model.vit.layernorm(out)
        l_pred.append(out.detach().cpu())
    return torch.cat(l_pred)

In [None]:
cifar10_pred = get_logit(model, test_dataloader)
cifar100_pred = get_logit(model, cifar100_dataloader)

In [None]:
cifar100_pred.shape, cifar10_pred.shape

In [None]:
from torch.nn.functional import softmax

In [None]:
cifar100_msp = torch.max(softmax(cifar100_pred, dim=1),dim=1).values
cifar10_msp = torch.max(softmax(cifar10_pred, dim=1),dim=1).values

In [None]:
cifar100_msp.shape

In [None]:
roc_btw_arr(cifar10_msp, cifar100_msp)

### Mahalanobis distance

In [None]:
prelogit = get_prelogit(model, cifar10_train_dataloader)

In [None]:
torch.save(prelogit, f'{run_name}/prelogit.pkl')

In [None]:
prelogit

In [None]:
cifar10_train_label = torch.tensor([cifar10_train_ds[i]['label'] for i in range(len(cifar10_train_ds))])

In [None]:
cifar10_train_label.shape

In [None]:
'''mahalanobis statistics computation'''
l_mean = []
l_outer = []
for k in range(10):
    subset_x = prelogit[cifar10_train_label == k]
    subset_mean = torch.mean(subset_x, dim=0, keepdim=True)
    # subset_outer = torch.cov(subset_x - subset_mean, correction=0) * len(subset_x)
    v = subset_x - subset_mean
    subset_outer = v.T.mm(v)
    l_mean.append(subset_mean)
    l_outer.append(subset_outer)
pooled_cov = torch.sum(torch.stack(l_outer), dim=0) / len(prelogit)
all_means = torch.stack(l_mean, dim=-1)
invcov = torch.linalg.inv(pooled_cov)

In [None]:
'''relative mahalanobis statistics'''
whole_mean = torch.mean(prelogit, dim=0, keepdim=True)
v = prelogit - whole_mean
whole_cov = v.T.mm(v) / len(prelogit)
whole_invcov = torch.linalg.inv(whole_cov)

In [None]:
'''save'''
torch.save({'all_means': all_means,
            'invcov': invcov,
            'whole_mean': whole_mean,
            'whole_invcov': whole_invcov}, f'{run_name}/maha-statistic.pkl')

In [None]:
pooled_cov

In [None]:
pooled_cov[0,:10]

In [None]:
whole_cov[0,:10]

In [1]:
import torch

In [57]:
xx = torch.tensor([[1,2], [3,4.]])
# xx = torch.tensor([[1,2],])
mean = torch.tensor([[0.1,0.]])
invcov = torch.tensor([[1, 0.1], [0., 1.]])

In [58]:
zz = xx - mean
zz.mm(invcov).mm(zz.T)

tensor([[ 4.9900, 10.9700],
        [11.1900, 25.5700]])

In [61]:
forward_maha(xx[[0]], mean.unsqueeze(-1), invcov)

tensor([4.9900])

In [62]:
forward_maha(xx[[1]], mean.unsqueeze(-1), invcov)

tensor([25.5700])

In [47]:
z = xx[:1].unsqueeze(-1) - mean
op1 = torch.einsum('ijk,jl->ilk', z, invcov)
op2 = torch.einsum('ijk,ijk->ik', op1, z)

In [48]:
op1

tensor([[[0.9000, 1.0000],
         [1.9000, 2.0000]]])

In [44]:
op1

tensor([[[0.9000, 1.0000],
         [1.9000, 2.0000]],

        [[2.9000, 3.0000],
         [3.9000, 4.0000]]])

In [49]:
op2

tensor([[4.4200, 5.0000]])

In [45]:
op2

tensor([[ 4.4200,  5.0000],
        [23.6200, 25.0000]])

In [30]:
op1

tensor([[[0.9000, 0.8000],
         [3.7000, 3.4000]],

        [[2.9000, 2.8000],
         [9.7000, 9.4000]]])

In [31]:
op2

tensor([[ 7.8400,  6.7600],
        [46.2400, 43.5600]])

In [23]:
torch.bmm(z, invcov.unsqueeze(0))

RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 2] but got: [1, 2].

In [20]:
op1.shape

torch.Size([2, 2, 2])

In [4]:
def forward_maha(xx, mean, invcov):
    """
    mean: [1, D, K]
    invcov: [D, D]
    """
    z = xx.unsqueeze(-1) - mean
    op1 = torch.einsum('ijk,jl->ilk', z, invcov)
    op2 = torch.einsum('ijk,ijk->ik', op1, z)  # [B, K]
    
    return torch.min(op2, dim=1).values

In [3]:
def forward_maha(xx, mean, invcov):
    """
    mean: [1, D, K]
    invcov: [D, D]
    """
    z = xx.unsqueeze(-1) - mean
    op1 = torch.einsum('ijk,jl->ilk', z, invcov)
    op2 = torch.einsum('ijk,ijk->ik', op1, z)  # [B, K]
    
    return torch.min(op2, dim=1).values
    
def forward_maha_dl(model, dataloader, mean, invcov):
    l_score = []
    for xx in tqdm(dataloader):
        
        out = model(xx['pixel_values'].cuda(), output_hidden_states=True)
        out = out.hidden_states[-1][:,0,:]
        out = model.vit.layernorm(out)
        
        prelogit = out.detach().cpu()
        score = forward_maha(prelogit, mean, invcov)
        l_score.append(score)
    return torch.cat(l_score)
# forward_maha(subset_x, all_means, invcov)

cifar10_maha = forward_maha_dl(model, test_dataloader, all_means, invcov)
cifar100_maha = forward_maha_dl(model, cifar100_dataloader, all_means, invcov)

NameError: name 'model' is not defined

In [None]:
roc_btw_arr(cifar100_maha, cifar10_maha)

In [None]:
svhn_maha = forward_maha_dl(model, svhn_dataloader, all_means, invcov)

In [None]:
roc_btw_arr(svhn_maha, cifar10_maha)

In [None]:
rank = np.searchsorted(np.sort(cifar10_maha), svhn_maha)
rank.min()

In [None]:
def forward_rel_maha_dl(model, dataloader, mean, invcov, whole_mean, whole_invcov):
    l_score = []
    for xx in tqdm(dataloader):
        
        out = model(xx['pixel_values'].cuda(), output_hidden_states=True)
        out = out.hidden_states[-1][:,0,:]
        out = model.vit.layernorm(out)
        
        prelogit = out.detach().cpu()
        maha = forward_maha(prelogit, mean, invcov)
        rel_maha = forward_maha(prelogit, whole_mean, whole_invcov)
        l_score.append(maha - rel_maha)
    return torch.cat(l_score)

In [None]:
cifar10_rel_maha = forward_rel_maha_dl(model, test_dataloader, all_means, invcov, whole_mean.unsqueeze(-1), whole_invcov)
cifar100_rel_maha = forward_rel_maha_dl(model, cifar100_dataloader, all_means, invcov, whole_mean.unsqueeze(-1), whole_invcov)

In [None]:
roc_btw_arr(cifar100_rel_maha, cifar10_rel_maha)