# Train a JointVAE model

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import torch

use_cuda = torch.cuda.is_available()
use_cuda

True

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # change to your device

#### Prepare data list

In [4]:
!ls data

dress_dresslen_train_test_splits.json	loadable_women_primary_dress.csv
dress_sleeve_train_test_splits.json	rayimages.ipynb
dress_sleevelen_train_test_splits.json


#### dress sleeve data

In [5]:
import json
with open("./data/dress_sleeve_train_test_splits.json", "r") as infile:
    data_dict = json.load(infile)
    
data_dict.keys()

dict_keys(['X_train_1', 'y_train_1', 'X_test_1', 'y_test_1', 'X_train_2', 'y_train_2', 'X_test_2', 'y_test_2', 'X_train_3', 'y_train_3', 'X_test_3', 'y_test_3', 'X_train_4', 'y_train_4', 'X_test_4', 'y_test_4', 'X_train_5', 'y_train_5', 'X_test_5', 'y_test_5', 'X_train_6', 'y_train_6', 'X_test_6', 'y_test_6', 'X_train_7', 'y_train_7', 'X_test_7', 'y_test_7', 'X_train_8', 'y_train_8', 'X_test_8', 'y_test_8', 'X_train_9', 'y_train_9', 'X_test_9', 'y_test_9', 'X_train_10', 'y_train_10', 'X_test_10', 'y_test_10'])

### Modify paths for home

In [6]:
!ls /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810

1  2  3  4  5  6  7  8	9


In [7]:
image_paths_train = []
image_paths_test = []

#root_data_dir = "/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables"
root_data_dir = "/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810"

for key, val in data_dict.items():
    if 'X_train' in key:
        image_paths_train.extend([root_data_dir + imgpath for imgpath in val])
    elif 'X_test' in key:
        image_paths_test.extend([root_data_dir + imgpath for imgpath in val])

#### Get ALL the filenames actually there

In [8]:
import glob

all_filenames = []
for filename in glob.iglob(root_data_dir + '**/*/*/*', recursive=True):
     all_filenames.append(filename)

all_filenames[:5]

['/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/5/5/5594593_9244920.jpg',
 '/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/5/5/5508248_9177849.jpg',
 '/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/5/5/5520209_8572840.jpg',
 '/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/5/5/5524461_9299582.jpg',
 '/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/5/5/5548707_8853300.jpg']

#### Resolve conflicts

In [9]:
print(len(image_paths_train), len(image_paths_test))

207603 23067


In [10]:
both = image_paths_train + image_paths_test
print(len(both))

230670


In [11]:
numpy_diff1=np.setdiff1d(both,all_filenames)
#diffs = list(set(both) ^ set(all_filenames))

In [12]:
len(numpy_diff1)

659

In [13]:
print(numpy_diff1[0] in both)
print(numpy_diff1[0] in all_filenames)

True
False


In [14]:
def filter_list(full_list, excludes):
    s = set(excludes)
    return (x for x in full_list if x not in s)
#filtered_list = list(filter_list(full_list, excludes))

In [15]:
cleaned = list(filter_list(both, numpy_diff1))

In [16]:
print(len(cleaned))

224080


In [17]:
image_paths_train = cleaned[:190000]
image_paths_test = cleaned[190000:]

print(f"Number of train image paths: {len(image_paths_train):,d}")
print(f"Number of test image paths: {len(image_paths_test):,d}")
print()
print("Sample paths:")
print(image_paths_train[0])
print(image_paths_train[-1])
print(image_paths_test[0])
print(image_paths_test[-1])

Number of train image paths: 190,000
Number of test image paths: 34,080

Sample paths:
/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2893552_3773662.jpg
/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/6/2694064_3474675.jpg
/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/9/2980076_3864391.jpg
/workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/3/0/3058817_8191784.jpg


### Not sure this is actually bad data - maybe something wrong with path?

In [18]:
#!ls /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/2/

#### loadable dresses data

### Create data loaders

In [20]:
from torchvision import transforms
from utils.dataloader_tools import get_imagelist_dataloader, ImageListDataset

BATCH_SIZE = 512

composed = transforms.Compose([transforms.CenterCrop((256,256)),transforms.Resize((256,256)),transforms.ToTensor()])

