In [1]:

from BaseVAEs.models.disent.model.ae import EncoderConv64, DecoderConv64, AutoEncoder
from BaseVAEs.models.disent.frameworks.vae.weaklysupervised import AdaVae, AdaCatVae
from BaseVAEs.models.disent.frameworks.vae.unsupervised import BetaVae

In [2]:
import torch
import numpy as np
import time
import matplotlib
matplotlib.use('Agg')
import sys
import os
from torch.utils.tensorboard import SummaryWriter
from rtpt.rtpt import RTPT
from torch.optim import lr_scheduler
from torch.optim import Adam

In [3]:
import BaseVAEs.utils_disent as utils
import BaseVAEs.data as data
from BaseVAEs.args import parse_args_as_dict


In [7]:
sys_argv = [
    "--save-step", "20",
    "--print-step", "1",
    "--learning-rate", "0.0001",
    "--batch-size", "128",
    "--epochs", "2000",
    "--exp-name", "unsup-betavae-0-ecr",
    "--n-groups", "3",
    "--n-protos", "6",
    "--seed", "0",
    "--dataset", "ecr",
    "--initials", "WS",
    "--lr-scheduler-warmup-steps", "1000",
    "--data-dir", "Data",
    "--results-dir", "experiments/BaseVAEs/runs/",
    "--n-workers", "0"
]

config = parse_args_as_dict(sys_argv)

Device name: cuda:0


In [10]:
# format print config
print("Config:")
for k, v in config.items():
    print("\t{}: {}".format(k, v))

Config:
	device: cuda:0
	device_ids: [0]
	save_step: 20
	print_step: 1
	display_step: 1
	lambda_recon_proto: 1.0
	train_protos: False
	freeze_enc: False
	learning_rate: 0.0001
	lr_scheduler: False
	lr_scheduler_warmup_steps: 1000
	batch_size: 128
	epochs: 2000
	n_workers: 0
	prototype_vectors: [2, 2]
	n_groups: 3
	n_protos: 6
	proto_dim: 32
	extra_mlp_dim: 4
	extra_softmax: False
	multiheads: False
	beta: 1.0
	lin_enc_size: 512
	temperature: 1.0
	exp_name: unsup-betavae-0-ecr
	results_dir: experiments/BaseVAEs/runs/unsup-betavae-0-ecr
	model_dir: experiments/BaseVAEs/runs/unsup-betavae-0-ecr\states
	img_dir: experiments/BaseVAEs/runs/unsup-betavae-0-ecr\imgs
	data_dir: Data
	seed: 0
	dataset: ecr
	initials: WS
	fpath_load_pretrained: None
	ckpt_fp: None
	test: False
	img_shape: (3, 64, 64)


In [5]:

def train(model, data_loader, log_samples, optimizer, scheduler, writer, config):

    rtpt = RTPT(name_initials=config['initials'], experiment_name='XIC_PrototypeDL', max_iterations=config['epochs'])
    rtpt.start()

    warmup_steps = 0

    for e in range(0, config['epochs']):
        max_iter = len(data_loader)
        start = time.time()
        loss_dict = dict(
            {'z_recon_loss': 0, 'loss': 0, 'kld': 0, 'elbo': 0})

        for i, batch in enumerate(data_loader):

            # manual lr warmup
            if warmup_steps < config['lr_scheduler_warmup_steps']:
                learning_rate = config['learning_rate'] * (warmup_steps + 1) / config['lr_scheduler_warmup_steps']
                optimizer.param_groups[0]['lr'] = learning_rate
            warmup_steps += 1

            imgs, labels_one_hot, labels_id, shared_labels = batch

            imgs0 = imgs[0].to(config['device'])
            imgs1 = imgs[1].to(config['device'])
            imgs = torch.cat((imgs0, imgs1), dim=0)
            # labels0_one_hot = labels_one_hot[0].to(config['device']).float()
            # labels1_one_hot = labels_one_hot[1].to(config['device']).float()
            # labels0_ids = labels_id[0].to(config['device']).float()
            # labels1_ids = labels_id[1].to(config['device']).float()
            # shared_labels = shared_labels.to(config['device'])

            # from disent repo: x_targ is if augmentation is applied, otherwise x_targ is x
            batch = {'x': (imgs,), 'x_targ': (imgs,)}
            batch_loss_dict = model.compute_training_loss(batch, batch_idx=i)

            loss, recon_loss, kl_reg_loss, kl_loss, elbo = batch_loss_dict['train_loss'], \
                                                           batch_loss_dict['recon_loss'], \
                                                           batch_loss_dict['kl_reg_loss'], \
                                                           batch_loss_dict['kl_loss'], \
                                                           batch_loss_dict['elbo']

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if config['lr_scheduler'] and e > config['lr_scheduler_warmup_steps']:
                scheduler.step()

            loss_dict['z_recon_loss'] += recon_loss.item()
            # loss_dict['proto_recon_loss'] += proto_recon_loss.item()
            loss_dict['kld'] += kl_reg_loss.item()
            loss_dict['loss'] += loss.item()
            loss_dict['elbo'] += elbo.item()

        for key in loss_dict.keys():
            loss_dict[key] /= len(data_loader)

        rtpt.step(subtitle=f'loss={loss_dict["loss"]:2.2f}')

        if (e + 1) % config['display_step'] == 0 or e == config['epochs'] - 1:
            cur_lr = optimizer.param_groups[0]["lr"]
            writer.add_scalar("lr", cur_lr, global_step=e)
            for key in loss_dict.keys():
                writer.add_scalar(f'train/{key}', loss_dict[key], global_step=e)

        if (e + 1) % config['print_step'] == 0 or e == config['epochs'] - 1:
            print(f'epoch {e} - loss {loss.item():2.4f} - time/epoch {(time.time() - start):2.2f}')
            loss_summary = ''
            for key in loss_dict.keys():
                loss_summary += f'{key} {loss_dict[key]:2.4f} '
            print(loss_summary)

        if (e + 1) % config['save_step'] == 0 or e == config['epochs'] - 1 or e == 0:
            state = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'ep': e,
                'config': config
            }
            torch.save(state, os.path.join(config['model_dir'], '%05d.pth' % (e)))

            # plot the individual prototypes of each group
            # utils.plot_prototypes(model, writer, config, step=e)

            # plot a few samples with proto recon
            utils.plot_examples(log_samples, model, writer, config, step=e)

            print(f'SAVED - epoch {e} - imgs @ {config["img_dir"]} - model @ {config["model_dir"]}')


