In [1]:
from google.colab import drive
# Mount your google drive to the /content/drive directory.
# This ensures that files changes remain saved in your drive.
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
%cd /content/drive/MyDrive/cpen455-project/

/content/drive/MyDrive/cpen455-project


In [3]:
!pip install -r requirements.txt

Collecting bidict (from -r requirements.txt (line 1))
  Downloading bidict-0.23.1-py3-none-any.whl.metadata (8.7 kB)
Collecting pytorch-fid (from -r requirements.txt (line 7))
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting dotenv (from -r requirements.txt (line 11))
  Downloading dotenv-0.9.9-py2.py3-none-any.whl.metadata (279 bytes)
Collecting python-dotenv (from dotenv->-r requirements.txt (line 11))
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r r

In [4]:
import time
import os
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import wandb
from tqdm import tqdm
from pprint import pprint
import argparse
from pytorch_fid.fid_score import calculate_fid_given_paths
import logging
import pandas as pd
import numpy as np

import dotenv
dotenv.load_dotenv()

True

In [5]:
import importlib
import utils, model, dataset, generation_evaluation, classification_evaluation

importlib.reload(utils)  # Reload the script after changes
importlib.reload(model)  # Reload the script after changes
importlib.reload(dataset)  # Reload the script after changes
importlib.reload(generation_evaluation)  # Reload the script after changes
importlib.reload(classification_evaluation)  # Reload the script after changes

from utils import *
from model import *
from dataset import *
import generation_evaluation
import classification_evaluation

In [None]:
import logging
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.WARN)

In [None]:
from dataclasses import dataclass

@dataclass
class Config:
    en_wandb: bool = False
    tag: str = "default"
    sampling_interval: int = 5
    data_dir: str = "data"
    save_dir: str = "models"
    sample_dir: str = "samples"
    dataset: str = "cpen455"
    save_interval: int = 10
    load_params: str = None
    obs: tuple = (3, 32, 32)
    nr_resnet: int = 1
    nr_filters: int = 40
    nr_logistic_mix: int = 5
    lr: float = 0.0002
    lr_decay: float = 0.999995
    batch_size: int = 64
    sample_batch_size: int = 32
    base_epoch: int = 0
    max_epochs: int = 5000
    seed: int = 1

In [None]:
args = Config(
    batch_size=16,
    sample_batch_size=16,
    sampling_interval=5,
    save_interval=10,
    max_epochs=50,
    en_wandb=True,
    nr_filters=40,
    nr_logistic_mix=5,
    nr_resnet=5,
    tag="Middle_f40_l5_r5",
    sample_dir="samples_f40_l5_r5",
    save_dir="models_f40_l5_r5",
)

pprint(args.__dict__)
check_dir_and_create(args.save_dir)

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'pcnn_' + args.dataset + "_"
model_path = args.save_dir + '/'
if args.load_params is not None:
    model_name = model_name + 'load_model'
    model_path = model_path + model_name + '/'
else:
    model_name = model_name + 'from_scratch'
    model_path = model_path + model_name + '/'

job_name = "PCNN_Training_" + "dataset:" + args.dataset + "_" + args.tag

if args.en_wandb:
    # start a new wandb run to track this script
    wandb.init(
        # set entity to specify your username or team name
        # entity="qihangz-work",
        # set the wandb project where this run will be logged
        project="CPEN455HW",
        # group=Group Name
        name=job_name,
    )
    wandb.config.current_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
    wandb.config.update(args)

#set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Reminder: if you have patience to read code line by line, you should notice this comment. here is the reason why we set num_workers to 0:
#In order to avoid pickling errors with the dataset on different machines, we set num_workers to 0.
#If you are using ubuntu/linux/colab, and find that loading data is too slow, you can set num_workers to 1 or even bigger.
kwargs = {'num_workers':2, 'pin_memory':True, 'drop_last':True}

{'batch_size': 16,
 'data_dir': 'data',
 'dataset': 'cpen455',
 'en_wandb': True,
 'load_params': None,
 'lr': 0.0002,
 'lr_decay': 0.999995,
 'max_epochs': 50,
 'nr_filters': 40,
 'nr_logistic_mix': 5,
 'nr_resnet': 5,
 'obs': (3, 32, 32),
 'sample_batch_size': 16,
 'sample_dir': 'samples_f40_l5_r5',
 'sampling_interval': 5,
 'save_dir': 'models_f40_l5_r5',
 'save_interval': 10,
 'seed': 1,
 'tag': 'Middle_f40_l5_r5'}


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mparshan-pjavanrood[0m ([33mparshan-pjavanrood-university-of-british-columbia[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
# set data
if "mnist" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), rescaling, replicate_color_channel])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

elif "cifar" in args.dataset:
    ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])
    if args.dataset == "cifar10":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    elif args.dataset == "cifar100":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    else:
        raise Exception('{} dataset not in {cifar10, cifar100}'.format(args.dataset))

elif "cpen455" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), rescaling])
    train_loader = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'train',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    test_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'test',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    val_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'validation',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
else:
    raise Exception('{} dataset not in {mnist, cifar, cpen455}'.format(args.dataset))

In [None]:
def train_or_test(model, data_loader, optimizer, loss_op, device, args, epoch, mode = 'training'):
    logging.debug('mode: {}'.format(mode))
    if mode == 'training':
        model.train()
    else:
        model.eval()

    deno =  args.batch_size * np.prod(args.obs) * np.log(2.)
    loss_tracker = mean_tracker()

    for batch_idx, item in enumerate(tqdm(data_loader)):
        logging.debug('batch_idx: {}'.format(batch_idx))
        model_input, label = item
        model_input = model_input.to(device)
        model_output = model(model_input, label)

        loss = loss_op(model_input, model_output)
        loss_tracker.update(loss.item()/deno)
        if mode == 'training':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        elif mode == 'val':
            pred_labels = classification_evaluation.get_label(model, model_input, device)
            org_labels = torch.tensor([my_bidict[l] for l in label], device=device)
            acc = torch.sum(org_labels == pred_labels) / len(label)
            if args.en_wandb:
                wandb.log({mode + "-Classification-Acc" : acc})

    if args.en_wandb:
        wandb.log({mode + "-Average-BPD" : loss_tracker.get_mean()})
        wandb.log({mode + "-epoch": epoch})


In [None]:
args.obs = (3, 32, 32)
input_channels = args.obs[0]

loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake)
sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix)

In [None]:
## Start from scratch

model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
            input_channels=input_channels, nr_logistic_mix=args.nr_logistic_mix)
model = model.to(device)

if args.load_params:
    model.load_state_dict(torch.load(args.load_params))
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

