# Train a JointVAE model

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import torch
from viz.visualize import Visualizer
use_cuda = torch.cuda.is_available()
use_cuda

True

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # change to your device

#### Prepare data list

In [4]:
!ls data

dress_dresslen_train_test_splits.json	dress_sleeve_train_test_splits.json
dress_sleevelen_train_test_splits.json	loadable_women_primary_dress.csv


#### Create list of image paths

In [5]:
loadable_dresses = list(np.loadtxt('data/loadable_women_primary_dress.csv',delimiter=',',skiprows=1,dtype='str'))

In [None]:
print(len(loadable_dresses))

In [8]:
loadable_dresses[:5]

['/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/4/6418008_9882769.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/6/6627534_9864695.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/7/6772508_9949243.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/7/6758001_9588597.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/6/6637527_9387666.jpg']

In [9]:
bad_data = ['/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/7/2/723739_1342692.jpg']
for i, bad in enumerate(bad_data):
    if bad in loadable_dresses:
        del loadable_dresses[i]
        
print(len(loadable_dresses))

102124


#### Split into train and test set

In [12]:
#image_paths_train = loadable_dresses[:int(len(loadable_dresses)*0.87)]
#image_paths_test = loadable_dresses[int(len(loadable_dresses)*0.87):]
image_paths_train = loadable_dresses[:88800]
image_paths_test = loadable_dresses[88800:-124]


print(f"Number of train image paths: {len(image_paths_train):,d}")
print(f"Number of test image paths: {len(image_paths_test):,d}")
print()
print("Sample paths:")
print(image_paths_train[0])
print(image_paths_train[-1])
print(image_paths_test[0])
print(image_paths_test[-1])

Number of train image paths: 88,800
Number of test image paths: 13,200

Sample paths:
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/6/6627534_9864695.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/5/2/529420_884192.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/5/2/527981_1034607.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/3/4/341494_806784.jpg


In [13]:
#from utils.dataloaders import get_mnist_dataloaders, get_fashion_mnist_dataloaders
#train_loader, test_loader = get_mnist_dataloaders(batch_size=64)
#train_loader, test_loader = get_fashion_mnist_dataloaders(batch_size=64)

In [14]:
from torchvision import transforms
#from utils.dataloaders_custom import get_imagelist_dataloader, ImageListDataset
from utils.dataloader_tools import get_imagelist_dataloader, ImageListDataset

BATCH_SIZE = 200

composed = transforms.Compose([transforms.CenterCrop((90,90)),transforms.Resize((64,64)),transforms.ToTensor()])

train_dataset = ImageListDataset(image_paths_train, cut_from='top', cut_amount=90, transform=composed)
test_dataset = ImageListDataset(image_paths_test, cut_from='top', cut_amount=90, transform=composed)

train_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=train_dataset)
test_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=test_dataset)

### Define latent distribution of the model

In [15]:
# Latent distribution will be joint distribution of 10 gaussian normal distributions
# and one 10 dimensional Gumbel Softmax distribution
latent_spec = {'cont': 5,
               'disc': [5]}

### Build a model

In [16]:
#from jointvae.models_v1 import VAE
#from jointvae.models import VAE
from jointvae.models_64_xstyle_int_nd import VAE

#model = VAE(latent_spec=latent_spec, img_size=(3, 260, 260), use_cuda=use_cuda)
model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64), use_cuda=use_cuda)

In [17]:
#print(model)

### Train the model

In [18]:
from torch import optim

# Build optimizer
optimizer = optim.Adam(model.parameters(), lr=8e-4, amsgrad=True) # added amsgrad # orig lr 5e-4

In [19]:
from jointvae.training import Trainer

# Define the capacities
# Continuous channels
cont_capacity = [0.0, 5.0, 25000, 20.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0
# Discrete channels
disc_capacity = [0.0, 5.0, 25000, 20.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0

# Build a trainer
trainer = Trainer(model, optimizer,
                  cont_capacity=cont_capacity,
                  disc_capacity=disc_capacity,
                 use_cuda=use_cuda)

#### Initialize visualizer

In [20]:
# Build a visualizer which will be passed to trainer to visualize progress during training
viz = Visualizer(model)

In [None]:
# Train model for 10 epochs
# Note this should really be a 100 epochs and trained on a GPU, but this is just to demo

trainer.train(train_loader, epochs=60, save_training_gif=None)

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
0/88800	Loss: 2835.160
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
30000/88800	Loss: 2063.940
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])
recon_batch shape in _train_iteration:  torch.Size([200, 3, 64, 64])
data shape in _train_iteration:  torch.Size([200, 3, 64, 64])
features shape:  torch.Size([200, 64, 4, 4])
featu

