In [1]:
from google.colab import drive
# Mount your google drive to the /content/drive directory.
# This ensures that files changes remain saved in your drive.
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [2]:
%cd /content/drive/MyDrive/cpen455-project/

/content/drive/MyDrive/cpen455-project


In [3]:
!pip install -r requirements.txt

Collecting bidict (from -r requirements.txt (line 1))
  Downloading bidict-0.23.1-py3-none-any.whl.metadata (8.7 kB)
Collecting pytorch-fid (from -r requirements.txt (line 7))
  Downloading pytorch_fid-0.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting dotenv (from -r requirements.txt (line 11))
  Downloading dotenv-0.9.9-py2.py3-none-any.whl.metadata (279 bytes)
Collecting python-dotenv (from dotenv->-r requirements.txt (line 11))
  Downloading python_dotenv-1.1.0-py3-none-any.whl.metadata (24 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r requirements.txt (line 7))
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.0.1->pytorch-fid->-r r

In [4]:
import time
import os
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
import wandb
from tqdm import tqdm
from pprint import pprint
import argparse
from pytorch_fid.fid_score import calculate_fid_given_paths
import logging
import pandas as pd
import numpy as np

import dotenv
dotenv.load_dotenv()

True

In [5]:
import importlib
import utils, model, dataset, generation_evaluation

importlib.reload(utils)  # Reload the script after changes
importlib.reload(model)  # Reload the script after changes
importlib.reload(dataset)  # Reload the script after changes
importlib.reload(generation_evaluation)  # Reload the script after changes

from utils import *
from model import *
from dataset import *
import generation_evaluation

In [6]:
import logging
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.WARN)

In [8]:
from dataclasses import dataclass

@dataclass
class Config:
    en_wandb: bool = False
    tag: str = "default"
    sampling_interval: int = 5
    data_dir: str = "data"
    save_dir: str = "models"
    sample_dir: str = "samples"
    dataset: str = "cpen455"
    save_interval: int = 10
    load_params: str = None
    obs: tuple = (3, 32, 32)
    nr_resnet: int = 1
    nr_filters: int = 40
    nr_logistic_mix: int = 5
    lr: float = 0.0002
    lr_decay: float = 0.999995
    batch_size: int = 64
    sample_batch_size: int = 32
    max_epochs: int = 5000
    seed: int = 1

In [9]:
args = Config(
    batch_size=32,
    sample_batch_size=16,
    sampling_interval=2,
    save_interval=10,
    max_epochs=50,
    en_wandb=True,
    nr_filters=80,
    nr_logistic_mix=5,
    nr_resnet=5,
    tag="Middle_f80_l5_r5",
    sample_dir="samples_f80_l5_r5",
    save_dir="models_f80_l5_r5",
)

pprint(args.__dict__)
check_dir_and_create(args.save_dir)

# reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)

model_name = 'pcnn_' + args.dataset + "_"
model_path = args.save_dir + '/'
if args.load_params is not None:
    model_name = model_name + 'load_model'
    model_path = model_path + model_name + '/'
else:
    model_name = model_name + 'from_scratch'
    model_path = model_path + model_name + '/'

job_name = "PCNN_Training_" + "dataset:" + args.dataset + "_" + args.tag

if args.en_wandb:
    # start a new wandb run to track this script
    wandb.init(
        # set entity to specify your username or team name
        # entity="qihangz-work",
        # set the wandb project where this run will be logged
        project="CPEN455HW",
        # group=Group Name
        name=job_name,
    )
    wandb.config.current_time = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
    wandb.config.update(args)

#set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#Reminder: if you have patience to read code line by line, you should notice this comment. here is the reason why we set num_workers to 0:
#In order to avoid pickling errors with the dataset on different machines, we set num_workers to 0.
#If you are using ubuntu/linux/colab, and find that loading data is too slow, you can set num_workers to 1 or even bigger.
kwargs = {'num_workers':2, 'pin_memory':True, 'drop_last':True}