def main(config):

    # get train data
    _data_loader = data.get_dataloader(config)

    # get test set samples
    test_set = data.get_test_set(_data_loader, config)

    # create tb writer
    writer = SummaryWriter(log_dir=config['results_dir'])

    # model setup
    _model = BetaVae(make_optimizer_fn=lambda params: Adam(params, lr=1e-3),
                 make_model_fn=lambda: AutoEncoder(
                     encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=config['n_groups'], z_multiplier=2),
                     decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=config['n_groups']),
                 ),
                 cfg=BetaVae.cfg(beta=4))

    _model = _model.to(config['device'])

    # optimizer setup
    optimizer = torch.optim.Adam(_model.parameters(), lr=config['learning_rate'])

    # learning rate scheduler
    scheduler = None
    if config['lr_scheduler']:
        # TODO: try LambdaLR
        num_steps = len(_data_loader) * config['epochs']
        num_steps += config['lr_scheduler_warmup_steps']
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps, eta_min=2e-5)

    # start training
    train(_model, _data_loader, test_set, optimizer, scheduler, writer, config)


In [8]:
main(config=config)

Getting dataloader for ecr
Loading data...
Dataset: ecr
root path: Data\ECR
root path: c:\Users\yuviu\Desktop\Uni Work\Thesis\XIConceptLearning\experiments\Data\ECR
Loading Data\ECR\train_ecr_pairs.npy
 Config num_workers: 0
Loading test set...
y_set shape: (32, 10)
	y_set: [[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.

  x_set = torch.Tensor(x_set)


epoch 0 - loss 403.2914 - time/epoch 17.14
z_recon_loss 454.2154 loss 454.2250 kld 0.0096 elbo -454.2178 
SAVED - epoch 0 - imgs @ experiments/BaseVAEs/runs/unsup-betavae-0-ecr\imgs - model @ experiments/BaseVAEs/runs/unsup-betavae-0-ecr\states
epoch 1 - loss 487.9590 - time/epoch 6.55
z_recon_loss 455.9960 loss 456.0010 kld 0.0050 elbo -455.9972 
epoch 2 - loss 413.6573 - time/epoch 5.81
z_recon_loss 453.8225 loss 453.8234 kld 0.0009 elbo -453.8227 
epoch 3 - loss 320.8172 - time/epoch 6.39
z_recon_loss 450.7908 loss 450.7927 kld 0.0020 elbo -450.7912 
epoch 4 - loss 513.6982 - time/epoch 6.13
z_recon_loss 452.5975 loss 452.6313 kld 0.0338 elbo -452.6060 
epoch 5 - loss 356.8102 - time/epoch 5.91
z_recon_loss 441.8451 loss 442.1337 kld 0.2886 elbo -441.9172 
epoch 6 - loss 501.0178 - time/epoch 6.03
z_recon_loss 421.7423 loss 427.2255 kld 5.4832 elbo -423.1131 
epoch 7 - loss 344.7471 - time/epoch 7.48
z_recon_loss 391.8606 loss 401.1198 kld 9.2592 elbo -394.1754 
epoch 8 - loss 398.3

KeyboardInterrupt: 

In [19]:
def load_pretrained(model, ckpt):
    model.load_state_dict(ckpt['model'])
    model.proto_dict = ckpt['model_misc']['prototypes']
    model.softmax_temp = ckpt['model_misc']['softmax_temp']
    return model

from ProtoLearning.models.icsn import iCSN



In [5]:
import os
print(os.getcwd())

# change to ../
os.chdir('experiments')
print(os.getcwd())

c:\Users\yuviu\Desktop\Uni Work\Thesis\XIConceptLearning
c:\Users\yuviu\Desktop\Uni Work\Thesis\XIConceptLearning\experiments
