# Punks Gen w/ Labeled CVAE

In [1]:
import sys
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import matplotlib.pyplot as plt

%matplotlib inline


In [2]:
# Load pallets library

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
from pallets import images as I, datasets as DS, models as M, logging as L


# Settings

In [3]:
SAVE_NAME = 'cvae.naive.labeled'

In [4]:
USE_GPU = True
LOG_LEVEL = 'INFO'

TEST_SIZE = 1000
EPOCHS = 50
LR = 1e-03
BATCH_SIZE = 32


In [5]:
# To GPU, or not to GPU
device = M.get_device(require_gpu=USE_GPU)

# Logging
L.init_logger(notebook=True)


<Logger pallets (INFO)>

## Prepare Datasets

In [6]:
# all_colors = I.get_punk_colors()
# mapper = DS.ColorOneHotMapper(all_colors)
# dataset = DS.OneHotCPunksDataset(mapper, test_size=TEST_SIZE)
# # dataset = DS.FastOneHotCPunksDataset(
# #     device, mapper, test_size=TEST_SIZE
# # )
# torch.save(dataset, '../artifacts/onehot_ds_cpu.pt')

dataset = torch.load('../artifacts/onehot_ds_cpu.pt')

In [7]:
train_sampler = SubsetRandomSampler(dataset.train_idx)
test_sampler = SubsetRandomSampler(dataset.test_idx)

train_loader = DataLoader(
    dataset, batch_size=BATCH_SIZE, sampler=train_sampler,
)
test_loader = DataLoader(
    dataset, batch_size=BATCH_SIZE, sampler=test_sampler,
)


## Labeled Naive CVAE

In [8]:
input_dim = 24 * 24 * 222
hidden_dim = 576
latent_dim = 32
classes_dim = 92


model = M.cvae.LabeledCVAE(input_dim, hidden_dim, latent_dim, classes_dim)
criterion = M.cvae.Loss()


In [9]:

train_losses, test_losses = M.cvae.train(
    device, model, criterion, train_loader, test_loader,
    learn_rate=LR, epochs=EPOCHS, conditional_loss=True
)


INFO | model: pallets.models.cvae.LabeledCVAE
INFO | criterion: pallets.models.cvae.Loss
INFO | learn rate: 0.001
INFO | epochs: 50
INFO | epoch 1 (  0%) loss: 1423525.750000
INFO | epoch 1 ( 35%) loss: 89680.947391
INFO | epoch 1 ( 70%) loss: 50595.481766
INFO | epoch 1 (100%) loss: 38352.575289
INFO | epoch 1 (test) loss: 6885.266197
INFO | epoch 2 (  0%) loss: 7326.567383


KeyboardInterrupt: 

In [None]:
M.save(SAVE_NAME, model, train_losses, test_losses)

# Model Output to Image

In [None]:
# model, train_losses, test_losses = M.load(SAVE_NAME, device)
# model = model.to(device)


AttributeError: Can't get attribute 'LabeledNaiveCVAE' on <module 'pallets.models.cvae' from '/home/jmsdnns/ML/pallets/pallets/models/cvae.py'>

In [None]:
import random
import json

raw_labels = json.load(open("../artifacts/pallets_labels.json"))
label_keys = [k for k in raw_labels["0"].keys()]

def rand_label():
    label_idx = int(random.random() * len(dataset._labels))
    features = dataset._labels[label_idx]
    enabled_names = [k for k,v in zip(label_keys, features) if v.item() == 1]
    return features.to(device), enabled_names


label_keys

In [None]:
# Generate new image

def rand_punk():
    z = torch.randn(1, latent_dim).to(device)
    # print(z.shape)
    features, names = rand_label()
    print(f"Features: {', '.join(names)}")

    with torch.no_grad():
        model.eval()
        generated_image = model.decoder(z, features.unsqueeze(0))
        # print(generated_image.shape)

    decoded_one_hot = generated_image[0]
    print(decoded_one_hot.shape)
    decoded_one_hot = decoded_one_hot[:-classes_dim].view((222, 24, 24))
    # print(decoded_one_hot.shape)
    decoded = DS.one_hot_to_rgba(decoded_one_hot, dataset.mapper)
    print(f"Shape: {decoded.shape}")
    return decoded


### 5 Randoms

In [None]:
decoded = rand_punk()

plt.imshow(transforms.functional.to_pil_image(decoded))
plt.axis('off')
plt.show()


In [None]:
decoded = rand_punk()

plt.imshow(transforms.functional.to_pil_image(decoded))
plt.axis('off')
plt.show()


In [None]:
decoded = rand_punk()

plt.imshow(transforms.functional.to_pil_image(decoded))
plt.axis('off')
plt.show()


In [None]:
decoded = rand_punk()

plt.imshow(transforms.functional.to_pil_image(decoded))
plt.axis('off')
plt.show()


In [None]:
decoded = rand_punk()

plt.imshow(transforms.functional.to_pil_image(decoded))
plt.axis('off')
plt.show()


## Reconstruction

In [None]:
def reconstruct_punk(idx):
    punk = I.get_punk_tensor(idx)
    p = DS.rgba_to_one_hot(punk, dataset.mapper)
    p = p.unsqueeze(0)
    p = p.to(device)

    _, labels = dataset[idx]
    l = labels.unsqueeze(0)
    l = l.to(device)
    enabled_features = [k for k,v in zip(label_keys, labels) if v.item() == 1]

    model.eval()
    with torch.no_grad():
        reconstructed, mu, logvar = model.forward(p, l)

    recon_punk = reconstructed[0].cpu()
    recon_punk = recon_punk[:-classes_dim].view((222, 24, 24))
    recon_punk = DS.one_hot_to_rgba(recon_punk, dataset.mapper)

    return punk, recon_punk, enabled_features


def draw_two(img1, img2):
    page_size = 2
    view_x, view_y = 4*page_size, 2*page_size
    fig = plt.figure(figsize=(view_x, view_y))
    fig.add_subplot(1, 2, 1)
    plt.imshow(transforms.functional.to_pil_image(img1))
    plt.axis('off')
    fig.add_subplot(1, 2, 2)
    plt.imshow(transforms.functional.to_pil_image(img2))
    plt.axis('off')
    plt.show()


### 5 Recons

In [None]:
punk, recon_punk, features = reconstruct_punk(1000)

print(f"Features: {', '.join(features)}")
draw_two(punk, recon_punk)

In [None]:
punk, recon_punk, features = reconstruct_punk(2001)

print(f"Features: {', '.join(features)}")
draw_two(punk, recon_punk)


In [None]:
punk, recon_punk, features = reconstruct_punk(5000)

print(f"Features: {', '.join(features)}")
draw_two(punk, recon_punk)

In [None]:
punk, recon_punk, features = reconstruct_punk(8000)

print(f"Features: {', '.join(features)}")
draw_two(punk, recon_punk)

In [None]:
punk, recon_punk, features = reconstruct_punk(1337)

print(f"Features: {', '.join(features)}")
draw_two(punk, recon_punk)

# Testing

In [None]:
plt.title("Train & Test loss")
plt.plot(train_losses, label='train loss')
plt.plot(test_losses, label='test loss')
plt.legend()
plt.show


In [None]:
train_losses


In [None]:
test_losses