# convert rgb is for the cv2 loaded images that I've got in this dir
train_dataset = ImageListDataset(image_paths_train, cut_from='top', cut_amount=256, transform=composed, 
                                 convert_rgb=False, error_handling=True)
test_dataset = ImageListDataset(image_paths_test, cut_from='top', cut_amount=256, transform=composed, 
                                convert_rgb=False, error_handling=True)

train_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=train_dataset)
test_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=test_dataset)

### Define latent distribution of the model

In [21]:
# Latent distribution will be joint distribution of 10 gaussian normal distributions
# and one 10 dimensional Gumbel Softmax distribution
latent_spec = {'cont': 20, 'disc': [10, 10, 10]}

### Build a model

In [22]:
from jointvae.models_256_convjump2bn import VAE

model = VAE(latent_spec=latent_spec, hidden_dim=1024, img_size=(3, 256, 256), use_cuda=use_cuda)

In [23]:
#print(model)

### Train the model

In [24]:
from torch import optim

# Build optimizer
optimizer = optim.Adam(model.parameters(), lr=8e-4, amsgrad=True) # added amsgrad # orig lr 5e-4

In [25]:
from jointvae.training import Trainer
#from jointvae.training_debug import Trainer

# Define the capacities
# Continuous channels
cont_capacity = [0.0, 7, 40000, 36.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0
# Discrete channels
disc_capacity = [0.0, 7, 40000, 36.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0

# Build a trainer
trainer = Trainer(model, optimizer,
                  cont_capacity=cont_capacity,
                  disc_capacity=disc_capacity,
                 use_cuda=use_cuda)

#### Initialize visualizer

In [26]:
#from viz.visualize import Visualizer
from viz.visualize import Visualizer

viz = Visualizer(model)
viz.save_images = False # needed to add this so it returns a tensor

In [None]:
# Note this should be at least 100 epochs for proper training or more but can be less to demo

trainer.train(train_loader, epochs=152, save_training_gif=None)

0/190000	Loss: 45369.492
25600/190000	Loss: 36886.453


In [28]:
#!ls /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/1/9/

In [None]:
print('Done Training')

In [None]:
print("Testing system")

In [None]:
save_image_path = "sample_images/home/256"
model_name = "cj2_e152_b512_c20d10-10-10_cmax7_gam36"

### Visualize

In [None]:
# Plot reconstructions
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Get a batch of data
for batch, labels in test_loader:
    break

# Reconstruct data using Joint-VAE model
recon = viz.reconstructions(batch)

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(recon.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_reconstructions.png",dpi=200)

In [None]:
# Plot samples
samples = viz.samples()

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(samples.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_samples.png",dpi=200)

### Note on Traversals

### Traverses all latent dimensions one by one and plots a grid of images where each row corresponds to a latent traversal of one latent dimension
* size: Number of samples for each latent traversal.

In [None]:
print(viz.model.latent_cont_dim, viz.model.num_disc_latents)

In [None]:
# Plot all traversals
traversals = viz.all_latent_traversals(size=20)

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_all_traversals_n20.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_traversals2100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=1, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(traversals.numpy()[0, :, :]);
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_traversals1100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=9, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_traversals9100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=1, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_traversals2101.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=0, disc_idx=0, disc_axis=1, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_traversals2001.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=3, cont_axis=0, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_traversals3000.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=12, cont_axis=0, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig(save_image_path + model_name + "_traversals3000.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz2.latent_traversal_grid2(cont_idx=3, cont_axis=0, disc_idx=0, disc_axis=0, size=(20, 20),first_n=10)

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
#plt.savefig("sample_images/306/256/cj1_256_e50maybe_b64_c10d10-10-10_gam30_traversals3000.png",dpi=200)

In [None]:
!ls

### Save Model

In [None]:
torch.save(model.state_dict(),"trained_models/" + "statedict_" + model_name) # save state dict
#torch.save(model, model_name) # save full model

In [None]:
print("Done training: ",model_name)

#### Restore Model from State Dict

In [None]:
sd_model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64))
sd_model.load_state_dict(torch.load("statedict_" + model_name))

#### Restore Full Model
* Note in this case the serialized data is bound to the specific classes and exact directory strucutre used.

In [None]:
full_model = torch.load(model_name)

In [None]:
type(full_model)

In [None]:
type(sd_model)

In [None]:
!ls