In [1]:
%env DATASET_NAME=AMRB2_species
%env MANIFOLD_D=512
%env MODEL_NAME=resnet
%env TRAIN_EPOCHS=50
%env OOD_K=1

env: DATASET_NAME=AMRB2_species
env: MANIFOLD_D=512
env: MODEL_NAME=resnet
env: TRAIN_EPOCHS=50
env: OOD_K=1


In [2]:
import logging
import os

import numpy as np
import torch
import sklearn.metrics
from tqdm import tqdm

from config import Config, load_config
from datasets import init_dataloaders, init_labels, init_shape
from models import get_model_optimizer_and_step
from models.common import gen_topk_accs, load_model_state, save_model_state
from matplotlib import pyplot as plt
from PIL import Image

In [3]:
# initialize the RNG deterministically
np.random.seed(42)
torch.manual_seed(42)

config = load_config()

# initialize data attributes and loaders
init_labels(config)
init_shape(config)
init_dataloaders(config)

config.print_labels()

import wandb.plot

import wandb

wandb.init(
    project="uq_ood",
    name=config.run_name,
    config=config.run_config,
)

INFO:root:LOG_LEVEL=INFO
INFO:root:OOD_K=1
INFO:root:DATA_DIR=/home/pjaya001/datasets
INFO:root:DATASET_NAME=AMRB2_species
INFO:root:MODEL_NAME=resnet
INFO:root:EXPERIMENT_BASE=/home/pjaya001/experiments/ood_flows
INFO:root:MANIFOLD_D=512
INFO:root:BATCH_SIZE=32
INFO:root:OPTIM_LR=0.001
INFO:root:OPTIM_M=0.8
INFO:root:TRAIN_EPOCHS=50
INFO:root:EXC_RESUME=1
INFO:root:Using device: cuda
INFO:root:Dataset file: /home/pjaya001/datasets/AMRB2/ctr_1_fit_f32.imag.npz
INFO:root:[preparation] loaded target info
INFO:root:[preparation] performed train/test split
INFO:root:Prepared datasets in 17.561773538589478 s


Performing ind/ood split


INFO:root:Labels (train, test): ['Acinetobacter', 'E_coli', 'K_pneumoniae', 'S_aureus']
INFO:root:Labels (ood): ['B_subtilis']


