In [None]:
%env DATA_DIR=$HOME/datasets
%env EXPERIMENT_BASE=$HOME/experiments/ood_flows
%env LOG_LEVEL=INFO
%env BATCH_SIZE=64
%env OPTIM_LR=0.001
%env OPTIM_M=0.8
%env TRAIN_EPOCHS=100
%env EXC_RESUME=1
%env DATASET_NAME=AMRB2_species
%env MANIFOLD_D=512
%env MODEL_NAME=resnet

In [None]:
import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger

from config import Config, load_config
from datasets import get_data
from models import get_model

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# initialize the RNG deterministically
np.random.seed(42)
torch.manual_seed(42)
torch.set_float32_matmul_precision('medium')

config = load_config()

# initialize data attributes and loaders
get_data(config)
config.print_labels()

assert config.datamodule

usage: ipykernel_launcher.py [-h] [--ood_k OOD_K] [--data_dir DATA_DIR]
                             [--dataset_name DATASET_NAME]
                             [--model_name MODEL_NAME]
                             [--experiment_base EXPERIMENT_BASE]
                             [--manifold_d MANIFOLD_D]
                             [--batch_size BATCH_SIZE] [--optim_lr OPTIM_LR]
                             [--optim_m OPTIM_M] [--train_epochs TRAIN_EPOCHS]
                             [--checkpoint_metric CHECKPOINT_METRIC]
                             [--image_size IMAGE_SIZE] [--scale SCALE SCALE]
                             [--train_supervised TRAIN_SUPERVISED]
                             [--temperature TEMPERATURE]
                             [--rgb_gaussian_blur_p RGB_GAUSSIAN_BLUR_P]
                             [--rgb_jitter_d RGB_JITTER_D]
                             [--rgb_jitter_p RGB_JITTER_P]
                             [--rgb_contrast RGB_CONTRAST]
                  

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [4]:
model_randinit = get_model(config)

In [5]:
from pathlib import Path
artifact_dir = WandbLogger.download_artifact(artifact="yasith/uq_project/model-33t7jols:best")
model_pretrain = model_randinit.load_from_checkpoint(Path(artifact_dir) / "model.ckpt", config=config)
model_pretrain.eval()
del model_randinit

[34m[1mwandb[0m: Downloading large artifact model-33t7jols:best, 58.83MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:2.6


In [6]:
assert config.datamodule
config.datamodule.setup("test")

Performing ind/ood split


100%|██████████| 259184/259184 [00:08<00:00, 30466.55it/s]


Train - OK


100%|██████████| 64795/64795 [00:02<00:00, 30501.12it/s]


Val - OK


100%|██████████| 215989/215989 [00:06<00:00, 30970.87it/s]

Test - OK
Performed ind/ood split
259184 64795 215989 0





In [7]:
test_loader = config.datamodule.test_dataloader()

In [8]:
from torchmetrics import Accuracy
accuracy = Accuracy(task="multiclass", num_classes=len(config.get_ind_labels())).cuda()

In [9]:
from models.common import edl_probs
from tqdm.auto import tqdm

# accuracy.reset()
classifier_loss = "edl"
for batch_idx, batch in enumerate(tqdm(test_loader)):
    x, y = batch
    print(x.size(), y.size())
#     x = x.cuda().float()
#     y = y.cuda().long()
    
#     z, logits, x_pred = model_pretrain(x)
    
#     # classifier loss
#     if classifier_loss == "edl":
#         pY, uY = edl_probs(logits)
#     elif classifier_loss == "crossent":
#         pY = logits.softmax(-1)
#         uY = 1.0 - pY.max(-1)
#     elif classifier_loss == "margin":
#         pY = logits.sigmoid()
#         uY = 1.0 - pY.max(-1)
#     else:
#         raise ValueError(classifier_loss)
#     accuracy.update(pY, y)
    
# print(accuracy.compute())

  2%|▏         | 84/3375 [00:01<00:34, 96.43it/s] 

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

  7%|▋         | 252/3375 [00:01<00:10, 306.17it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 13%|█▎        | 424/3375 [00:01<00:05, 504.91it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 18%|█▊        | 602/3375 [00:01<00:04, 662.00it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 23%|██▎       | 774/3375 [00:02<00:03, 747.74it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 28%|██▊       | 942/3375 [00:02<00:03, 785.96it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 33%|███▎      | 1110/3375 [00:02<00:02, 803.60it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 38%|███▊      | 1279/3375 [00:02<00:02, 811.57it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 43%|████▎     | 1448/3375 [00:02<00:02, 826.94it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 48%|████▊     | 1616/3375 [00:03<00:02, 828.18it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 53%|█████▎    | 1785/3375 [00:03<00:01, 817.48it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 58%|█████▊    | 1959/3375 [00:03<00:01, 842.58it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 63%|██████▎   | 2133/3375 [00:03<00:01, 853.57it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 68%|██████▊   | 2308/3375 [00:03<00:01, 857.26it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 73%|███████▎  | 2478/3375 [00:04<00:01, 832.76it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 78%|███████▊  | 2646/3375 [00:04<00:00, 830.36it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 83%|████████▎ | 2816/3375 [00:04<00:00, 837.69it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 88%|████████▊ | 2986/3375 [00:04<00:00, 829.55it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 93%|█████████▎| 3154/3375 [00:04<00:00, 822.03it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

 98%|█████████▊| 3319/3375 [00:05<00:00, 806.74it/s]

torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size([64, 1, 40, 40]) torch.Size([64])
torch.Size

100%|██████████| 3375/3375 [00:05<00:00, 650.54it/s]