for epoch in tqdm(range(args.max_epochs)):
    train_or_test(model = model,
                  data_loader = train_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'training')

    # decrease learning rate
    scheduler.step()
    train_or_test(model = model,
                  data_loader = test_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'test')

    train_or_test(model = model,
                  data_loader = val_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'val')

    if epoch % args.sampling_interval == 0:
        print('......sampling......')
        sampled_images = generation_evaluation.my_sample(model, args.sample_dir, args.sample_batch_size, args.obs, sample_op)
        sample_result = {label: wandb.Image(img, caption="epoch {} - label {}".format(epoch, label)) for label, img in sampled_images.items()}

        gen_data_dir = args.sample_dir
        ref_data_dir = args.data_dir +'/test'
        paths = [gen_data_dir, ref_data_dir]
        try:
            fid_score = calculate_fid_given_paths(paths, 32, device, dims=192)
            print("Dimension {:d} works! fid score: {}".format(192, fid_score))
        except:
            print("Dimension {:d} fails!".format(192))

        if args.en_wandb:
          for label, img in sample_result.items():
            wandb.log({f"samples_{label}": img,
                        "FID": fid_score})

    if (epoch + 1) % args.save_interval == 0:
        if not os.path.exists("models"):
            os.makedirs("models")
        torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))


  WeightNorm.apply(module, name, dim)
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:46,  1.55it/s][A
  1%|          | 2/259 [00:00<01:42,  2.50it/s][A
  1%|          | 3/259 [00:01<01:19,  3.23it/s][A
  2%|▏         | 4/259 [00:01<01:54,  2.23it/s][A
  2%|▏         | 5/259 [00:01<01:32,  2.74it/s][A
  2%|▏         | 6/259 [00:02<01:17,  3.28it/s][A
  3%|▎         | 7/259 [00:02<01:08,  3.70it/s][A
  3%|▎         | 8/259 [00:02<01:01,  4.09it/s][A
  3%|▎         | 9/259 [00:02<00:56,  4.40it/s][A
  4%|▍         | 10/259 [00:02<00:55,  4.47it/s][A
  4%|▍         | 11/259 [00:03<00:53,  4.65it/s][A
  5%|▍         | 12/259 [00:03<00:51,  4.81it/s][A
  5%|▌         | 13/259 [00:03<00:50,  4.85it/s][A
  5%|▌         | 14/259 [00:03<00:50,  4.90it/s][A
  6%|▌         | 15/259 [00:03<00:50,  4.86it/s][A
  6%|▌         | 16/259 [00:04<00:49,  4.92it/s][A
  7%|▋         | 17/259 [00:04<00:48,  4.98it/s][A
  7%|

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth

  0%|          | 0.00/91.2M [00:00<?, ?B/s][A
 20%|██        | 18.4M/91.2M [00:00<00:00, 192MB/s][A
 52%|█████▏    | 47.2M/91.2M [00:00<00:00, 257MB/s][A
100%|██████████| 91.2M/91.2M [00:00<00:00, 227MB/s]

  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.73it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:00<00:03,  4.12it/s][A
 12%|█▏        | 2/17 [00:01<00:09,  1.58it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.44it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  5.14it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  6.73it/s][A
 59%|█████▉    | 10/17 [00:01<00:00,  8.01it/s][A
 71%|███████   | 12/17 [00:01<00:00,  9.00it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.72it/s][A
100%|██████████| 17/17 [00:02<00:00,  7.12it/s]


Dimension 192 works! fid score: 28.855192436091542


  2%|▏         | 1/50 [06:25<5:15:00, 385.73s/it]
  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:17,  1.87it/s][A
  1%|          | 2/259 [00:00<01:39,  2.59it/s][A
  1%|          | 3/259 [00:01<01:24,  3.03it/s][A
  2%|▏         | 4/259 [00:01<01:16,  3.32it/s][A
  2%|▏         | 5/259 [00:01<01:13,  3.45it/s][A
  2%|▏         | 6/259 [00:01<01:11,  3.54it/s][A
  3%|▎         | 7/259 [00:02<01:11,  3.53it/s][A
  3%|▎         | 8/259 [00:02<01:11,  3.53it/s][A
  3%|▎         | 9/259 [00:02<01:09,  3.60it/s][A
  4%|▍         | 10/259 [00:02<01:06,  3.72it/s][A
  4%|▍         | 11/259 [00:03<01:08,  3.64it/s][A
  5%|▍         | 12/259 [00:03<01:05,  3.75it/s][A
  5%|▌         | 13/259 [00:03<01:05,  3.76it/s][A
  5%|▌         | 14/259 [00:04<01:05,  3.73it/s][A
  6%|▌         | 15/259 [00:04<01:06,  3.65it/s][A
  6%|▌         | 16/259 [00:04<01:04,  3.78it/s][A
  7%|▋         | 17/259 [00:04<00:58,  4.11it/s][A
  7%|▋         | 18/259 [00:04<0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  6.41it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.67it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:19,  1.25s/it][A
 18%|█▊        | 3/17 [00:01<00:05,  2.55it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.22it/s][A
 41%|████      | 7/17 [00:01<00:01,  5.77it/s][A
 53%|█████▎    | 9/17 [00:01<00:01,  7.14it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.24it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  9.22it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  9.83it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.48it/s]
 12%|█▏        | 6/50 [17:11<2:15:25, 184.66s/it]

Dimension 192 works! fid score: 50.69169767812336



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:00,  2.14it/s][A
  1%|          | 2/259 [00:00<01:24,  3.05it/s][A
  1%|          | 3/259 [00:00<01:08,  3.74it/s][A
  2%|▏         | 4/259 [00:01<01:02,  4.10it/s][A
  2%|▏         | 5/259 [00:01<00:58,  4.36it/s][A
  2%|▏         | 6/259 [00:01<00:55,  4.59it/s][A
  3%|▎         | 7/259 [00:01<00:54,  4.64it/s][A
  3%|▎         | 8/259 [00:01<00:58,  4.31it/s][A
  3%|▎         | 9/259 [00:02<00:57,  4.32it/s][A
  4%|▍         | 10/259 [00:02<00:54,  4.54it/s][A
  4%|▍         | 11/259 [00:02<00:52,  4.70it/s][A
  5%|▍         | 12/259 [00:02<00:52,  4.72it/s][A
  5%|▌         | 13/259 [00:03<00:51,  4.78it/s][A
  5%|▌         | 14/259 [00:03<00:52,  4.67it/s][A
  6%|▌         | 15/259 [00:03<00:51,  4.76it/s][A
  6%|▌         | 16/259 [00:03<00:50,  4.83it/s][A
  7%|▋         | 17/259 [00:03<00:50,  4.77it/s][A
  7%|▋         | 18/259 [00:04<00:50,  4.80it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.90it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:21,  1.32s/it][A
 12%|█▏        | 2/17 [00:01<00:09,  1.66it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.57it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  5.26it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  6.73it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.98it/s][A
 71%|███████   | 12/17 [00:02<00:00,  8.95it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.60it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.33it/s]
 22%|██▏       | 11/50 [28:03<1:53:51, 175.15s/it]

Dimension 192 works! fid score: 47.95119107638472



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:58,  2.18it/s][A
  1%|          | 2/259 [00:00<01:21,  3.15it/s][A
  1%|          | 3/259 [00:00<01:08,  3.74it/s][A
  2%|▏         | 4/259 [00:01<01:00,  4.19it/s][A
  2%|▏         | 5/259 [00:01<01:02,  4.09it/s][A
  2%|▏         | 6/259 [00:01<00:57,  4.37it/s][A
  3%|▎         | 7/259 [00:01<00:55,  4.57it/s][A
  3%|▎         | 8/259 [00:01<00:54,  4.60it/s][A
  3%|▎         | 9/259 [00:02<00:53,  4.67it/s][A
  4%|▍         | 10/259 [00:02<00:52,  4.73it/s][A
  4%|▍         | 11/259 [00:02<00:51,  4.81it/s][A
  5%|▍         | 12/259 [00:02<00:50,  4.87it/s][A
  5%|▌         | 13/259 [00:02<00:50,  4.83it/s][A
  5%|▌         | 14/259 [00:03<00:50,  4.87it/s][A
  6%|▌         | 15/259 [00:03<00:49,  4.90it/s][A
  6%|▌         | 16/259 [00:03<00:48,  4.96it/s][A
  7%|▋         | 17/259 [00:03<00:48,  4.96it/s][A
  7%|▋         | 18/259 [00:04<00:53,  4.50it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.40it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.84it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:21,  1.31s/it][A
 18%|█▊        | 3/17 [00:01<00:05,  2.46it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.10it/s][A
 41%|████      | 7/17 [00:01<00:01,  5.61it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.96it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.93it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  8.20it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  8.35it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.09it/s]
 32%|███▏      | 16/50 [38:57<1:38:46, 174.31s/it]

Dimension 192 works! fid score: 25.618099123873783



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:02,  2.11it/s][A
  1%|          | 2/259 [00:00<01:24,  3.05it/s][A
  1%|          | 3/259 [00:00<01:09,  3.69it/s][A
  2%|▏         | 4/259 [00:01<01:01,  4.12it/s][A
  2%|▏         | 5/259 [00:01<00:58,  4.32it/s][A
  2%|▏         | 6/259 [00:01<00:55,  4.55it/s][A
  3%|▎         | 7/259 [00:01<00:54,  4.62it/s][A
  3%|▎         | 8/259 [00:01<00:54,  4.64it/s][A
  3%|▎         | 9/259 [00:02<00:52,  4.73it/s][A
  4%|▍         | 10/259 [00:02<00:51,  4.80it/s][A
  4%|▍         | 11/259 [00:02<00:51,  4.83it/s][A
  5%|▍         | 12/259 [00:02<00:51,  4.79it/s][A
  5%|▌         | 13/259 [00:02<00:51,  4.80it/s][A
  5%|▌         | 14/259 [00:03<00:50,  4.85it/s][A
  6%|▌         | 15/259 [00:03<00:50,  4.83it/s][A
  6%|▌         | 16/259 [00:03<00:49,  4.92it/s][A
  7%|▋         | 17/259 [00:03<00:50,  4.80it/s][A
  7%|▋         | 18/259 [00:03<00:50,  4.81it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.72it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:20,  1.26s/it][A
 18%|█▊        | 3/17 [00:01<00:05,  2.57it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.29it/s][A
 41%|████      | 7/17 [00:01<00:01,  5.82it/s][A
 53%|█████▎    | 9/17 [00:01<00:01,  7.17it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.26it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  9.06it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  9.75it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.47it/s]
 42%|████▏     | 21/50 [49:56<1:24:24, 174.65s/it]

Dimension 192 works! fid score: 37.76128004041195



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:58,  2.17it/s][A
  1%|          | 2/259 [00:00<01:21,  3.14it/s][A
  1%|          | 3/259 [00:00<01:09,  3.70it/s][A
  2%|▏         | 4/259 [00:01<01:01,  4.15it/s][A
  2%|▏         | 5/259 [00:01<00:58,  4.34it/s][A
  2%|▏         | 6/259 [00:01<00:55,  4.53it/s][A
  3%|▎         | 7/259 [00:01<00:54,  4.60it/s][A
  3%|▎         | 8/259 [00:01<00:53,  4.68it/s][A
  3%|▎         | 9/259 [00:02<00:52,  4.73it/s][A
  4%|▍         | 10/259 [00:02<00:52,  4.75it/s][A
  4%|▍         | 11/259 [00:02<00:51,  4.78it/s][A
  5%|▍         | 12/259 [00:02<00:51,  4.79it/s][A
  5%|▌         | 13/259 [00:02<00:51,  4.80it/s][A
  5%|▌         | 14/259 [00:03<00:50,  4.88it/s][A
  6%|▌         | 15/259 [00:03<00:56,  4.28it/s][A
  6%|▌         | 16/259 [00:03<00:59,  4.11it/s][A
  7%|▋         | 17/259 [00:04<01:01,  3.94it/s][A
  7%|▋         | 18/259 [00:04<01:01,  3.94it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.73it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:18,  1.16s/it][A
 12%|█▏        | 2/17 [00:01<00:08,  1.85it/s][A
 18%|█▊        | 3/17 [00:01<00:04,  2.94it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  4.07it/s][A
 35%|███▌      | 6/17 [00:01<00:01,  6.20it/s][A
 41%|████      | 7/17 [00:01<00:01,  6.91it/s][A
 53%|█████▎    | 9/17 [00:01<00:00,  8.21it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  9.28it/s][A
 76%|███████▋  | 13/17 [00:02<00:00, 10.06it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.57it/s]
 52%|█████▏    | 26/50 [1:00:53<1:09:44, 174.34s/it]

Dimension 192 works! fid score: 45.656611259220234



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:05,  2.06it/s][A
  1%|          | 2/259 [00:00<01:24,  3.04it/s][A
  1%|          | 3/259 [00:00<01:15,  3.39it/s][A
  2%|▏         | 4/259 [00:01<01:13,  3.47it/s][A
  2%|▏         | 5/259 [00:01<01:10,  3.58it/s][A
  2%|▏         | 6/259 [00:01<01:09,  3.66it/s][A
  3%|▎         | 7/259 [00:02<01:09,  3.64it/s][A
  3%|▎         | 8/259 [00:02<01:10,  3.57it/s][A
  3%|▎         | 9/259 [00:02<01:09,  3.58it/s][A
  4%|▍         | 10/259 [00:02<01:08,  3.62it/s][A
  4%|▍         | 11/259 [00:03<01:08,  3.59it/s][A
  5%|▍         | 12/259 [00:03<01:08,  3.58it/s][A
  5%|▌         | 13/259 [00:03<01:10,  3.50it/s][A
  5%|▌         | 14/259 [00:04<01:09,  3.54it/s][A
  6%|▌         | 15/259 [00:04<01:09,  3.52it/s][A
  6%|▌         | 16/259 [00:04<01:10,  3.43it/s][A
  7%|▋         | 17/259 [00:04<01:10,  3.41it/s][A
  7%|▋         | 18/259 [00:05<01:08,  3.54it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.88it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.47it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:24,  1.52s/it][A
 18%|█▊        | 3/17 [00:01<00:06,  2.17it/s][A
 29%|██▉       | 5/17 [00:01<00:03,  3.67it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.15it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.45it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.64it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  8.57it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  9.37it/s][A
100%|██████████| 17/17 [00:02<00:00,  5.81it/s]
 62%|██████▏   | 31/50 [1:11:56<55:32, 175.39s/it]

Dimension 192 works! fid score: 34.168649044347056



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:06,  2.03it/s][A
  1%|          | 2/259 [00:00<01:28,  2.90it/s][A
  1%|          | 3/259 [00:00<01:12,  3.53it/s][A
  2%|▏         | 4/259 [00:01<01:04,  3.96it/s][A
  2%|▏         | 5/259 [00:01<01:00,  4.18it/s][A
  2%|▏         | 6/259 [00:01<00:57,  4.40it/s][A
  3%|▎         | 7/259 [00:01<00:56,  4.48it/s][A
  3%|▎         | 8/259 [00:01<00:55,  4.56it/s][A
  3%|▎         | 9/259 [00:02<00:54,  4.61it/s][A
  4%|▍         | 10/259 [00:02<00:53,  4.69it/s][A
  4%|▍         | 11/259 [00:02<00:52,  4.73it/s][A
  5%|▍         | 12/259 [00:02<00:52,  4.68it/s][A
  5%|▌         | 13/259 [00:03<00:52,  4.72it/s][A
  5%|▌         | 14/259 [00:03<00:52,  4.69it/s][A
  6%|▌         | 15/259 [00:03<00:51,  4.78it/s][A
  6%|▌         | 16/259 [00:03<00:50,  4.86it/s][A
  7%|▋         | 17/259 [00:03<00:51,  4.73it/s][A
  7%|▋         | 18/259 [00:04<00:50,  4.80it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.65it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:21,  1.35s/it][A
 12%|█▏        | 2/17 [00:01<00:09,  1.62it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.51it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  5.17it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  6.68it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.74it/s][A
 71%|███████   | 12/17 [00:02<00:00,  8.68it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.39it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.16it/s]
 72%|███████▏  | 36/50 [1:23:05<41:15, 176.80s/it]

Dimension 192 works! fid score: 45.9892818955644



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<03:03,  1.40it/s][A
  1%|          | 2/259 [00:01<02:03,  2.07it/s][A
  1%|          | 3/259 [00:01<01:40,  2.54it/s][A
  2%|▏         | 4/259 [00:01<01:27,  2.93it/s][A
  2%|▏         | 5/259 [00:01<01:21,  3.11it/s][A
  2%|▏         | 6/259 [00:02<01:16,  3.29it/s][A
  3%|▎         | 7/259 [00:02<01:13,  3.44it/s][A
  3%|▎         | 8/259 [00:02<01:10,  3.54it/s][A
  3%|▎         | 9/259 [00:02<01:09,  3.59it/s][A
  4%|▍         | 10/259 [00:03<01:09,  3.56it/s][A
  4%|▍         | 11/259 [00:03<01:09,  3.55it/s][A
  5%|▍         | 12/259 [00:03<01:12,  3.41it/s][A
  5%|▌         | 13/259 [00:04<01:11,  3.44it/s][A
  5%|▌         | 14/259 [00:04<01:04,  3.79it/s][A
  6%|▌         | 15/259 [00:04<01:00,  4.02it/s][A
  6%|▌         | 16/259 [00:04<00:58,  4.19it/s][A
  7%|▋         | 17/259 [00:04<00:55,  4.34it/s][A
  7%|▋         | 18/259 [00:05<00:54,  4.42it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  6.11it/s][A
100%|██████████| 2/2 [00:00<00:00,  5.98it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:22,  1.41s/it][A
 12%|█▏        | 2/17 [00:01<00:09,  1.55it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.35it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  4.94it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  6.36it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.53it/s][A
 71%|███████   | 12/17 [00:02<00:00,  8.50it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.17it/s][A
100%|██████████| 17/17 [00:02<00:00,  5.96it/s]
 82%|████████▏ | 41/50 [1:34:22<26:45, 178.43s/it]

Dimension 192 works! fid score: 27.35354534483524



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:06,  2.04it/s][A
  1%|          | 2/259 [00:00<01:26,  2.97it/s][A
  1%|          | 3/259 [00:00<01:13,  3.51it/s][A
  2%|▏         | 4/259 [00:01<01:05,  3.88it/s][A
  2%|▏         | 5/259 [00:01<01:02,  4.09it/s][A
  2%|▏         | 6/259 [00:01<00:59,  4.27it/s][A
  3%|▎         | 7/259 [00:01<00:57,  4.40it/s][A
  3%|▎         | 8/259 [00:02<00:56,  4.44it/s][A
  3%|▎         | 9/259 [00:02<00:55,  4.50it/s][A
  4%|▍         | 10/259 [00:02<00:55,  4.51it/s][A
  4%|▍         | 11/259 [00:02<00:59,  4.16it/s][A
  5%|▍         | 12/259 [00:03<01:03,  3.88it/s][A
  5%|▌         | 13/259 [00:03<01:06,  3.69it/s][A
  5%|▌         | 14/259 [00:03<01:05,  3.76it/s][A
  6%|▌         | 15/259 [00:03<01:07,  3.63it/s][A
  6%|▌         | 16/259 [00:04<01:10,  3.45it/s][A
  7%|▋         | 17/259 [00:04<01:11,  3.40it/s][A
  7%|▋         | 18/259 [00:04<01:09,  3.46it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.81it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.65it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:21,  1.35s/it][A
 12%|█▏        | 2/17 [00:01<00:09,  1.61it/s][A
 18%|█▊        | 3/17 [00:01<00:05,  2.58it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.59it/s][A
 41%|████      | 7/17 [00:01<00:01,  6.18it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  7.55it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.61it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  9.38it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  9.75it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.06it/s]
 92%|█████████▏| 46/50 [1:45:33<11:48, 177.22s/it]

Dimension 192 works! fid score: 32.563606188866316



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:11,  1.96it/s][A
  1%|          | 2/259 [00:00<01:30,  2.83it/s][A
  1%|          | 3/259 [00:00<01:14,  3.44it/s][A
  2%|▏         | 4/259 [00:01<01:05,  3.89it/s][A
  2%|▏         | 5/259 [00:01<01:01,  4.10it/s][A
  2%|▏         | 6/259 [00:01<01:00,  4.20it/s][A
  3%|▎         | 7/259 [00:01<00:58,  4.32it/s][A
  3%|▎         | 8/259 [00:02<00:57,  4.37it/s][A
  3%|▎         | 9/259 [00:02<00:56,  4.45it/s][A
  4%|▍         | 10/259 [00:02<01:00,  4.12it/s][A
  4%|▍         | 11/259 [00:02<00:57,  4.31it/s][A
  5%|▍         | 12/259 [00:02<00:56,  4.38it/s][A
  5%|▌         | 13/259 [00:03<00:55,  4.46it/s][A
  5%|▌         | 14/259 [00:03<00:54,  4.52it/s][A
  6%|▌         | 15/259 [00:03<00:53,  4.57it/s][A
  6%|▌         | 16/259 [00:03<00:52,  4.65it/s][A
  7%|▋         | 17/259 [00:04<00:52,  4.65it/s][A
  7%|▋         | 18/259 [00:04<00:52,  4.63it/s][A
  7%|▋         | 19/259 [00:0

In [None]:
### CONTINUE TRAINING

base_epoch = 50

model = PixelCNN(nr_resnet=5, nr_filters=80,
            input_channels=3, nr_logistic_mix=5)
model.load_state_dict(torch.load('models_backup_f40_l5_r5/pcnn_cpen455_f80_l5_r5.pth'))
model = model.to(device)

if args.load_params:
    model.load_state_dict(torch.load(args.load_params))
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)


for epoch in tqdm(range(base_epoch, base_epoch + args.max_epochs)):
    train_or_test(model = model,
                  data_loader = train_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'training')

    # decrease learning rate
    scheduler.step()
    train_or_test(model = model,
                  data_loader = test_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'test')

    train_or_test(model = model,
                  data_loader = val_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'val')

    if epoch % args.sampling_interval == 0:
        print('......sampling......')
        sampled_images = generation_evaluation.my_sample(model, args.sample_dir, args.sample_batch_size, args.obs, sample_op)
        sample_result = {label: wandb.Image(img, caption="epoch {} - label {}".format(epoch, label)) for label, img in sampled_images.items()}

        gen_data_dir = args.sample_dir
        ref_data_dir = args.data_dir +'/test'
        paths = [gen_data_dir, ref_data_dir]
        try:
            fid_score = calculate_fid_given_paths(paths, 32, device, dims=192)
            print("Dimension {:d} works! fid score: {}".format(192, fid_score))
        except:
            print("Dimension {:d} fails!".format(192))

        if args.en_wandb:
          for label, img in sample_result.items():
            wandb.log({f"samples_{label}": img,
                        "FID": fid_score})

    if (epoch + 1) % args.save_interval == 0:
        if not os.path.exists("models"):
            os.makedirs("models")
        torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))


  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:01<08:02,  1.87s/it][A
  1%|          | 2/259 [00:02<04:17,  1.00s/it][A
  1%|          | 3/259 [00:02<02:41,  1.59it/s][A
  2%|▏         | 4/259 [00:02<01:56,  2.19it/s][A
  2%|▏         | 5/259 [00:02<01:31,  2.79it/s][A
  2%|▏         | 6/259 [00:03<01:15,  3.35it/s][A
  3%|▎         | 7/259 [00:03<01:05,  3.82it/s][A
  3%|▎         | 8/259 [00:03<01:00,  4.18it/s][A
  3%|▎         | 9/259 [00:03<00:55,  4.54it/s][A
  4%|▍         | 10/259 [00:03<00:52,  4.73it/s][A
  4%|▍         | 11/259 [00:03<00:50,  4.87it/s][A
  5%|▍         | 12/259 [00:04<00:49,  5.04it/s][A
  5%|▌         | 13/259 [00:04<00:47,  5.16it/s][A
  5%|▌         | 14/259 [00:04<00:48,  5.08it/s][A
  6%|▌         | 15/259 [00:04<00:47,  5.10it/s][A
  6%|▌         | 16/259 [00:04<00:47,  5.16it/s][A
  7%|▋         | 17/259 [00:05<00:46,  5.19it/s][A
  7%|▋         | 18/259 [00:05<00:45,  5.31

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.91it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:20,  1.27s/it][A
 18%|█▊        | 3/17 [00:01<00:05,  2.55it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.30it/s][A
 41%|████      | 7/17 [00:01<00:01,  5.91it/s][A
 53%|█████▎    | 9/17 [00:01<00:01,  7.38it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.60it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  9.27it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  9.78it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.54it/s]
  2%|▏         | 1/50 [05:41<4:38:58, 341.61s/it]

Dimension 192 works! fid score: 40.82594718093446



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:48,  2.39it/s][A
  1%|          | 2/259 [00:00<01:13,  3.50it/s][A
  1%|          | 3/259 [00:00<01:03,  4.06it/s][A
  2%|▏         | 4/259 [00:00<00:56,  4.50it/s][A
  2%|▏         | 5/259 [00:01<00:53,  4.74it/s][A
  2%|▏         | 6/259 [00:01<00:52,  4.77it/s][A
  3%|▎         | 7/259 [00:01<00:51,  4.90it/s][A
  3%|▎         | 8/259 [00:01<00:49,  5.05it/s][A
  3%|▎         | 9/259 [00:01<00:49,  5.09it/s][A
  4%|▍         | 10/259 [00:02<00:48,  5.15it/s][A
  4%|▍         | 11/259 [00:02<00:49,  5.05it/s][A
  5%|▍         | 12/259 [00:02<00:47,  5.17it/s][A
  5%|▌         | 13/259 [00:02<00:48,  5.09it/s][A
  5%|▌         | 14/259 [00:02<00:47,  5.14it/s][A
  6%|▌         | 15/259 [00:03<00:47,  5.19it/s][A
  6%|▌         | 16/259 [00:03<00:47,  5.10it/s][A
  7%|▋         | 17/259 [00:03<00:53,  4.53it/s][A
  7%|▋         | 18/259 [00:03<00:54,  4.40it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  5.98it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.43it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:21,  1.34s/it][A
 12%|█▏        | 2/17 [00:01<00:09,  1.63it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.54it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  5.28it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  6.77it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  8.05it/s][A
 71%|███████   | 12/17 [00:02<00:00,  9.07it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.72it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.33it/s]
 12%|█▏        | 6/50 [16:21<2:12:47, 181.08s/it]

Dimension 192 works! fid score: 28.98568260160679



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:50,  2.33it/s][A
  1%|          | 2/259 [00:00<01:16,  3.35it/s][A
  1%|          | 3/259 [00:00<01:04,  3.96it/s][A
  2%|▏         | 4/259 [00:01<01:00,  4.23it/s][A
  2%|▏         | 5/259 [00:01<00:56,  4.48it/s][A
  2%|▏         | 6/259 [00:01<01:01,  4.15it/s][A
  3%|▎         | 7/259 [00:01<01:01,  4.10it/s][A
  3%|▎         | 8/259 [00:02<01:01,  4.11it/s][A
  3%|▎         | 9/259 [00:02<01:01,  4.04it/s][A
  4%|▍         | 10/259 [00:02<01:04,  3.84it/s][A
  4%|▍         | 11/259 [00:02<01:04,  3.85it/s][A
  5%|▍         | 12/259 [00:03<01:04,  3.81it/s][A
  5%|▌         | 13/259 [00:03<01:05,  3.78it/s][A
  5%|▌         | 14/259 [00:03<01:02,  3.91it/s][A
  6%|▌         | 15/259 [00:03<01:01,  3.95it/s][A
  6%|▌         | 16/259 [00:04<01:03,  3.83it/s][A
  7%|▋         | 17/259 [00:04<01:01,  3.93it/s][A
  7%|▋         | 18/259 [00:04<01:03,  3.81it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.37it/s][A
100%|██████████| 2/2 [00:00<00:00,  5.01it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:30,  1.90s/it][A
 12%|█▏        | 2/17 [00:02<00:13,  1.14it/s][A
 24%|██▎       | 4/17 [00:02<00:05,  2.57it/s][A
 29%|██▉       | 5/17 [00:02<00:03,  3.28it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.03it/s][A
 41%|████      | 7/17 [00:02<00:02,  4.86it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  5.51it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.19it/s][A
 59%|█████▉    | 10/17 [00:02<00:01,  6.73it/s][A
 65%|██████▍   | 11/17 [00:03<00:00,  7.10it/s][A
 71%|███████   | 12/17 [00:03<00:00,  7.52it/s][A
 76%|███████▋  | 13/17 [00:03<00:00,  7.62it/s][A
 82%|████████▏ | 14/17 [00:03<00:00,  7.62it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  8.01it/s][A
100%|██████████| 17/17 [00:03<00:00,  4.54it/s]
 22%|██▏       | 11/50 [27:05<1:52:41, 173.37s/it]

Dimension 192 works! fid score: 33.32409948194672



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:54,  1.48it/s][A
  1%|          | 2/259 [00:00<01:58,  2.17it/s][A
  1%|          | 3/259 [00:01<01:29,  2.87it/s][A
  2%|▏         | 4/259 [00:01<01:14,  3.40it/s][A
  2%|▏         | 5/259 [00:01<01:06,  3.82it/s][A
  2%|▏         | 6/259 [00:01<01:00,  4.17it/s][A
  3%|▎         | 7/259 [00:02<00:59,  4.23it/s][A
  3%|▎         | 8/259 [00:02<00:57,  4.38it/s][A
  3%|▎         | 9/259 [00:02<00:56,  4.42it/s][A
  4%|▍         | 10/259 [00:02<00:55,  4.52it/s][A
  4%|▍         | 11/259 [00:02<00:53,  4.62it/s][A
  5%|▍         | 12/259 [00:03<00:53,  4.59it/s][A
  5%|▌         | 13/259 [00:03<00:52,  4.67it/s][A
  5%|▌         | 14/259 [00:03<00:52,  4.71it/s][A
  6%|▌         | 15/259 [00:03<00:51,  4.77it/s][A
  6%|▌         | 16/259 [00:03<00:50,  4.77it/s][A
  7%|▋         | 17/259 [00:04<00:51,  4.73it/s][A
  7%|▋         | 18/259 [00:04<00:49,  4.84it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.89it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:21,  1.31s/it][A
 12%|█▏        | 2/17 [00:01<00:09,  1.66it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.61it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.41it/s][A
 41%|████      | 7/17 [00:01<00:01,  6.24it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  7.62it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.77it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  9.53it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.31it/s]
 32%|███▏      | 16/50 [37:53<1:37:50, 172.66s/it]

Dimension 192 works! fid score: 47.89482647506567



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:55,  2.23it/s][A
  1%|          | 2/259 [00:00<01:18,  3.28it/s][A
  1%|          | 3/259 [00:00<01:07,  3.80it/s][A
  2%|▏         | 4/259 [00:01<01:00,  4.21it/s][A
  2%|▏         | 5/259 [00:01<00:57,  4.43it/s][A
  2%|▏         | 6/259 [00:01<00:54,  4.62it/s][A
  3%|▎         | 7/259 [00:01<00:53,  4.69it/s][A
  3%|▎         | 8/259 [00:01<00:53,  4.68it/s][A
  3%|▎         | 9/259 [00:02<00:52,  4.76it/s][A
  4%|▍         | 10/259 [00:02<00:51,  4.87it/s][A
  4%|▍         | 11/259 [00:02<00:51,  4.85it/s][A
  5%|▍         | 12/259 [00:02<00:50,  4.91it/s][A
  5%|▌         | 13/259 [00:02<00:50,  4.92it/s][A
  5%|▌         | 14/259 [00:03<00:49,  4.91it/s][A
  6%|▌         | 15/259 [00:03<00:49,  4.93it/s][A
  6%|▌         | 16/259 [00:03<00:49,  4.93it/s][A
  7%|▋         | 17/259 [00:03<00:48,  4.95it/s][A
  7%|▋         | 18/259 [00:03<00:49,  4.87it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  7.06it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:00<00:03,  4.20it/s][A
 12%|█▏        | 2/17 [00:01<00:09,  1.54it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.38it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.16it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  4.71it/s][A
 41%|████      | 7/17 [00:01<00:01,  5.38it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  5.96it/s][A
 53%|█████▎    | 9/17 [00:01<00:01,  6.50it/s][A
 59%|█████▉    | 10/17 [00:02<00:00,  7.09it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.54it/s][A
 71%|███████   | 12/17 [00:02<00:00,  7.52it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  7.52it/s][A
 88%|████████▊ | 15/17 [00:02<00:00,  8.33it/s][A
100%|██████████| 17/17 [00:02<00:00,  5.89it/s]
 42%|████▏     | 21/50 [48:43<1:23:35, 172.94s/it]

Dimension 192 works! fid score: 29.678635148529736



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:48,  1.53it/s][A
  1%|          | 2/259 [00:00<01:55,  2.23it/s][A
  1%|          | 3/259 [00:01<01:32,  2.76it/s][A
  2%|▏         | 4/259 [00:01<01:22,  3.07it/s][A
  2%|▏         | 5/259 [00:01<01:19,  3.18it/s][A
  2%|▏         | 6/259 [00:02<01:21,  3.11it/s][A
  3%|▎         | 7/259 [00:02<01:17,  3.26it/s][A
  3%|▎         | 8/259 [00:02<01:08,  3.68it/s][A
  3%|▎         | 9/259 [00:02<01:02,  3.98it/s][A
  4%|▍         | 10/259 [00:03<00:58,  4.23it/s][A
  4%|▍         | 11/259 [00:03<00:57,  4.35it/s][A
  5%|▍         | 12/259 [00:03<00:55,  4.47it/s][A
  5%|▌         | 13/259 [00:03<00:53,  4.59it/s][A
  5%|▌         | 14/259 [00:03<00:51,  4.75it/s][A
  6%|▌         | 15/259 [00:04<00:51,  4.76it/s][A
  6%|▌         | 16/259 [00:04<00:50,  4.78it/s][A
  7%|▋         | 17/259 [00:04<00:50,  4.79it/s][A
  7%|▋         | 18/259 [00:04<00:50,  4.76it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.64it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:00<00:03,  4.17it/s][A
 12%|█▏        | 2/17 [00:01<00:08,  1.68it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.62it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  5.36it/s][A
 47%|████▋     | 8/17 [00:01<00:01,  6.84it/s][A
 59%|█████▉    | 10/17 [00:01<00:00,  8.13it/s][A
 71%|███████   | 12/17 [00:01<00:00,  8.99it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  9.61it/s][A
100%|██████████| 17/17 [00:02<00:00,  7.24it/s]
 52%|█████▏    | 26/50 [59:34<1:09:05, 172.72s/it]

Dimension 192 works! fid score: 29.497205628123737



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:57,  2.19it/s][A
  1%|          | 2/259 [00:00<01:20,  3.21it/s][A
  1%|          | 3/259 [00:00<01:08,  3.76it/s][A
  2%|▏         | 4/259 [00:01<01:08,  3.74it/s][A
  2%|▏         | 5/259 [00:01<01:05,  3.90it/s][A
  2%|▏         | 6/259 [00:01<01:00,  4.20it/s][A
  3%|▎         | 7/259 [00:01<00:57,  4.38it/s][A
  3%|▎         | 8/259 [00:01<00:55,  4.55it/s][A
  3%|▎         | 9/259 [00:02<00:53,  4.64it/s][A
  4%|▍         | 10/259 [00:02<00:53,  4.66it/s][A
  4%|▍         | 11/259 [00:02<00:51,  4.78it/s][A
  5%|▍         | 12/259 [00:02<00:50,  4.85it/s][A
  5%|▌         | 13/259 [00:03<00:50,  4.90it/s][A
  5%|▌         | 14/259 [00:03<00:50,  4.85it/s][A
  6%|▌         | 15/259 [00:03<00:50,  4.84it/s][A
  6%|▌         | 16/259 [00:03<00:50,  4.84it/s][A
  7%|▋         | 17/259 [00:03<00:49,  4.89it/s][A
  7%|▋         | 18/259 [00:04<00:49,  4.92it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.61it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:20,  1.27s/it][A
 12%|█▏        | 2/17 [00:01<00:08,  1.68it/s][A
 18%|█▊        | 3/17 [00:01<00:05,  2.69it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.77it/s][A
 41%|████      | 7/17 [00:01<00:01,  6.48it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  7.75it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  8.74it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  9.38it/s][A
100%|██████████| 17/17 [00:02<00:00,  6.34it/s]
 62%|██████▏   | 31/50 [1:10:26<54:48, 173.10s/it]

Dimension 192 works! fid score: 24.537619638178054



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<01:59,  2.16it/s][A
  1%|          | 2/259 [00:00<01:21,  3.14it/s][A
  1%|          | 3/259 [00:00<01:15,  3.40it/s][A
  2%|▏         | 4/259 [00:01<01:12,  3.51it/s][A
  2%|▏         | 5/259 [00:01<01:10,  3.58it/s][A
  2%|▏         | 6/259 [00:01<01:07,  3.72it/s][A
  3%|▎         | 7/259 [00:02<01:08,  3.68it/s][A
  3%|▎         | 8/259 [00:02<01:09,  3.62it/s][A
  3%|▎         | 9/259 [00:02<01:10,  3.56it/s][A
  4%|▍         | 10/259 [00:02<01:10,  3.55it/s][A
  4%|▍         | 11/259 [00:03<01:09,  3.58it/s][A
  5%|▍         | 12/259 [00:03<01:09,  3.57it/s][A
  5%|▌         | 13/259 [00:03<01:08,  3.60it/s][A
  5%|▌         | 14/259 [00:03<01:06,  3.68it/s][A
  6%|▌         | 15/259 [00:04<01:06,  3.69it/s][A
  6%|▌         | 16/259 [00:04<01:07,  3.59it/s][A
  7%|▋         | 17/259 [00:04<01:09,  3.47it/s][A
  7%|▋         | 18/259 [00:05<01:08,  3.53it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  4.40it/s][A
100%|██████████| 2/2 [00:00<00:00,  4.72it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:21,  1.33s/it][A
 18%|█▊        | 3/17 [00:01<00:05,  2.40it/s][A
 24%|██▎       | 4/17 [00:01<00:04,  3.11it/s][A
 29%|██▉       | 5/17 [00:01<00:03,  3.77it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  4.54it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.33it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  5.96it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.40it/s][A
 59%|█████▉    | 10/17 [00:02<00:01,  6.73it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  7.01it/s][A
 71%|███████   | 12/17 [00:02<00:00,  6.90it/s][A
 82%|████████▏ | 14/17 [00:02<00:00,  8.46it/s][A
100%|██████████| 17/17 [00:03<00:00,  5.43it/s]
 72%|███████▏  | 36/50 [1:21:19<40:19, 172.84s/it]

Dimension 192 works! fid score: 48.51310775390226



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:02,  2.10it/s][A
  1%|          | 2/259 [00:00<01:22,  3.11it/s][A
  1%|          | 3/259 [00:00<01:09,  3.70it/s][A
  2%|▏         | 4/259 [00:01<01:02,  4.06it/s][A
  2%|▏         | 5/259 [00:01<00:59,  4.29it/s][A
  2%|▏         | 6/259 [00:01<00:58,  4.36it/s][A
  3%|▎         | 7/259 [00:01<00:55,  4.51it/s][A
  3%|▎         | 8/259 [00:01<00:54,  4.63it/s][A
  3%|▎         | 9/259 [00:02<00:53,  4.65it/s][A
  4%|▍         | 10/259 [00:02<00:53,  4.66it/s][A
  4%|▍         | 11/259 [00:02<00:53,  4.59it/s][A
  5%|▍         | 12/259 [00:02<00:52,  4.67it/s][A
  5%|▌         | 13/259 [00:03<00:53,  4.62it/s][A
  5%|▌         | 14/259 [00:03<00:52,  4.69it/s][A
  6%|▌         | 15/259 [00:03<00:50,  4.78it/s][A
  6%|▌         | 16/259 [00:03<00:50,  4.79it/s][A
  7%|▋         | 17/259 [00:03<00:50,  4.79it/s][A
  7%|▋         | 18/259 [00:04<00:50,  4.81it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  5.67it/s][A
100%|██████████| 2/2 [00:00<00:00,  5.57it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:25,  1.60s/it][A
 12%|█▏        | 2/17 [00:01<00:11,  1.35it/s][A
 18%|█▊        | 3/17 [00:01<00:06,  2.21it/s][A
 24%|██▎       | 4/17 [00:01<00:04,  3.11it/s][A
 29%|██▉       | 5/17 [00:02<00:02,  4.02it/s][A
 35%|███▌      | 6/17 [00:02<00:02,  4.69it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.61it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  6.37it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.37it/s][A
 59%|█████▉    | 10/17 [00:02<00:01,  6.80it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  6.77it/s][A
 71%|███████   | 12/17 [00:02<00:00,  7.46it/s][A
 76%|███████▋  | 13/17 [00:03<00:00,  7.85it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  8.53it/s][A
100%|██████████| 17/17 [00:03<00:00,  4.83it/s]
 82%|████████▏ | 41/50 [1:32:06<25:44, 171.59s/it]

Dimension 192 works! fid score: 21.196628410721786



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<03:08,  1.37it/s][A
  1%|          | 2/259 [00:00<01:50,  2.32it/s][A
  1%|          | 3/259 [00:01<01:24,  3.03it/s][A
  2%|▏         | 4/259 [00:01<01:11,  3.55it/s][A
  2%|▏         | 5/259 [00:01<01:04,  3.93it/s][A
  2%|▏         | 6/259 [00:01<01:00,  4.20it/s][A
  3%|▎         | 7/259 [00:02<00:58,  4.29it/s][A
  3%|▎         | 8/259 [00:02<00:56,  4.48it/s][A
  3%|▎         | 9/259 [00:02<00:54,  4.58it/s][A
  4%|▍         | 10/259 [00:02<00:52,  4.72it/s][A
  4%|▍         | 11/259 [00:02<00:51,  4.80it/s][A
  5%|▍         | 12/259 [00:03<00:51,  4.79it/s][A
  5%|▌         | 13/259 [00:03<00:51,  4.79it/s][A
  5%|▌         | 14/259 [00:03<00:50,  4.83it/s][A
  6%|▌         | 15/259 [00:03<00:50,  4.84it/s][A
  6%|▌         | 16/259 [00:03<00:49,  4.92it/s][A
  7%|▋         | 17/259 [00:04<00:51,  4.69it/s][A
  7%|▋         | 18/259 [00:04<00:51,  4.70it/s][A
  7%|▋         | 19/259 [00:0

......sampling......
Label: Class0
Label: Class1
Label: Class2
Label: Class3



  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:00<00:00,  6.28it/s][A
100%|██████████| 2/2 [00:00<00:00,  6.37it/s]

  0%|          | 0/17 [00:00<?, ?it/s][A
  6%|▌         | 1/17 [00:01<00:19,  1.20s/it][A
 12%|█▏        | 2/17 [00:01<00:08,  1.69it/s][A
 18%|█▊        | 3/17 [00:01<00:05,  2.63it/s][A
 24%|██▎       | 4/17 [00:01<00:03,  3.46it/s][A
 29%|██▉       | 5/17 [00:01<00:02,  4.48it/s][A
 35%|███▌      | 6/17 [00:01<00:02,  5.12it/s][A
 41%|████      | 7/17 [00:02<00:01,  5.66it/s][A
 47%|████▋     | 8/17 [00:02<00:01,  5.70it/s][A
 53%|█████▎    | 9/17 [00:02<00:01,  6.04it/s][A
 59%|█████▉    | 10/17 [00:02<00:01,  6.43it/s][A
 65%|██████▍   | 11/17 [00:02<00:00,  6.53it/s][A
 71%|███████   | 12/17 [00:02<00:00,  6.75it/s][A
 76%|███████▋  | 13/17 [00:02<00:00,  6.87it/s][A
 82%|████████▏ | 14/17 [00:03<00:00,  7.30it/s][A
 88%|████████▊ | 15/17 [00:03<00:00,  7.19it/s][A
100%|██████████| 17/17 [00:03<00:00,  4.97it/s]
 92%|█████████▏|

Dimension 192 works! fid score: 22.199479562968904



  0%|          | 0/259 [00:00<?, ?it/s][A
  0%|          | 1/259 [00:00<02:50,  1.51it/s][A
  1%|          | 2/259 [00:00<01:55,  2.23it/s][A
  1%|          | 3/259 [00:01<01:37,  2.63it/s][A
  2%|▏         | 4/259 [00:01<01:28,  2.88it/s][A
  2%|▏         | 5/259 [00:01<01:20,  3.14it/s][A
  2%|▏         | 6/259 [00:02<01:17,  3.26it/s][A
  3%|▎         | 7/259 [00:02<01:08,  3.66it/s][A
  3%|▎         | 8/259 [00:02<01:03,  3.98it/s][A
  3%|▎         | 9/259 [00:02<01:00,  4.11it/s][A
  4%|▍         | 10/259 [00:02<00:58,  4.29it/s][A
  4%|▍         | 11/259 [00:03<00:55,  4.44it/s][A
  5%|▍         | 12/259 [00:03<00:54,  4.56it/s][A
  5%|▌         | 13/259 [00:03<00:52,  4.66it/s][A
  5%|▌         | 14/259 [00:03<00:52,  4.64it/s][A
  6%|▌         | 15/259 [00:04<00:54,  4.50it/s][A
  6%|▌         | 16/259 [00:04<00:52,  4.65it/s][A
  7%|▋         | 17/259 [00:04<00:51,  4.72it/s][A
  7%|▋         | 18/259 [00:04<00:51,  4.65it/s][A
  7%|▋         | 19/259 [00:0

In [16]:
! python3 classification_evaluation.py -b 16 -m test

{'batch_size': 16, 'data_dir': 'data', 'mode': 'test'}
  WeightNorm.apply(module, name, dim)
100% 33/33 [00:17<00:00,  1.89it/s]
[2, 1, 2, 0, 2, 3, 2, 0, 0, 1, 0, 1, 3, 1, 0, 1, 1, 1, 1, 2, 2, 1, 3, 3, 3, 3, 0, 2, 2, 3, 1, 1, 0, 1, 0, 2, 1, 0, 3, 3, 3, 3, 1, 0, 1, 1, 3, 1, 2, 3, 2, 3, 1, 1, 0, 2, 1, 3, 0, 1, 2, 0, 0, 2, 0, 0, 3, 0, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 3, 1, 2, 2, 0, 3, 0, 1, 0, 0, 1, 1, 1, 2, 0, 3, 3, 2, 0, 3, 0, 2, 3, 0, 2, 0, 0, 0, 3, 2, 2, 0, 1, 2, 0, 0, 3, 1, 1, 0, 1, 2, 3, 3, 0, 0, 1, 2, 3, 0, 2, 2, 2, 0, 2, 0, 3, 0, 0, 2, 1, 3, 2, 2, 3, 2, 2, 3, 1, 0, 2, 2, 1, 3, 0, 2, 3, 1, 3, 1, 3, 0, 3, 0, 3, 0, 0, 0, 3, 2, 3, 0, 2, 1, 2, 1, 1, 0, 3, 3, 0, 2, 2, 2, 3, 1, 1, 3, 2, 3, 3, 2, 2, 2, 0, 2, 0, 3, 1, 0, 1, 1, 0, 0, 2, 1, 3, 2, 1, 1, 3, 0, 0, 3, 3, 0, 0, 2, 0, 3, 2, 0, 1, 1, 2, 0, 1, 1, 2, 3, 3, 1, 3, 2, 2, 0, 0, 3, 1, 2, 3, 2, 1, 2, 3, 2, 2, 0, 1, 2, 3, 3, 3, 3, 0, 1, 2, 2, 1, 1, 0, 1, 0, 1, 0, 1, 2, 2, 2, 2, 2, 2, 3, 3, 0, 3, 3, 2, 1, 0, 0, 1, 1, 0, 2, 3, 1, 1, 0, 0, 1, 

In [None]:
0

In [18]:
! python3 generation_evaluation.py

  WeightNorm.apply(module, name, dim)
Label: Class0
Label: Class1
Label: Class2
Label: Class3
#generated images: 100, #reference images: 519
Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100% 91.2M/91.2M [00:00<00:00, 126MB/s]
100% 1/1 [00:01<00:00,  1.03s/it]
100% 5/5 [00:02<00:00,  1.70it/s]
Dimension 192 works! fid score: 25.917448519048907
Average fid score: 25.917448519048907


In [None]:
! python3 wandb_sweep.py

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mparshan-pjavanrood[0m ([33mparshan-pjavanrood-university-of-british-columbia[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Agent Starting Run: t984sodv with config:
[34m[1mwandb[0m: 	nr_filters: 80
[34m[1mwandb[0m: 	nr_logistic_mix: 5
[34m[1mwandb[0m: 	nr_resnet: 5
[34m[1mwandb[0m: Tracking run with wandb version 0.19.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/drive/MyDrive/cpen455-project/wandb/run-20250406_053238-t984sodv[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33msummer-sweep-13[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/parshan-pjavanrood-university-of-british-columbia/CPEN455HW[0m
[34m[1mwandb[0m: 🧹 View sweep a

In [None]:
!pip install numba

from numba import cuda
device = cuda.get_current_device()
device.reset()

