In [11]:
%env DATA_DIR=$HOME/datasets
%env EXPERIMENT_BASE=$HOME/experiments/ood_flows
%env LOG_LEVEL=INFO
%env BATCH_SIZE=64
%env OPTIM_LR=0.001
%env OPTIM_M=0.8
%env TRAIN_EPOCHS=100
%env EXC_RESUME=1
%env DATASET_NAME=AMRB2_species
%env MANIFOLD_D=512
%env MODEL_NAME=resnet

env: DATA_DIR=$HOME/datasets
env: EXPERIMENT_BASE=$HOME/experiments/ood_flows
env: LOG_LEVEL=INFO
env: BATCH_SIZE=64
env: OPTIM_LR=0.001
env: OPTIM_M=0.8
env: TRAIN_EPOCHS=100
env: EXC_RESUME=1
env: DATASET_NAME=AMRB2_species
env: MANIFOLD_D=512
env: MODEL_NAME=resnet


In [12]:
import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger

from config import Config, load_config
from datasets import get_data
from models import get_model

In [13]:
# initialize the RNG deterministically
np.random.seed(42)
torch.manual_seed(42)
torch.set_float32_matmul_precision('medium')

config = load_config()

# initialize data attributes and loaders
get_data(config)
config.print_labels()

assert config.datamodule

In [14]:
artifact_dir = WandbLogger.download_artifact(artifact="yasith/uq_project/model-dfriexbx:best")

[34m[1mwandb[0m: Downloading large artifact model-dfriexbx:best, 58.83MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.3


In [15]:
model_randinit = get_model(config)

In [16]:
from pathlib import Path
model_pretrain = model_randinit.load_from_checkpoint(Path(artifact_dir) / "model.ckpt", config=config)
model_pretrain.eval()
del model_randinit

In [19]:
assert config.datamodule
config.datamodule.setup("test")

In [20]:
test_loader = config.datamodule.test_dataloader()

In [21]:
from torchmetrics import Accuracy
accuracy = Accuracy(task="multiclass", num_classes=len(config.get_ind_labels())).cuda()

In [22]:
from models.common import edl_probs
from tqdm import tqdm

accuracy.reset()
classifier_loss = "edl"
for batch_idx, batch in enumerate(tqdm(test_loader)):
    x, y = batch
    x = x.cuda().float()
    y = y.cuda().long()
    
    z, logits, x_pred = model_pretrain(x)
    
    # classifier loss
    if classifier_loss == "edl":
        pY, uY = edl_probs(logits)
    elif classifier_loss == "crossent":
        pY = logits.softmax(-1)
        uY = 1.0 - pY.max(-1)
    elif classifier_loss == "margin":
        pY = logits.sigmoid()
        uY = 1.0 - pY.max(-1)
    else:
        raise ValueError(classifier_loss)
    accuracy.update(pY, y)
    
print(accuracy.compute())

100%|██████████| 3375/3375 [00:23<00:00, 143.57it/s]

tensor(0.6077, device='cuda:0')