{'batch_size': 32,
 'data_dir': 'data',
 'dataset': 'cpen455',
 'en_wandb': True,
 'load_params': None,
 'lr': 0.0002,
 'lr_decay': 0.999995,
 'max_epochs': 50,
 'nr_filters': 80,
 'nr_logistic_mix': 5,
 'nr_resnet': 5,
 'obs': (3, 32, 32),
 'sample_batch_size': 16,
 'sample_dir': 'samples_f80_l5_r5',
 'sampling_interval': 2,
 'save_dir': 'models_f80_l5_r5',
 'save_interval': 10,
 'seed': 1,
 'tag': 'Middle_f80_l5_r5'}


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mparshan-pjavanrood[0m ([33mparshan-pjavanrood-university-of-british-columbia[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
# set data
if "mnist" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor(), rescaling, replicate_color_channel])
    train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, download=True,
                        train=True, transform=ds_transforms), batch_size=args.batch_size,
                            shuffle=True, **kwargs)

    test_loader  = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

elif "cifar" in args.dataset:
    ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])
    if args.dataset == "cifar10":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR10(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    elif args.dataset == "cifar100":
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=True,
            download=True, transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)

        test_loader  = torch.utils.data.DataLoader(datasets.CIFAR100(args.data_dir, train=False,
                    transform=ds_transforms), batch_size=args.batch_size, shuffle=True, **kwargs)
    else:
        raise Exception('{} dataset not in {cifar10, cifar100}'.format(args.dataset))

elif "cpen455" in args.dataset:
    ds_transforms = transforms.Compose([transforms.Resize((32, 32)), rescaling])
    train_loader = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'train',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    test_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'test',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
    val_loader  = torch.utils.data.DataLoader(CPEN455Dataset(root_dir=args.data_dir,
                                                              mode = 'validation',
                                                              transform=ds_transforms),
                                                batch_size=args.batch_size,
                                                shuffle=True,
                                                **kwargs)
else:
    raise Exception('{} dataset not in {mnist, cifar, cpen455}'.format(args.dataset))

In [11]:
def train_or_test(model, data_loader, optimizer, loss_op, device, args, epoch, mode = 'training'):
    logging.debug('mode: {}'.format(mode))
    if mode == 'training':
        model.train()
    else:
        model.eval()

    deno =  args.batch_size * np.prod(args.obs) * np.log(2.)
    loss_tracker = mean_tracker()

    for batch_idx, item in enumerate(tqdm(data_loader)):
        logging.debug('batch_idx: {}'.format(batch_idx))
        model_input, label = item
        model_input = model_input.to(device)
        model_output = model(model_input, label)

        loss = loss_op(model_input, model_output)
        loss_tracker.update(loss.item()/deno)
        if mode == 'training':
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    if args.en_wandb:
        wandb.login()
        wandb.log({mode + "-Average-BPD" : loss_tracker.get_mean()})
        wandb.log({mode + "-epoch": epoch})


In [None]:
args.obs = (3, 32, 32)
input_channels = args.obs[0]

loss_op   = lambda real, fake : discretized_mix_logistic_loss(real, fake)
sample_op = lambda x : sample_from_discretized_mix_logistic(x, args.nr_logistic_mix)

model = PixelCNN(nr_resnet=args.nr_resnet, nr_filters=args.nr_filters,
            input_channels=input_channels, nr_logistic_mix=args.nr_logistic_mix)
model = model.to(device)

if args.load_params:
    model.load_state_dict(torch.load(args.load_params))
    print('model parameters loaded')

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=args.lr_decay)

