In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import math
import os
import torch

from upcycle.plotting.credible_regions import get_arm, draw_arm_comparison
from upcycle import cuda
from pathlib import Path

sns.set(style='whitegrid', font_scale=1.25)

In [None]:
exp_dir = '../data/experiments/image_generation'

arms = {
    'CIFAR-10': 'sngan_cifar10_v0.0.5',
    'CIFAR-100': 'sngan_cifar100_v0.0.5',
}
window = 1

In [None]:
fig = plt.figure(figsize=(10, 4))

ax_1 = fig.add_subplot(1, 2, 1)
ax_1 = draw_arm_comparison(ax_1, exp_dir, arms, 'train_metrics', 'step', 'fid_score')

ax_2 = fig.add_subplot(1, 2, 2)
ax_2 = draw_arm_comparison(ax_2, exp_dir, arms, 'train_metrics', 'step', 'is_score')

ax_2.legend()
plt.tight_layout()

In [None]:
fig = plt.figure(figsize=(10, 4))

ax_1 = fig.add_subplot(1, 2, 1)
ax_1 = draw_arm_comparison(ax_1, exp_dir, arms, 'train_metrics', 'step', 'disc_loss')

ax_2 = fig.add_subplot(1, 2, 2)
ax_2 = draw_arm_comparison(ax_2, exp_dir, arms, 'train_metrics', 'step', 'gen_loss')

ax_2.legend()
plt.tight_layout()
plt.tight_layout()

In [None]:
from oil.architectures.img_gen.resnetgan import Generator
import pickle as pkl

gen_kwargs = dict(img_channels=3, z_dim=128, k=256)
gen_net = Generator(**gen_kwargs)
gen_net = cuda.try_cuda(gen_net)

epoch = 500
arm_name = 'sngan_cifar10_v0.0.5'
ckpt_files = list(Path(os.path.join(exp_dir, arm_name)).rglob(f'*generator_{epoch}.ckpt'))
with open(ckpt_files[0], 'rb') as f:
    state_dict = pkl.load(f)

gen_net.load_state_dict(state_dict)
_, fid_score, _, _ = get_arm(exp_dir, arm_name, 'train_metrics.csv', 'step', 'fid_score', window)
_, is_score, _, _ = get_arm(exp_dir, arm_name, 'train_metrics.csv', 'step', 'is_score', window)
print(f'FID Score: {fid_score[epoch - 1]:0.4f}')
print(f'IS Score: {is_score[epoch - 1]:0.4f}')

In [None]:
import torchvision.utils as tv_utils
fig = plt.figure(figsize=(8, 8))
num_rows = 8
with torch.no_grad():
    fake_images = gen_net.sample(num_rows ** 2).cpu()
img_grid = tv_utils.make_grid(fake_images, nrow=num_rows, normalize=True)
img_grid = np.transpose(img_grid, (1, 2, 0))

plt.axis('off')
plt.imshow(img_grid)