Performed ind/ood split
315072 210048 14784


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myasith[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
def epoch_stats_to_wandb(
    stats: dict[str, dict],
    config: Config,
    step: int,
) -> dict:
    # initialize metric dict
    metrics = {}
    metrics["trn/loss"] = None
    metrics["trn/acc"] = dict(acc1=None, acc2=None, acc3=None)
    metrics["tid/loss"] = None
    metrics["tid/acc"] = dict(acc1=None, acc2=None, acc3=None)
    metrics["tod/loss"] = None

    # get label information
    ind_labels, ood_labels = config.get_ind_labels(), config.get_ood_labels()
    Ki, Ko = len(ind_labels), len(ood_labels)

    figs = {}
    if "trn" in stats:
        metrics["trn/loss"] = stats["trn"]["loss"]
        y_true_trn: np.ndarray = stats["trn"]["y_true"]
        y_prob_trn: np.ndarray = stats["trn"]["y_prob"]
        metrics["trn/acc"] = gen_topk_accs(y_true_trn, y_prob_trn, Ki)

        cm = sklearn.metrics.confusion_matrix(
            y_true_trn, y_prob_trn.argmax(-1), labels=list(range(Ki))
        )
        disp = sklearn.metrics.ConfusionMatrixDisplay(cm, display_labels=ind_labels)
        disp.plot()
        fp = os.path.join(config.experiment_path, f"cm_trn_e{step}.png")
        plt.savefig(fp)
        figs["trn/cm"] = wandb.Image(Image.open(fp))

    y_true_tst = []
    y_prob_tst = []

    if "tid" in stats:
        y_true_tid: np.ndarray = stats["tid"]["y_true"]
        y_prob_tid: np.ndarray = stats["tid"]["y_prob"]
        if len(y_true_tid) > 0:
            metrics["tid/loss"] = stats["tid"]["loss"]
            metrics["tid/acc"] = gen_topk_accs(y_true_tid, y_prob_tid, Ki)
            y_true_tst.append(y_true_tid)
            y_prob_tst.append(y_prob_tid)

    if "tod" in stats:
        y_true_tod: np.ndarray = stats["tod"]["y_true"]
        y_prob_tod: np.ndarray = stats["tod"]["y_prob"]
        if len(y_true_tod) > 0:
            metrics["tod/loss"] = stats["tod"]["loss"]
            y_true_tst.append(y_true_tod)
            y_prob_tst.append(y_prob_tod)

    # concatenate y of tid and tod datasets
    y_true_tst = np.concatenate(y_true_tst, axis=0)
    y_prob_tst = np.concatenate(y_prob_tst, axis=0)

    # zero-pad y_prob_tst for ood targets
    y_prob_tst = np.pad(y_prob_tst, pad_width=((0, 0), (0, Ko)))
    perm_labels = ind_labels + ood_labels

    cm = sklearn.metrics.confusion_matrix(
        y_true_tst, y_prob_tst.argmax(-1), labels=list(range(Ki + Ko))
    )
    disp = sklearn.metrics.ConfusionMatrixDisplay(cm, display_labels=perm_labels)
    disp.plot()
    fp = os.path.join(config.experiment_path, f"cm_tst_e{step}.png")
    plt.savefig(fp)
    figs["tst/cm"] = wandb.Image(Image.open(fp))

    tqdm.write(f"Epoch {step}: {metrics}")

    join = lambda x: ",".join(map(perm_labels.__getitem__, x))
    fig_row_lbl = []
    fig_a_data = []
    fig_a_col_lbl = []
    fig_b_data = []
    fig_b_col_lbl = []

    if "trn" in stats:
        fig_row_lbl.append("trn")
        (E_x_trn, E_y_trn, E_xp_trn, E_yp_trn) = zip(*stats["trn"]["samples"])
        E_x_trn = np.concatenate(E_x_trn, axis=1)
        E_xp_trn = np.concatenate(E_xp_trn, axis=1)
        fig_a_data.append(E_x_trn)
        fig_a_col_lbl.append(join(E_y_trn))
        fig_b_data.append(E_xp_trn)
        fig_b_col_lbl.append(join(E_yp_trn))

    if "tid" in stats and len(stats["tid"]["y_true"]) > 0:
        fig_row_lbl.append("tid")
        (E_x_tid, E_y_tid, E_xp_tid, E_yp_tid) = zip(*stats["tid"]["samples"])
        E_x_tid = np.concatenate(E_x_tid, axis=1)
        E_xp_tid = np.concatenate(E_xp_tid, axis=1)
        fig_a_data.append(E_x_tid)
        fig_a_col_lbl.append(join(E_y_tid))
        fig_b_data.append(E_xp_tid)
        fig_b_col_lbl.append(join(E_yp_tid))

    if "tod" in stats and len(stats["tod"]["y_true"]) > 0:
        fig_row_lbl.append("tod")
        (E_x_tod, E_y_tod, E_xp_tod, E_yp_tod) = zip(*stats["tod"]["samples"])
        E_x_tod = np.concatenate(E_x_tod, axis=1)
        E_xp_tod = np.concatenate(E_xp_tod, axis=1)
        fig_a_data.append(E_x_tod)
        fig_a_col_lbl.append(join(E_y_tod))
        fig_b_data.append(E_xp_tod)
        fig_b_col_lbl.append(join(E_yp_tod))

    fig_a_data = np.concatenate(fig_a_data, axis=0)
    fig_b_data = np.concatenate(fig_b_data, axis=0)

    fig_row_lbl = ", ".join([f"R{i+1}={v}" for i, v in enumerate(fig_row_lbl)])
    fig_a_col_lbl = ", ".join([f"R{i+1}={v}" for i, v in enumerate(fig_a_col_lbl)])
    fig_b_col_lbl = ", ".join([f"R{i+1}={v}" for i, v in enumerate(fig_b_col_lbl)])
    fig_a_cap = f"sample inputs - rows: [{fig_row_lbl}] - targets: [{fig_a_col_lbl}]"
    fig_b_cap = f"sample output - rows: [{fig_row_lbl}] - targets: [{fig_b_col_lbl}]"

    figs["samples/input"] = wandb.Image(fig_a_data, caption=fig_a_cap)
    figs["samples/output"] = wandb.Image(fig_b_data, caption=fig_b_cap)

    done_keys = ["y_true", "y_prob", "y_ucty", "samples"]

    for prefix in ["trn", "tid", "tod"]:
        if prefix not in stats:
            continue
        prefix_stats: dict = stats[prefix]
        for key in set(prefix_stats).difference(done_keys):
            val = prefix_stats[key]
            # unbounded histograms
            if key in ["u_norm", "v_norm", "z_norm", "z_nll"]:
                val = np.tanh(val)
                # save v_norm for AUROC computation
                if key in ["v_norm"]:
                    figs[f"{prefix}/{key}"] = val
                hist = np.histogram(val, bins=100, range=(0.0, 1.0))
                figs[f"{prefix}/{key}_hist"] = wandb.Histogram(np_histogram=hist)
            # bounded histograms
            elif key == "y_ucty":
                hist = np.histogram(val, bins=100, range=(0.0, 1.0))
                figs[f"{prefix}/{key}_hist"] = wandb.Histogram(np_histogram=hist)
            # log everything else
            else:
                figs[f"{prefix}/{key}"] = val

    prefix = "ood_detection"
    if "v_norm" in stats["tid"] and "v_norm" in stats["tod"]:
        tid_v_norm = stats["tid"]["v_norm"]
        tod_v_norm = stats["tod"]["v_norm"]
        B_InD = tid_v_norm.shape[0]
        B_OoD = tod_v_norm.shape[0]
        # binary classification labels for ID and OOD
        LABELS = ["InD", "OoD"]
        values = np.concatenate([tid_v_norm, tod_v_norm], axis=0)
        values_2d = np.stack([1.0 - values, values], axis=-1)
        target = np.concatenate([np.zeros(B_InD), np.ones(B_OoD)], axis=0)
        figs[f"{prefix}/roc"] = wandb.plot.roc_curve(target, values_2d, LABELS)
        figs[f"{prefix}/pr"] = wandb.plot.pr_curve(target, values_2d, LABELS)
        figs[f"{prefix}/auroc"] = sklearn.metrics.roc_auc_score(target, values)

    data = {}
    data.update(metrics)
    data.update(figs)
    return data

In [5]:
assert config.train_loader
assert config.test_loader

model, optim, step = get_model_optimizer_and_step(config)

# load saved model and optimizer, if present
load_model_state(model, config)
model = model.float().to(config.device)

wandb.watch(model, log_freq=100)

# run train / test loops
logging.info("Started Train/Test")

artifact = wandb.Artifact(f"{config.run_name}-{config.model_name}", type="model")

# loop over epochs
for epoch in range(1, config.train_epochs + 1):
    epoch_stats: dict = {}

    # train
    trn_stats = step(
        prefix="train",
        model=model,
        epoch=epoch,
        config=config,
        optim=optim,
    )
    epoch_stats["trn"] = trn_stats

    # test (InD)
    tid_stats = step(
        prefix="test_ind",
        model=model,
        epoch=epoch,
        config=config,
    )
    epoch_stats["tid"] = tid_stats

    # test (OoD)
    tod_stats = step(
        prefix="test_ood",
        model=model,
        epoch=epoch,
        config=config,
    )
    epoch_stats["tod"] = tod_stats

    wandb.log(epoch_stats_to_wandb(epoch_stats, config, epoch), step=epoch)

    # save model and optimizer states
    save_model_state(model, config, epoch)
    fp = config.experiment_path
    model_name = config.model_name
    artifact.add_file(os.path.join(fp, f"{model_name}_model_e{epoch}.pth"))

artifact.save()

INFO:root:Started Train/Test


[train] Epoch 1: 100%|██████████| 9332/9332, Loss(agg)=0.2839, Loss(y)=0.2813, Loss(x)=0.0026
[test_ind] Epoch 1: 100%|██████████| 6221/6221, Loss(agg)=0.6182, Loss(y)=0.6165, Loss(x)=0.0017
[test_ood] Epoch 1: 100%|██████████| 1319/1319, Loss(agg)=0.5874, Loss(y)=0.5863, Loss(x)=0.0011
Epoch 1: {'trn/loss': {'agg': 0.5089272060589222, 'y': 0.49223520335942134, 'x': 0.01669200236335348}, 'trn/acc': {'acc1': 0.6373801168024004, 'acc2': 0.8847246035147879, 'acc3': 0.9678860372910416}, 'tid/loss': {'agg': 0.6628909130472785, 'y': 0.6614432970564612, 'x': 0.0014476162296268634}, 'tid/acc': {'acc1': 0.4873312168461662, 'acc2': 0.7707362160424369, 'acc3': 0.9694934496061727}, 'tod/loss': {'agg': 0.5752481530795411, 'y': 0.5741055700433715, 'x': 0.0011425825406764045}}


INFO:root:saving model state - e1


[train] Epoch 2: 100%|██████████| 9332/9332, Loss(agg)=0.3273, Loss(y)=0.3266, Loss(x)=0.0007
[test_ind] Epoch 2: 100%|██████████| 6221/6221, Loss(agg)=0.6140, Loss(y)=0.6125, Loss(x)=0.0015
[test_ood] Epoch 2: 100%|██████████| 1319/1319, Loss(agg)=0.5713, Loss(y)=0.5705, Loss(x)=0.0008


INFO:root:saving model state - e2


Epoch 2: {'trn/loss': {'agg': 0.4327755451547297, 'y': 0.4313237969082949, 'x': 0.0014517484181854733}, 'trn/acc': {'acc1': 0.6884409826403772, 'acc2': 0.904696206600943, 'acc3': 0.9678324582083154}, 'tid/loss': {'agg': 0.596541619625645, 'y': 0.595324418045972, 'x': 0.0012172016298526076}, 'tid/acc': {'acc1': 0.5629169345764347, 'acc2': 0.8235914643947918, 'acc3': 0.9662936023147404}, 'tod/loss': {'agg': 0.5793790769812011, 'y': 0.5785486880263385, 'x': 0.0008303885256527013}}
[train] Epoch 3: 100%|██████████| 9332/9332, Loss(agg)=0.3191, Loss(y)=0.3185, Loss(x)=0.0006
[test_ind] Epoch 3: 100%|██████████| 6221/6221, Loss(agg)=0.6702, Loss(y)=0.6697, Loss(x)=0.0005
[test_ood] Epoch 3: 100%|██████████| 1319/1319, Loss(agg)=0.5420, Loss(y)=0.5416, Loss(x)=0.0004


INFO:root:saving model state - e3


Epoch 3: {'trn/loss': {'agg': 0.40159052087132174, 'y': 0.4008466579256347, 'x': 0.0007438632570655915}, 'trn/acc': {'acc1': 0.7144167916845263, 'acc2': 0.9129272931847406, 'acc3': 0.9672866213030433}, 'tid/loss': {'agg': 0.6033940304467732, 'y': 0.602901456517396, 'x': 0.0004925736531177291}, 'tid/acc': {'acc1': 0.5534178588651342, 'acc2': 0.7936123613566951, 'acc3': 0.9679462706960296}, 'tod/loss': {'agg': 0.5363841451777934, 'y': 0.5359489729089932, 'x': 0.0004351722383504512}}
[train] Epoch 4: 100%|██████████| 9332/9332, Loss(agg)=0.3937, Loss(y)=0.3933, Loss(x)=0.0004
[test_ind] Epoch 4: 100%|██████████| 6221/6221, Loss(agg)=0.8559, Loss(y)=0.8555, Loss(x)=0.0004
[test_ood] Epoch 4: 100%|██████████| 1319/1319, Loss(agg)=0.5565, Loss(y)=0.5561, Loss(x)=0.0004


INFO:root:saving model state - e4


Epoch 4: {'trn/loss': {'agg': 0.38451546302605766, 'y': 0.3839958166281233, 'x': 0.0005196465447524119}, 'trn/acc': {'acc1': 0.7269074153450493, 'acc2': 0.917029441705958, 'acc3': 0.9669350085726532}, 'tid/loss': {'agg': 0.5920119660639568, 'y': 0.5915652115524936, 'x': 0.0004467549323762805}, 'tid/acc': {'acc1': 0.5708185982960938, 'acc2': 0.812826515029738, 'acc3': 0.9672530541713551}, 'tod/loss': {'agg': 0.5716724056832442, 'y': 0.5712953805923462, 'x': 0.00037702593111385075}}
[train] Epoch 5: 100%|██████████| 9332/9332, Loss(agg)=0.4116, Loss(y)=0.4112, Loss(x)=0.0004
[test_ind] Epoch 5: 100%|██████████| 6221/6221, Loss(agg)=0.5565, Loss(y)=0.5562, Loss(x)=0.0003
[test_ood] Epoch 5: 100%|██████████| 1319/1319, Loss(agg)=0.6428, Loss(y)=0.6425, Loss(x)=0.0003


INFO:root:saving model state - e5


Epoch 5: {'trn/loss': {'agg': 0.37186045339750184, 'y': 0.37142533816211964, 'x': 0.0004351151860753702}, 'trn/acc': {'acc1': 0.7373251982426061, 'acc2': 0.9194103621945993, 'acc3': 0.9664293559794256}, 'tid/loss': {'agg': 0.5950941982385795, 'y': 0.5947192109890104, 'x': 0.0003749871476097675}, 'tid/acc': {'acc1': 0.5639366661308471, 'acc2': 0.814318437550233, 'acc3': 0.9679864571612281}, 'tod/loss': {'agg': 0.648531327742894, 'y': 0.6481818005476512, 'x': 0.00034952691228039295}}
[train] Epoch 6: 100%|██████████| 9332/9332, Loss(agg)=0.3167, Loss(y)=0.3137, Loss(x)=0.0030
[test_ind] Epoch 6: 100%|██████████| 6221/6221, Loss(agg)=0.4900, Loss(y)=0.4875, Loss(x)=0.0026
[test_ood] Epoch 6: 100%|██████████| 1319/1319, Loss(agg)=0.5946, Loss(y)=0.5912, Loss(x)=0.0034


INFO:root:saving model state - e6


Epoch 6: {'trn/loss': {'agg': 0.36175090954764766, 'y': 0.3613804698627533, 'x': 0.0003704395602923835}, 'trn/acc': {'acc1': 0.7448698028289755, 'acc2': 0.921690821903129, 'acc3': 0.9663054543506215}, 'tid/loss': {'agg': 0.605748526235808, 'y': 0.6028454313178937, 'x': 0.0029030947761424213}, 'tid/acc': {'acc1': 0.5664935299791031, 'acc2': 0.8178548464877029, 'acc3': 0.9656204790226651}, 'tod/loss': {'agg': 0.6214265493934492, 'y': 0.6179663022419465, 'x': 0.0034602462473643874}}
[train] Epoch 7: 100%|██████████| 9332/9332, Loss(agg)=0.4596, Loss(y)=0.4594, Loss(x)=0.0002
[test_ind] Epoch 7: 100%|██████████| 6221/6221, Loss(agg)=0.7095, Loss(y)=0.7092, Loss(x)=0.0002
[test_ood] Epoch 7: 100%|██████████| 1319/1319, Loss(agg)=0.6571, Loss(y)=0.6568, Loss(x)=0.0003


INFO:root:saving model state - e7


Epoch 7: {'trn/loss': {'agg': 0.3534244723630805, 'y': 0.35308481871353625, 'x': 0.0003396535919665547}, 'trn/acc': {'acc1': 0.7514064509215602, 'acc2': 0.9232379179168453, 'acc3': 0.9660074207029575}, 'tid/loss': {'agg': 0.5972797637956546, 'y': 0.5970156753710891, 'x': 0.0002640882725154092}, 'tid/acc': {'acc1': 0.5671716765793281, 'acc2': 0.82173786368751, 'acc3': 0.960029537051921}, 'tod/loss': {'agg': 0.6140702450429064, 'y': 0.613789917032677, 'x': 0.0002803269805407139}}
[train] Epoch 8: 100%|██████████| 9332/9332, Loss(agg)=0.2492, Loss(y)=0.2490, Loss(x)=0.0003
[test_ind] Epoch 8: 100%|██████████| 6221/6221, Loss(agg)=0.7164, Loss(y)=0.7162, Loss(x)=0.0002
[test_ood] Epoch 8: 100%|██████████| 1319/1319, Loss(agg)=0.5540, Loss(y)=0.5538, Loss(x)=0.0002


INFO:root:saving model state - e8


Epoch 8: {'trn/loss': {'agg': 0.3461995392285712, 'y': 0.34588992515115, 'x': 0.00030961429069439536}, 'trn/acc': {'acc1': 0.7570590441491641, 'acc2': 0.9252236926703815, 'acc3': 0.9661313223317617}, 'tid/loss': {'agg': 0.5927531681697366, 'y': 0.5925098095191597, 'x': 0.00024335927262046989}, 'tid/acc': {'acc1': 0.5727274553930236, 'acc2': 0.8220844719498472, 'acc3': 0.9618077881369554}, 'tod/loss': {'agg': 0.549491623570469, 'y': 0.5492516792661046, 'x': 0.00023994353539652847}}
[train] Epoch 9: 100%|██████████| 9332/9332, Loss(agg)=0.3386, Loss(y)=0.3383, Loss(x)=0.0003
[test_ind] Epoch 9: 100%|██████████| 6221/6221, Loss(agg)=0.6967, Loss(y)=0.6964, Loss(x)=0.0003
[test_ood] Epoch 9: 100%|██████████| 1319/1319, Loss(agg)=0.5801, Loss(y)=0.5799, Loss(x)=0.0002


INFO:root:saving model state - e9


Epoch 9: {'trn/loss': {'agg': 0.33996096864563047, 'y': 0.3396623508443444, 'x': 0.00029861779608876277}, 'trn/acc': {'acc1': 0.7621992873981998, 'acc2': 0.9265397288898414, 'acc3': 0.9660777432490356}, 'tid/loss': {'agg': 0.6113059418537938, 'y': 0.6110651925735545, 'x': 0.0002407489862348995}, 'tid/acc': {'acc1': 0.5690303005947597, 'acc2': 0.8196783073460858, 'acc3': 0.9611045249959813}, 'tod/loss': {'agg': 0.5768623283363093, 'y': 0.5766517774366446, 'x': 0.00021055049236551058}}
[train] Epoch 10: 100%|██████████| 9332/9332, Loss(agg)=0.3218, Loss(y)=0.3214, Loss(x)=0.0004
[test_ind] Epoch 10: 100%|██████████| 6221/6221, Loss(agg)=0.4633, Loss(y)=0.4631, Loss(x)=0.0003
[test_ood] Epoch 10: 100%|██████████| 1319/1319, Loss(agg)=0.5259, Loss(y)=0.5257, Loss(x)=0.0002


INFO:root:saving model state - e10


Epoch 10: {'trn/loss': {'agg': 0.3333250547360439, 'y': 0.3330369456114301, 'x': 0.00028810918247839176}, 'trn/acc': {'acc1': 0.7673529254179169, 'acc2': 0.9271190527218174, 'acc3': 0.9654615837976854}, 'tid/loss': {'agg': 0.5959059651968008, 'y': 0.5956385669059988, 'x': 0.00026739848647510273}, 'tid/acc': {'acc1': 0.5802021379199486, 'acc2': 0.8226571290789262, 'acc3': 0.9585124979906767}, 'tod/loss': {'agg': 0.5609277870458035, 'y': 0.5607444072104115, 'x': 0.0001833779013883732}}
[train] Epoch 11: 100%|██████████| 9332/9332, Loss(agg)=0.3991, Loss(y)=0.3988, Loss(x)=0.0003
[test_ind] Epoch 11: 100%|██████████| 6221/6221, Loss(agg)=0.6643, Loss(y)=0.6641, Loss(x)=0.0002
[test_ood] Epoch 11: 100%|██████████| 1319/1319, Loss(agg)=0.8564, Loss(y)=0.8561, Loss(x)=0.0003


  fig, ax = plt.subplots()
INFO:root:saving model state - e11


Epoch 11: {'trn/loss': {'agg': 0.4455713402580484, 'y': 0.4452174962033185, 'x': 0.0003538439669880351}, 'trn/acc': {'acc1': 0.7580904414916416, 'acc2': 0.8821795970852979, 'acc3': 0.9296540130732962}, 'tid/loss': {'agg': 0.7488356322035723, 'y': 0.7485613772396182, 'x': 0.0002742544002889929}, 'tid/acc': {'acc1': 0.5614903150618872, 'acc2': 0.8458849059636714, 'acc3': 0.9176780260408295}, 'tod/loss': {'agg': 0.8173045577019611, 'y': 0.8169782488431055, 'x': 0.0003263104867393651}}
[train] Epoch 12: 100%|██████████| 9332/9332, Loss(agg)=0.3882, Loss(y)=0.3879, Loss(x)=0.0003
[test_ind] Epoch 12:  79%|███████▉  | 4920/6221, Loss(agg)=0.5247, Loss(y)=0.5243, Loss(x)=0.0004

wandb: Network error (ReadTimeout), entering retry loop.


[test_ind] Epoch 12: 100%|██████████| 6221/6221, Loss(agg)=0.8020, Loss(y)=0.8018, Loss(x)=0.0002
[test_ood] Epoch 12: 100%|██████████| 1319/1319, Loss(agg)=0.9099, Loss(y)=0.9095, Loss(x)=0.0004


INFO:root:saving model state - e12


Epoch 12: {'trn/loss': {'agg': 0.47655646630873816, 'y': 0.4762570605486746, 'x': 0.00029940649771400635}, 'trn/acc': {'acc1': 0.7600360319331333, 'acc2': 0.8687781290184312, 'acc3': 0.9183086423060437}, 'tid/loss': {'agg': 0.783184092396652, 'y': 0.7829153745758095, 'x': 0.0002687171428088815}, 'tid/acc': {'acc1': 0.573938072657129, 'acc2': 0.8476330171998071, 'acc3': 0.9209331297219097}, 'tod/loss': {'agg': 0.948495608780581, 'y': 0.9481395196517007, 'x': 0.00035608897766841883}}
[train] Epoch 13: 100%|██████████| 9332/9332, Loss(agg)=0.4522, Loss(y)=0.4518, Loss(x)=0.0004
[test_ind] Epoch 13: 100%|██████████| 6221/6221, Loss(agg)=0.9265, Loss(y)=0.9263, Loss(x)=0.0003
[test_ood] Epoch 13: 100%|██████████| 1319/1319, Loss(agg)=0.8033, Loss(y)=0.8030, Loss(x)=0.0003


INFO:root:saving model state - e13


Epoch 13: {'trn/loss': {'agg': 0.4997899272314249, 'y': 0.49948375699406933, 'x': 0.0003061700132932326}, 'trn/acc': {'acc1': 0.765782388555508, 'acc2': 0.8682255947278182, 'acc3': 0.9127665559365624}, 'tid/loss': {'agg': 0.8768758040362856, 'y': 0.8766100360465345, 'x': 0.0002657672204107185}, 'tid/acc': {'acc1': 0.5504942935219418, 'acc2': 0.8190252772866099, 'acc3': 0.8994484407651503}, 'tod/loss': {'agg': 0.8801727280913202, 'y': 0.8798707731751622, 'x': 0.0003019567211691666}}
[train] Epoch 14: 100%|██████████| 9332/9332, Loss(agg)=0.4551, Loss(y)=0.4549, Loss(x)=0.0002
[test_ind] Epoch 14: 100%|██████████| 6221/6221, Loss(agg)=0.7307, Loss(y)=0.7304, Loss(x)=0.0003
[test_ood] Epoch 14: 100%|██████████| 1319/1319, Loss(agg)=1.0118, Loss(y)=1.0116, Loss(x)=0.0002


INFO:root:saving model state - e14


Epoch 14: {'trn/loss': {'agg': 0.5182435482060174, 'y': 0.5179392320678385, 'x': 0.00030431656190120585}, 'trn/acc': {'acc1': 0.7673328332618946, 'acc2': 0.8637517413201886, 'acc3': 0.9074488319759966}, 'tid/loss': {'agg': 0.8583806712582187, 'y': 0.8581434599965394, 'x': 0.0002372119942856056}, 'tid/acc': {'acc1': 0.5894952579971066, 'acc2': 0.8043873573380486, 'acc3': 0.8767179713872367}, 'tod/loss': {'agg': 0.8804031529600101, 'y': 0.8802242503191505, 'x': 0.00017890300063068263}}
[train] Epoch 15: 100%|██████████| 9332/9332, Loss(agg)=0.6120, Loss(y)=0.6118, Loss(x)=0.0002
[test_ind] Epoch 15: 100%|██████████| 6221/6221, Loss(agg)=0.9523, Loss(y)=0.9521, Loss(x)=0.0003
[test_ood] Epoch 15: 100%|██████████| 1319/1319, Loss(agg)=0.7800, Loss(y)=0.7796, Loss(x)=0.0004


INFO:root:saving model state - e15


Epoch 15: {'trn/loss': {'agg': 0.5331423102075252, 'y': 0.5328426230275677, 'x': 0.0002996881057849502}, 'trn/acc': {'acc1': 0.7662980872267466, 'acc2': 0.8585110372910416, 'acc3': 0.902492766823832}, 'tid/loss': {'agg': 0.956522321216241, 'y': 0.9562586486636586, 'x': 0.0002636723747758795}, 'tid/acc': {'acc1': 0.5634845683973637, 'acc2': 0.7903622809837647, 'acc3': 0.851711943417457}, 'tod/loss': {'agg': 0.7787345433614759, 'y': 0.7784317885048919, 'x': 0.00030275406798927173}}
[train] Epoch 16: 100%|██████████| 9332/9332, Loss(agg)=0.8610, Loss(y)=0.8608, Loss(x)=0.0003
[test_ind] Epoch 16: 100%|██████████| 6221/6221, Loss(agg)=1.0330, Loss(y)=1.0327, Loss(x)=0.0003
[test_ood] Epoch 16: 100%|██████████| 1319/1319, Loss(agg)=0.7887, Loss(y)=0.7885, Loss(x)=0.0003


INFO:root:saving model state - e16


Epoch 16: {'trn/loss': {'agg': 0.5465370679407533, 'y': 0.5462361047953763, 'x': 0.0003009636326166011}, 'trn/acc': {'acc1': 0.7633880732961852, 'acc2': 0.8539534665666524, 'acc3': 0.8977644127732534}, 'tid/loss': {'agg': 0.9529543197139669, 'y': 0.9526726629952107, 'x': 0.00028165760958785685}, 'tid/acc': {'acc1': 0.5842710175212988, 'acc2': 0.8010418341102716, 'acc3': 0.8721316910464556}, 'tod/loss': {'agg': 0.8457501293855513, 'y': 0.8454821178911309, 'x': 0.00026801178131042117}}
[train] Epoch 17: 100%|██████████| 9332/9332, Loss(agg)=0.5435, Loss(y)=0.5433, Loss(x)=0.0002
[test_ind] Epoch 17: 100%|██████████| 6221/6221, Loss(agg)=0.9607, Loss(y)=0.9605, Loss(x)=0.0002
[test_ood] Epoch 17: 100%|██████████| 1319/1319, Loss(agg)=1.1140, Loss(y)=1.1137, Loss(x)=0.0003


INFO:root:saving model state - e17


Epoch 17: {'trn/loss': {'agg': 0.5593168215129158, 'y': 0.5590093381418673, 'x': 0.00030748365078467924}, 'trn/acc': {'acc1': 0.7594165237891127, 'acc2': 0.8495767252464638, 'acc3': 0.8932101907415345}, 'tid/loss': {'agg': 0.962295627569083, 'y': 0.9620377839023594, 'x': 0.00025784611498682957}, 'tid/acc': {'acc1': 0.569186023147404, 'acc2': 0.7837867706156566, 'acc3': 0.8457743931843755}, 'tod/loss': {'agg': 1.0525761284640198, 'y': 1.0522926436670808, 'x': 0.0002834819807518144}}
[train] Epoch 18: 100%|██████████| 9332/9332, Loss(agg)=0.5202, Loss(y)=0.5199, Loss(x)=0.0003
[test_ind] Epoch 18: 100%|██████████| 6221/6221, Loss(agg)=0.8238, Loss(y)=0.8235, Loss(x)=0.0004
[test_ood] Epoch 18: 100%|██████████| 1319/1319, Loss(agg)=1.1788, Loss(y)=1.1785, Loss(x)=0.0003
Epoch 18: {'trn/loss': {'agg': 0.570751490110537, 'y': 0.5704509900524363, 'x': 0.00030050076681271173}, 'trn/acc': {'acc1': 0.7565199046292328, 'acc2': 0.846067295327904, 'acc3': 0.8893926810972996}, 'tid/loss': {'agg': 0

INFO:root:saving model state - e18


[train] Epoch 19: 100%|██████████| 9332/9332, Loss(agg)=0.4927, Loss(y)=0.4924, Loss(x)=0.0003
[test_ind] Epoch 19: 100%|██████████| 6221/6221, Loss(agg)=0.8654, Loss(y)=0.8652, Loss(x)=0.0002
[test_ood] Epoch 19: 100%|██████████| 1319/1319, Loss(agg)=0.7967, Loss(y)=0.7965, Loss(x)=0.0002


INFO:root:saving model state - e19


Epoch 19: {'trn/loss': {'agg': 0.5800243166179117, 'y': 0.5797318088244431, 'x': 0.00029250786468120984}, 'trn/acc': {'acc1': 0.7522871570938705, 'acc2': 0.8433012751821689, 'acc3': 0.8861645413630519}, 'tid/loss': {'agg': 1.0075495593121648, 'y': 1.007308538886363, 'x': 0.00024102170844306355}, 'tid/acc': {'acc1': 0.5778411830895355, 'acc2': 0.7955814981514227, 'acc3': 0.8610151101109147}, 'tod/loss': {'agg': 0.92033099685319, 'y': 0.9201409337283808, 'x': 0.00019006591505899538}}
[train] Epoch 20: 100%|██████████| 9332/9332, Loss(agg)=0.4908, Loss(y)=0.4906, Loss(x)=0.0002
[test_ind] Epoch 20: 100%|██████████| 6221/6221, Loss(agg)=1.0918, Loss(y)=1.0916, Loss(x)=0.0002
[test_ood] Epoch 20: 100%|██████████| 1319/1319, Loss(agg)=0.7463, Loss(y)=0.7461, Loss(x)=0.0002


INFO:root:saving model state - e20


Epoch 20: {'trn/loss': {'agg': 0.590707225850402, 'y': 0.5904134072641216, 'x': 0.0002938195782463546}, 'trn/acc': {'acc1': 0.7464235962280326, 'acc2': 0.838422229961423, 'acc3': 0.8804215334333476}, 'tid/loss': {'agg': 0.9987240667570022, 'y': 0.9984819518084542, 'x': 0.00024211639787761796}, 'tid/acc': {'acc1': 0.5651422600868028, 'acc2': 0.7824405240315062, 'acc3': 0.8434686947436103}, 'tod/loss': {'agg': 0.8793088520811397, 'y': 0.8791205286527782, 'x': 0.00018832073274883707}}
[train] Epoch 21: 100%|██████████| 9332/9332, Loss(agg)=0.4227, Loss(y)=0.4220, Loss(x)=0.0007
[test_ind] Epoch 21: 100%|██████████| 6221/6221, Loss(agg)=1.0116, Loss(y)=1.0111, Loss(x)=0.0005
[test_ood] Epoch 21: 100%|██████████| 1319/1319, Loss(agg)=0.7139, Loss(y)=0.7136, Loss(x)=0.0003


INFO:root:saving model state - e21


Epoch 21: {'trn/loss': {'agg': 0.5887625518590659, 'y': 0.5884679204249811, 'x': 0.00029463179351645573}, 'trn/acc': {'acc1': 0.7440091888126875, 'acc2': 0.8369789434204886, 'acc3': 0.8787672792541792}, 'tid/loss': {'agg': 1.0148211401809242, 'y': 1.0143298840018795, 'x': 0.0004912578210695288}, 'tid/acc': {'acc1': 0.5713510689599742, 'acc2': 0.7732127069602958, 'acc3': 0.848180557788137}, 'tod/loss': {'agg': 0.6926551107387817, 'y': 0.6923410278364417, 'x': 0.0003140826623382328}}
[train] Epoch 22: 100%|██████████| 9332/9332, Loss(agg)=0.5285, Loss(y)=0.5282, Loss(x)=0.0003
[test_ind] Epoch 22: 100%|██████████| 6221/6221, Loss(agg)=1.5883, Loss(y)=1.5880, Loss(x)=0.0003
[test_ood] Epoch 22: 100%|██████████| 1319/1319, Loss(agg)=0.6462, Loss(y)=0.6460, Loss(x)=0.0002


INFO:root:saving model state - e22


Epoch 22: {'trn/loss': {'agg': 0.5862991629057415, 'y': 0.5860082403055292, 'x': 0.0002909228160755785}, 'trn/acc': {'acc1': 0.7471636573081869, 'acc2': 0.8403611230175739, 'acc3': 0.882698644449207}, 'tid/loss': {'agg': 1.0053768618191923, 'y': 1.0051061844178055, 'x': 0.00027067987763805616}, 'tid/acc': {'acc1': 0.5843212506027969, 'acc2': 0.8219438193216525, 'acc3': 0.8870860794084552}, 'tod/loss': {'agg': 0.7276727281979669, 'y': 0.7274631591255327, 'x': 0.00020956643358391227}}
[train] Epoch 23: 100%|██████████| 9332/9332, Loss(agg)=0.8524, Loss(y)=0.8520, Loss(x)=0.0004
[test_ind] Epoch 23: 100%|██████████| 6221/6221, Loss(agg)=0.6615, Loss(y)=0.6612, Loss(x)=0.0003
[test_ood] Epoch 23: 100%|██████████| 1319/1319, Loss(agg)=1.0232, Loss(y)=1.0229, Loss(x)=0.0004


INFO:root:saving model state - e23


Epoch 23: {'trn/loss': {'agg': 0.5859898361419333, 'y': 0.5856966271204495, 'x': 0.00029320933565368557}, 'trn/acc': {'acc1': 0.7436776682383197, 'acc2': 0.8375281290184312, 'acc3': 0.8797919792113159}, 'tid/loss': {'agg': 1.0040123918071435, 'y': 1.0037333789678118, 'x': 0.00027901484318246884}, 'tid/acc': {'acc1': 0.5986828886031185, 'acc2': 0.8230740636553608, 'acc3': 0.8895826635589134}, 'tod/loss': {'agg': 0.935049146481587, 'y': 0.934722593917008, 'x': 0.00032655436236390505}}
[train] Epoch 24: 100%|██████████| 9332/9332, Loss(agg)=0.7839, Loss(y)=0.7836, Loss(x)=0.0003
[test_ind] Epoch 24: 100%|██████████| 6221/6221, Loss(agg)=1.0755, Loss(y)=1.0752, Loss(x)=0.0003
[test_ood] Epoch 24: 100%|██████████| 1319/1319, Loss(agg)=1.0530, Loss(y)=1.0528, Loss(x)=0.0002


INFO:root:saving model state - e24


Epoch 24: {'trn/loss': {'agg': 0.583788717937633, 'y': 0.5834941220191534, 'x': 0.0002945964509125562}, 'trn/acc': {'acc1': 0.7441297417488213, 'acc2': 0.8385896645949421, 'acc3': 0.8805219942134591}, 'tid/loss': {'agg': 1.032541246414951, 'y': 1.032237295283514, 'x': 0.0003039535343814824}, 'tid/acc': {'acc1': 0.5755606011895193, 'acc2': 0.8023378476129239, 'acc3': 0.8675705272464234}, 'tod/loss': {'agg': 0.9212335451612697, 'y': 0.9210205956604374, 'x': 0.00021294756334238292}}
[train] Epoch 25: 100%|██████████| 9332/9332, Loss(agg)=0.4115, Loss(y)=0.4112, Loss(x)=0.0003
[test_ind] Epoch 25: 100%|██████████| 6221/6221, Loss(agg)=0.7458, Loss(y)=0.7455, Loss(x)=0.0003
[test_ood] Epoch 25: 100%|██████████| 1319/1319, Loss(agg)=1.0837, Loss(y)=1.0833, Loss(x)=0.0004


INFO:root:saving model state - e25


Epoch 25: {'trn/loss': {'agg': 0.5830497770026115, 'y': 0.5827637969157341, 'x': 0.00028598043135019983}, 'trn/acc': {'acc1': 0.743091647021003, 'acc2': 0.8371832136733819, 'acc3': 0.8791121945992285}, 'tid/loss': {'agg': 1.0194303490613132, 'y': 1.0190799942959008, 'x': 0.0003503560464440494}, 'tid/acc': {'acc1': 0.563399172158817, 'acc2': 0.7646881530300594, 'acc3': 0.8281526281948239}, 'tod/loss': {'agg': 0.9537140174830294, 'y': 0.9533630570197665, 'x': 0.00035096141742277604}}
[train] Epoch 26: 100%|██████████| 9332/9332, Loss(agg)=0.5572, Loss(y)=0.5569, Loss(x)=0.0003
[test_ind] Epoch 26: 100%|██████████| 6221/6221, Loss(agg)=0.9938, Loss(y)=0.9935, Loss(x)=0.0003
[test_ood] Epoch 26: 100%|██████████| 1319/1319, Loss(agg)=1.3224, Loss(y)=1.3222, Loss(x)=0.0002


INFO:root:saving model state - e26


Epoch 26: {'trn/loss': {'agg': 0.5816104370179321, 'y': 0.5813131433345962, 'x': 0.00029729363112558226}, 'trn/acc': {'acc1': 0.7418559794256322, 'acc2': 0.8379935972996142, 'acc3': 0.8797417488212602}, 'tid/loss': {'agg': 0.9998837496532618, 'y': 0.9996117070889745, 'x': 0.0002720452953759933}, 'tid/acc': {'acc1': 0.57364169747629, 'acc2': 0.7837114209934094, 'acc3': 0.8572375823822537}, 'tod/loss': {'agg': 1.3030860487758615, 'y': 1.3028494013449385, 'x': 0.0002366478072319263}}
[train] Epoch 27: 100%|██████████| 9332/9332, Loss(agg)=0.5428, Loss(y)=0.5424, Loss(x)=0.0003
[test_ind] Epoch 27: 100%|██████████| 6221/6221, Loss(agg)=1.0851, Loss(y)=1.0848, Loss(x)=0.0003
[test_ood] Epoch 27: 100%|██████████| 1319/1319, Loss(agg)=0.9734, Loss(y)=0.9732, Loss(x)=0.0002


INFO:root:saving model state - e27


Epoch 27: {'trn/loss': {'agg': 0.5811529571978488, 'y': 0.5808666272737977, 'x': 0.0002863300047259765}, 'trn/acc': {'acc1': 0.7399070402914703, 'acc2': 0.8370794042006001, 'acc3': 0.8785630090012859}, 'tid/loss': {'agg': 1.0185797554880185, 'y': 1.0183098464126354, 'x': 0.0002699087020717867}, 'tid/acc': {'acc1': 0.582146158173927, 'acc2': 0.7970231875904196, 'acc3': 0.8640290950008037}, 'tod/loss': {'agg': 0.9650474439704843, 'y': 0.9648204373987992, 'x': 0.00022700395939016956}}
[train] Epoch 28: 100%|██████████| 9332/9332, Loss(agg)=0.4388, Loss(y)=0.4386, Loss(x)=0.0002
[test_ind] Epoch 28: 100%|██████████| 6221/6221, Loss(agg)=0.9273, Loss(y)=0.9270, Loss(x)=0.0002
[test_ood] Epoch 28: 100%|██████████| 1319/1319, Loss(agg)=1.0451, Loss(y)=1.0449, Loss(x)=0.0002


INFO:root:saving model state - e28


Epoch 28: {'trn/loss': {'agg': 0.5806179157639081, 'y': 0.5803291514528651, 'x': 0.0002887641581824166}, 'trn/acc': {'acc1': 0.7416885447921132, 'acc2': 0.8381375910844406, 'acc3': 0.8803913951993142}, 'tid/loss': {'agg': 1.006877509090529, 'y': 1.006612782941296, 'x': 0.0002647282920931134}, 'tid/acc': {'acc1': 0.5821913679472753, 'acc2': 0.7973547259283074, 'acc3': 0.861195949204308}, 'tod/loss': {'agg': 0.9344836768098994, 'y': 0.9342245867615672, 'x': 0.00025908952817028507}}
[train] Epoch 29: 100%|██████████| 9332/9332, Loss(agg)=0.4403, Loss(y)=0.4399, Loss(x)=0.0003
[test_ind] Epoch 29: 100%|██████████| 6221/6221, Loss(agg)=0.7780, Loss(y)=0.7777, Loss(x)=0.0003
[test_ood] Epoch 29: 100%|██████████| 1319/1319, Loss(agg)=0.6109, Loss(y)=0.6107, Loss(x)=0.0002


INFO:root:saving model state - e29


Epoch 29: {'trn/loss': {'agg': 0.5791990410646706, 'y': 0.5789102137414519, 'x': 0.00028882816211660826}, 'trn/acc': {'acc1': 0.7420769931418774, 'acc2': 0.8377826296613802, 'acc3': 0.8795106890270039}, 'tid/loss': {'agg': 1.0095168301634152, 'y': 1.0091857952744683, 'x': 0.0003310365523165436}, 'tid/acc': {'acc1': 0.5609025880083588, 'acc2': 0.7669586883137759, 'acc3': 0.8188293682687671}, 'tod/loss': {'agg': 0.603608781185179, 'y': 0.6033820598933441, 'x': 0.0002267189600973244}}
[train] Epoch 30: 100%|██████████| 9332/9332, Loss(agg)=0.5147, Loss(y)=0.5146, Loss(x)=0.0002
[test_ind] Epoch 30: 100%|██████████| 6221/6221, Loss(agg)=0.6661, Loss(y)=0.6659, Loss(x)=0.0002
[test_ood] Epoch 30: 100%|██████████| 1319/1319, Loss(agg)=0.8599, Loss(y)=0.8596, Loss(x)=0.0003


INFO:root:saving model state - e30


Epoch 30: {'trn/loss': {'agg': 0.5775262204314627, 'y': 0.5772392627632623, 'x': 0.00028695842112479956}, 'trn/acc': {'acc1': 0.7408681150878696, 'acc2': 0.8369421078011144, 'acc3': 0.878589798542649}, 'tid/loss': {'agg': 1.0031825660787352, 'y': 1.0029211687257944, 'x': 0.0002613982736364373}, 'tid/acc': {'acc1': 0.587445748271982, 'acc2': 0.7984548304131168, 'acc3': 0.8712074023468895}, 'tod/loss': {'agg': 0.9347474830813983, 'y': 0.9344508445470779, 'x': 0.00029663705010218806}}
[train] Epoch 31: 100%|██████████| 9332/9332, Loss(agg)=0.4604, Loss(y)=0.4599, Loss(x)=0.0004
[test_ind] Epoch 31: 100%|██████████| 6221/6221, Loss(agg)=1.1814, Loss(y)=1.1811, Loss(x)=0.0003
[test_ood] Epoch 31: 100%|██████████| 1319/1319, Loss(agg)=1.0166, Loss(y)=1.0163, Loss(x)=0.0003


INFO:root:saving model state - e31


Epoch 31: {'trn/loss': {'agg': 0.5777175985308635, 'y': 0.5774224359227565, 'x': 0.0002951634084653703}, 'trn/acc': {'acc1': 0.7405165023574797, 'acc2': 0.8371999571367338, 'acc3': 0.8780506590227175}, 'tid/loss': {'agg': 1.055361295705477, 'y': 1.0549630384806594, 'x': 0.0003982610876740161}, 'tid/acc': {'acc1': 0.5596316910464556, 'acc2': 0.771796134062048, 'acc3': 0.8299610191287574}, 'tod/loss': {'agg': 1.1259680270154517, 'y': 1.1256799392848416, 'x': 0.0002880886633669671}}
[train] Epoch 32: 100%|██████████| 9332/9332, Loss(agg)=0.4887, Loss(y)=0.4885, Loss(x)=0.0002
[test_ind] Epoch 32: 100%|██████████| 6221/6221, Loss(agg)=0.9538, Loss(y)=0.9534, Loss(x)=0.0004
[test_ood] Epoch 32: 100%|██████████| 1319/1319, Loss(agg)=1.1012, Loss(y)=1.1009, Loss(x)=0.0003


INFO:root:saving model state - e32


Epoch 32: {'trn/loss': {'agg': 0.5762610076050415, 'y': 0.5759799741975008, 'x': 0.0002810334859014598}, 'trn/acc': {'acc1': 0.7418526307329618, 'acc2': 0.836955502571796, 'acc3': 0.8783453439777111}, 'tid/loss': {'agg': 1.0407355583867657, 'y': 1.0404302260774525, 'x': 0.0003053335216485238}, 'tid/acc': {'acc1': 0.5713359990355248, 'acc2': 0.7856102314740395, 'acc3': 0.8441418180356856}, 'tod/loss': {'agg': 1.051305094526246, 'y': 1.0510722348145953, 'x': 0.00023285790456535182}}
[train] Epoch 33: 100%|██████████| 9332/9332, Loss(agg)=0.4619, Loss(y)=0.4616, Loss(x)=0.0002
[test_ind] Epoch 33: 100%|██████████| 6221/6221, Loss(agg)=1.2283, Loss(y)=1.2280, Loss(x)=0.0003
[test_ood] Epoch 33: 100%|██████████| 1319/1319, Loss(agg)=1.0961, Loss(y)=1.0958, Loss(x)=0.0002


INFO:root:saving model state - e33


Epoch 33: {'trn/loss': {'agg': 0.576837272837956, 'y': 0.5765497322871712, 'x': 0.000287540987145781}, 'trn/acc': {'acc1': 0.7412297738962709, 'acc2': 0.8362857640377197, 'acc3': 0.8774110587226747}, 'tid/loss': {'agg': 1.0287272120102715, 'y': 1.0283742733485368, 'x': 0.0003529390602878706}, 'tid/acc': {'acc1': 0.5813374055618068, 'acc2': 0.8065574264587687, 'acc3': 0.8794657209451857}, 'tod/loss': {'agg': 1.051310122826499, 'y': 1.0510660336720754, 'x': 0.0002440866918290363}}
[train] Epoch 34: 100%|██████████| 9332/9332, Loss(agg)=0.6240, Loss(y)=0.6238, Loss(x)=0.0002
[test_ind] Epoch 34: 100%|██████████| 6221/6221, Loss(agg)=1.0464, Loss(y)=1.0461, Loss(x)=0.0003
[test_ood] Epoch 34: 100%|██████████| 1319/1319, Loss(agg)=0.9295, Loss(y)=0.9293, Loss(x)=0.0002


INFO:root:saving model state - e34


Epoch 34: {'trn/loss': {'agg': 0.5748074823839568, 'y': 0.5745320194143366, 'x': 0.000275463207219395}, 'trn/acc': {'acc1': 0.7418090977282469, 'acc2': 0.8374511090870125, 'acc3': 0.8784926864552078}, 'tid/loss': {'agg': 1.030103637937027, 'y': 1.0298158583524546, 'x': 0.00028778290159517074}, 'tid/acc': {'acc1': 0.5644189037132293, 'acc2': 0.7696009484005787, 'acc3': 0.8367676016717569}, 'tod/loss': {'agg': 0.7693140903054267, 'y': 0.769082711785015, 'x': 0.00023137819031825383}}
[train] Epoch 35: 100%|██████████| 9332/9332, Loss(agg)=0.5528, Loss(y)=0.5526, Loss(x)=0.0002
[test_ind] Epoch 35: 100%|██████████| 6221/6221, Loss(agg)=1.3520, Loss(y)=1.3518, Loss(x)=0.0002
[test_ood] Epoch 35: 100%|██████████| 1319/1319, Loss(agg)=1.2540, Loss(y)=1.2537, Loss(x)=0.0002


INFO:root:saving model state - e35


Epoch 35: {'trn/loss': {'agg': 0.5735329993410883, 'y': 0.5732526559489661, 'x': 0.0002803434452603577}, 'trn/acc': {'acc1': 0.7438015698671239, 'acc2': 0.8387437044577797, 'acc3': 0.879959413844835}, 'tid/loss': {'agg': 1.1083971480237884, 'y': 1.1081628107771653, 'x': 0.00023434028195933894}, 'tid/acc': {'acc1': 0.588259524192252, 'acc2': 0.8052815061887156, 'acc3': 0.8663498633660184}, 'tod/loss': {'agg': 1.2414537059076813, 'y': 1.2412076784499764, 'x': 0.0002460296913964492}}
[train] Epoch 36: 100%|██████████| 9332/9332, Loss(agg)=0.5238, Loss(y)=0.5235, Loss(x)=0.0003
[test_ind] Epoch 36: 100%|██████████| 6221/6221, Loss(agg)=1.3216, Loss(y)=1.3212, Loss(x)=0.0004
[test_ood] Epoch 36: 100%|██████████| 1319/1319, Loss(agg)=0.8963, Loss(y)=0.8961, Loss(x)=0.0003


INFO:root:saving model state - e36


Epoch 36: {'trn/loss': {'agg': 0.5728636245441652, 'y': 0.5725832529816308, 'x': 0.00028037217570289096}, 'trn/acc': {'acc1': 0.7444244267038148, 'acc2': 0.8395005090012859, 'acc3': 0.8800766180882983}, 'tid/loss': {'agg': 1.0088699507504448, 'y': 1.0083792425154336, 'x': 0.0004907098626782673}, 'tid/acc': {'acc1': 0.5863908535605208, 'acc2': 0.8050956437871725, 'acc3': 0.8733824947757596}, 'tod/loss': {'agg': 0.8096639908923625, 'y': 0.8094023106177708, 'x': 0.00026167957633149464}}
[train] Epoch 37: 100%|██████████| 9332/9332, Loss(agg)=0.6028, Loss(y)=0.6026, Loss(x)=0.0003
[test_ind] Epoch 37: 100%|██████████| 6221/6221, Loss(agg)=1.1700, Loss(y)=1.1698, Loss(x)=0.0003
[test_ood] Epoch 37: 100%|██████████| 1319/1319, Loss(agg)=1.2920, Loss(y)=1.2913, Loss(x)=0.0007


INFO:root:saving model state - e37


Epoch 37: {'trn/loss': {'agg': 0.5722553094156786, 'y': 0.5719715231000432, 'x': 0.00028378651509893476}, 'trn/acc': {'acc1': 0.7425223692670382, 'acc2': 0.837534826403772, 'acc3': 0.8783252518216889}, 'tid/loss': {'agg': 0.9839010436882487, 'y': 0.9835757224329069, 'x': 0.0003253245146457563}, 'tid/acc': {'acc1': 0.596819241279537, 'acc2': 0.8064268204468735, 'acc3': 0.8761503375663077}, 'tod/loss': {'agg': 1.122543664752217, 'y': 1.1219183352067672, 'x': 0.0006253268156473195}}
[train] Epoch 38: 100%|██████████| 9332/9332, Loss(agg)=0.5402, Loss(y)=0.5399, Loss(x)=0.0003
[test_ind] Epoch 38: 100%|██████████| 6221/6221, Loss(agg)=0.8263, Loss(y)=0.8261, Loss(x)=0.0002
[test_ood] Epoch 38: 100%|██████████| 1319/1319, Loss(agg)=1.0429, Loss(y)=1.0427, Loss(x)=0.0002


INFO:root:saving model state - e38


Epoch 38: {'trn/loss': {'agg': 0.5712911577227031, 'y': 0.5710077833973488, 'x': 0.00028337398003169047}, 'trn/acc': {'acc1': 0.7418961637376769, 'acc2': 0.8382347031718816, 'acc3': 0.8789682008144021}, 'tid/loss': {'agg': 1.0465221109736715, 'y': 1.046276775668852, 'x': 0.00024533687798889304}, 'tid/acc': {'acc1': 0.5695326314097412, 'acc2': 0.7799891496543964, 'acc3': 0.8450309435782029}, 'tod/loss': {'agg': 1.0332671802906126, 'y': 1.03304299502051, 'x': 0.0002241876919870706}}
[train] Epoch 39: 100%|██████████| 9332/9332, Loss(agg)=0.6486, Loss(y)=0.6484, Loss(x)=0.0002
[test_ind] Epoch 39: 100%|██████████| 6221/6221, Loss(agg)=1.2940, Loss(y)=1.2935, Loss(x)=0.0005
[test_ood] Epoch 39: 100%|██████████| 1319/1319, Loss(agg)=0.7760, Loss(y)=0.7757, Loss(x)=0.0003


INFO:root:saving model state - e39


Epoch 39: {'trn/loss': {'agg': 0.5720522064432993, 'y': 0.571766279583523, 'x': 0.00028592750223904045}, 'trn/acc': {'acc1': 0.7400208958422632, 'acc2': 0.8373472996142306, 'acc3': 0.8785529629232748}, 'tid/loss': {'agg': 0.9885035778736114, 'y': 0.9881433998440416, 'x': 0.00036018044063803846}, 'tid/acc': {'acc1': 0.5775297379842469, 'acc2': 0.7886644028291272, 'acc3': 0.851576314097412}, 'tod/loss': {'agg': 0.6063512695658108, 'y': 0.6060098102365686, 'x': 0.0003414581628382867}}
[train] Epoch 40: 100%|██████████| 9332/9332, Loss(agg)=0.7204, Loss(y)=0.7201, Loss(x)=0.0003
[test_ind] Epoch 40: 100%|██████████| 6221/6221, Loss(agg)=1.5513, Loss(y)=1.5511, Loss(x)=0.0002
[test_ood] Epoch 40: 100%|██████████| 1319/1319, Loss(agg)=0.6899, Loss(y)=0.6897, Loss(x)=0.0003


INFO:root:saving model state - e40


Epoch 40: {'trn/loss': {'agg': 0.5698894178578642, 'y': 0.5696088489635253, 'x': 0.00028056913809759}, 'trn/acc': {'acc1': 0.7430447653236176, 'acc2': 0.8373472996142306, 'acc3': 0.8780607051007286}, 'tid/loss': {'agg': 1.0389621967844709, 'y': 1.0386773131702745, 'x': 0.00028488493440918274}, 'tid/acc': {'acc1': 0.56929151261855, 'acc2': 0.785705674328886, 'acc3': 0.8472512457804211}, 'tod/loss': {'agg': 0.7799374772302123, 'y': 0.7796534996894926, 'x': 0.000283977529323071}}
[train] Epoch 41: 100%|██████████| 9332/9332, Loss(agg)=0.5999, Loss(y)=0.5996, Loss(x)=0.0002
[test_ind] Epoch 41: 100%|██████████| 6221/6221, Loss(agg)=1.1477, Loss(y)=1.1475, Loss(x)=0.0002
[test_ood] Epoch 41: 100%|██████████| 1319/1319, Loss(agg)=0.8276, Loss(y)=0.8274, Loss(x)=0.0002


INFO:root:saving model state - e41


Epoch 41: {'trn/loss': {'agg': 0.5714722828331982, 'y': 0.5711932249201171, 'x': 0.00027905883888813165}, 'trn/acc': {'acc1': 0.7388455047149592, 'acc2': 0.8360279147021002, 'acc3': 0.8770862355336476}, 'tid/loss': {'agg': 1.0208790345104575, 'y': 1.0206426797314025, 'x': 0.00023635769289538888}, 'tid/acc': {'acc1': 0.5655893345121363, 'acc2': 0.7726752129882656, 'acc3': 0.8296696672560682}, 'tod/loss': {'agg': 0.7432243117244972, 'y': 0.7430185997368622, 'x': 0.00020571064816371604}}
[train] Epoch 42: 100%|██████████| 9332/9332, Loss(agg)=0.6721, Loss(y)=0.6718, Loss(x)=0.0004
[test_ind] Epoch 42: 100%|██████████| 6221/6221, Loss(agg)=1.0064, Loss(y)=1.0062, Loss(x)=0.0002
[test_ood] Epoch 42: 100%|██████████| 1319/1319, Loss(agg)=1.1406, Loss(y)=1.1404, Loss(x)=0.0002


INFO:root:saving model state - e42


Epoch 42: {'trn/loss': {'agg': 0.5698450397854959, 'y': 0.5695648916381856, 'x': 0.00028014858560493426}, 'trn/acc': {'acc1': 0.7412465173596228, 'acc2': 0.8383083744106301, 'acc3': 0.8791892145306472}, 'tid/loss': {'agg': 0.9773206169573445, 'y': 0.9770317571212234, 'x': 0.00028886302139301956}, 'tid/acc': {'acc1': 0.5769922440122167, 'acc2': 0.8034881851792316, 'acc3': 0.87364370679955}, 'tod/loss': {'agg': 1.1728680192565628, 'y': 1.1726507844574259, 'x': 0.0002172354271200643}}
[train] Epoch 43: 100%|██████████| 9332/9332, Loss(agg)=0.4884, Loss(y)=0.4881, Loss(x)=0.0002
[test_ind] Epoch 43: 100%|██████████| 6221/6221, Loss(agg)=1.0062, Loss(y)=1.0060, Loss(x)=0.0003
[test_ood] Epoch 43: 100%|██████████| 1319/1319, Loss(agg)=0.6971, Loss(y)=0.6969, Loss(x)=0.0003


INFO:root:saving model state - e43


Epoch 43: {'trn/loss': {'agg': 0.5695391774538424, 'y': 0.5692653414157395, 'x': 0.0002738362608476575}, 'trn/acc': {'acc1': 0.7427132447492499, 'acc2': 0.8378563009001286, 'acc3': 0.878519475996571}, 'tid/loss': {'agg': 0.9993853972351373, 'y': 0.9990769309440699, 'x': 0.0003084690012533843}, 'tid/acc': {'acc1': 0.5911479263783957, 'acc2': 0.7984045973316187, 'acc3': 0.8640793280823019}, 'tod/loss': {'agg': 0.7446612379964868, 'y': 0.7444072924752051, 'x': 0.0002539450943886943}}
[train] Epoch 44:  84%|████████▍ | 7856/9332, Loss(agg)=0.4963, Loss(y)=0.4960, Loss(x)=0.0003