for epoch in tqdm(range(args.max_epochs)):
    train_or_test(model = model,
                  data_loader = train_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'training')

    # decrease learning rate
    scheduler.step()
    train_or_test(model = model,
                  data_loader = test_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'test')

    train_or_test(model = model,
                  data_loader = val_loader,
                  optimizer = optimizer,
                  loss_op = loss_op,
                  device = device,
                  args = args,
                  epoch = epoch,
                  mode = 'val')

    if epoch % args.sampling_interval == 0:
        print('......sampling......')
        sampled_images = generation_evaluation.my_sample(model, args.sample_dir, args.sample_batch_size, args.obs, sample_op)
        sample_result = {label: wandb.Image(img, caption="epoch {} - label {}".format(epoch, label)) for label, img in sampled_images.items()}

        gen_data_dir = args.sample_dir
        ref_data_dir = args.data_dir +'/test'
        paths = [gen_data_dir, ref_data_dir]
        try:
            fid_score = calculate_fid_given_paths(paths, 32, device, dims=192)
            print("Dimension {:d} works! fid score: {}".format(192, fid_score))
        except:
            print("Dimension {:d} fails!".format(192))

        if args.en_wandb:
          for label, img in sample_result.items():
            wandb.log({f"samples_{label}": img,
                        "FID": fid_score})

    if (epoch + 1) % args.save_interval == 0:
        if not os.path.exists("models"):
            os.makedirs("models")
        torch.save(model.state_dict(), 'models/{}_{}.pth'.format(model_name, epoch))


  WeightNorm.apply(module, name, dim)
  0%|          | 0/50 [00:00<?, ?it/s]
  0%|          | 0/129 [00:00<?, ?it/s][A
  1%|          | 1/129 [00:26<55:39, 26.09s/it][A
  2%|▏         | 2/129 [00:26<23:27, 11.09s/it][A
  2%|▏         | 3/129 [00:34<20:21,  9.70s/it][A
  3%|▎         | 4/129 [00:35<12:50,  6.17s/it][A
  4%|▍         | 5/129 [00:44<14:51,  7.19s/it][A
  5%|▍         | 6/129 [00:45<10:15,  5.00s/it][A
  5%|▌         | 7/129 [00:54<12:47,  6.29s/it][A
  6%|▌         | 8/129 [00:55<09:23,  4.65s/it][A
  7%|▋         | 9/129 [01:04<12:05,  6.04s/it][A
  8%|▊         | 10/129 [01:05<08:49,  4.45s/it][A
  9%|▊         | 11/129 [01:14<11:42,  5.95s/it][A
  9%|▉         | 12/129 [01:15<08:31,  4.37s/it][A
 10%|█         | 13/129 [01:24<11:13,  5.80s/it][A
 11%|█         | 14/129 [01:25<08:17,  4.33s/it][A
 12%|█▏        | 15/129 [01:35<11:19,  5.96s/it][A
 12%|█▏        | 16/129 [01:36<08:35,  4.56s/it][A
 13%|█▎        | 17/129 [01:44<10:15,  5.50s/it][A
 14%|

In [None]:
! python3 classification_evaluation.py -m test -b 1

In [7]:
! python3 generation_evaluation.py

  WeightNorm.apply(module, name, dim)
Label: Class0
Label: Class1
Label: Class2
Label: Class3
#generated images: 100, #reference images: 519
Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100% 91.2M/91.2M [00:00<00:00, 99.2MB/s]
100% 1/1 [00:00<00:00,  1.70it/s]
100% 5/5 [00:08<00:00,  1.65s/it]
Dimension 192 works! fid score: 25.811592364799804
Average fid score: 25.811592364799804


In [None]:
! python3 wandb_sweep.py

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mparshan-pjavanrood[0m ([33mparshan-pjavanrood-university-of-british-columbia[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Agent Starting Run: t984sodv with config:
[34m[1mwandb[0m: 	nr_filters: 80
[34m[1mwandb[0m: 	nr_logistic_mix: 5
[34m[1mwandb[0m: 	nr_resnet: 5
[34m[1mwandb[0m: Tracking run with wandb version 0.19.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/drive/MyDrive/cpen455-project/wandb/run-20250406_053238-t984sodv[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33msummer-sweep-13[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/parshan-pjavanrood-university-of-british-columbia/CPEN455HW[0m
[34m[1mwandb[0m: 🧹 View sweep a