In [22]:
print('hi')

hi


### Visualize

In [23]:
# Plot reconstructions
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Get a batch of data
for batch, labels in test_loader:
    break

# Reconstruct data using Joint-VAE model
recon = viz.reconstructions(batch)

plt.figure(figsize=(26,26))
#np.transpose(recon.numpy(), (2,1,0))
plt.imshow(np.rot90(np.transpose(recon.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/64/croptotop_64x64_e60_b200_c5d5_gam20_reconstructions.png",dpi=200)

features shape:  torch.Size([200, 64, 4, 4])
features view shape:  torch.Size([200, 1024])


AttributeError: 'NoneType' object has no attribute 'numpy'

<Figure size 1872x1872 with 0 Axes>

In [None]:
# Plot samples
samples = viz.samples()

plt.figure(figsize=(26,26))
#plt.imshow(samples.numpy()[0, :, :],cmap='gray');
plt.imshow(np.rot90(np.transpose(samples.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/64/croptotop_64x64_e60_b200_c5d5_gam20_samples.png",dpi=200)

#### Traverses all latent dimensions one by one and plots a grid of images where each row corresponds to a latent traversal of one latent dimension

In [None]:
# Plot all traversals
traversals = viz.all_latent_traversals(size=20)

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/64/croptotop_64x64_e60_b200_c5d5_gam20_all_traversals_n20.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/64/croptotop_64x64_e60_b200_c5d5_gam20.png_traversals2100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=1, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(16,16))
plt.imshow(traversals.numpy()[0, :, :]);
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/64/croptotop_64x64_e60_b200_c5d5_gam20_traversals1100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=9, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(16,16))
#plt.imshow(traversals.numpy()[0, :, :], cmap='gray');
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/64/croptotop_64x64_e60_b200_c5d5_gam20_traversals9100.png",dpi=200)

In [34]:
!ls

data
dataloading_pytorch_test.ipynb
git_dump_all_versions_of_a_file.sh
imgs
jointvae
jvae_fmnist_oct292018.pth
latent_traversals.py
load_model.ipynb
main.py
model_size_debug.txt
__pycache__
RandomUtilsandTests.ipynb
rd64x64_e200_b20_308.pth
README.md
realdata64x64_e10_b20.pth
requirements.txt
sample_images
statedict_jvae_fmnist_oct292018.pth
statedict_rd64x64_e200_b20_308.pth
statedict_realdata64x64_e10_b20.pth
trained_models
training.gif
TrainingNotebooks
training_rd1.gif
training_rd_308_64_200e_v1.gif
train_model.ipynb
train_model_loadable_64_croptotop.ipynb
train_model_loadable_64_xstyle_croptotop.ipynb
train_model_old.ipynb
train_model_realdata_305_64.ipynb
train_model_realdata_306_260_v1.ipynb
train_model_realdata_308_64.ipynb
train_model_realdata_64_v1.ipynb
train_model_realdata_64_v2.ipynb
train_model_realdata_64_v3.ipynb
train_rd_305_old.ipynb
utils
viz


### Save Model

In [None]:
model_name = "croptotop_64x64_e60_b200_c5d5_gam20.pth"

In [None]:
torch.save(model.state_dict(),"trained_models" + "statedict_" + model_name) # save state dict
#torch.save(model, model_name) # save full model

In [None]:
print("Done training: ",model_name)

#### Restore Model from State Dict

In [28]:
sd_model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64))
sd_model.load_state_dict(torch.load("statedict_" + model_name))

#### Restore Full Model
* Note in this case the serialized data is bound to the specific classes and exact directory strucutre used.

In [None]:
full_model = torch.load(model_name)

In [None]:
type(full_model)

In [None]:
type(sd_model)

In [36]:
!ls

data				requirements.txt
dataloading_pytorch_test.ipynb	statedict_jvae_fmnist_oct292018.pth
imgs				trained_models
jointvae			training.gif
jvae_fmnist_oct292018.pth	training_rd1.gif
latent_traversals.py		train_model.ipynb
load_model.ipynb		train_model_realdata_306_260_v1.ipynb
main.py				train_model_realdata_64_v1.ipynb
__pycache__			utils
RandomUtilsandTests.ipynb	viz
README.md
