# Train a JointVAE model

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import torch

use_cuda = torch.cuda.is_available()
use_cuda

True

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # change to your device

#### Prepare data list

In [4]:
!ls ~/vmldata/raw_source_data/v20181105_third_opinion_v01

third_opinion_v01_data_dictionary.txt  third_opinion_v01.h5


In [5]:
!ls ~/vmldata/raw_source_data/v20181105_third_opinion_v01/third_opinion_v01.h5

/home/jovyan/vmldata/raw_source_data/v20181105_third_opinion_v01/third_opinion_v01.h5


#### Create list of image paths

In [6]:
import pandas as pd
path = '/home/jovyan/vmldata/raw_source_data/v20181105_third_opinion_v01/third_opinion_v01.h5'
third = pd.read_hdf(path)

In [7]:
third.head()

Unnamed: 0,path,DSL_label,DL_label
0,/1/5/1533245_2188111.jpg,2,1
1,/2/4/2403009_2991265.jpg,2,1
2,/2/8/2814416_3549033.jpg,0,3
3,/2/6/2671073_3328473.jpg,3,0
4,/2/8/2825395_3667369.jpg,2,1


In [8]:
labs = '/home/jovyan/vmldata/raw_source_data/v20181105_third_opinion_v01/third_opinion_v01_data_dictionary.txt'
with open(labs) as labels_dict:
    for line in labels_dict:
        print(line.strip())

third_opinion_v01.h5 has three columns:

path:       path to the file, should be attempted from the largest dataset as of Nov 5, 2018.
ie: /vmldata/raw_source_data/v20180810_all_wearables

DSL_label:  Dress Sleeve Length
Value:  Meaning:       Definition:
0:      No sleeve    - armless dress with nothing sticking out that appears to curve down towards shoulder
1:      short sleeve - sleeveless with shoulder that curves down towards arm to elbow length
2:      3/4 sleeve   - elbow length to approximately wrist length
3:      long         - from wrist length to anything longer

DL_label:   Dress Length
Value:  Meaning:       Definition:
0:      Short        - Shorter than clearly above the knee area
1:      KneeLength   - Hem of dress is at knee height (all of knee cap/tendon area)
2:      Midi         - Below knee lower tendon to approximately top of ankle area
3:      Long         - Anything touching the ankles or longer
4:      HiLo         - Normally any dress (regardless of length) 

In [9]:
longish = third.query('DL_label > 0 and DL_label < 4')
ls = longish.query('DSL_label > 1')
midiknee = third.query('DL_label == 1 or DL_label == 2')

In [10]:
print(len(third), len(longish), len(ls), len(midiknee))

10775 6350 3961 4596


In [11]:
list(ls['path'][:5])

['/1/5/1533245_2188111.jpg',
 '/2/4/2403009_2991265.jpg',
 '/2/8/2825395_3667369.jpg',
 '/2/9/2979794_3871671.jpg',
 '/2/5/2596552_3418185.jpg']

#### Split into train and test set

In [14]:
#image_paths_train = loadable_dresses[:8000]
#image_paths_test = loadable_dresses[9000:11000]

rootpath = '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables'
datachoice = midiknee
ttdata = [rootpath + x for x in datachoice['path']]

image_paths_train = ttdata[:3800]
image_paths_test = ttdata[3910:-6]

print(f"Number of train image paths: {len(image_paths_train):,d}")
print(f"Number of test image paths: {len(image_paths_test):,d}")
print()
print("Sample paths:")
print(image_paths_train[0])
print(image_paths_train[-1])
print(image_paths_test[0])
print(image_paths_test[-1])

Number of train image paths: 3,800
Number of test image paths: 680

Sample paths:
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/1/5/1533245_2188111.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/4/4/4491038_8306711.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/3/2346358_2947133.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/5/1/5164455_8194878.jpg


In [15]:
del third
del longish
del ls
del midiknee

In [16]:
from torchvision import transforms
#from utils.dataloaders_custom import get_imagelist_dataloader, ImageListDataset
from utils.dataloader_tools import get_imagelist_dataloader, ImageListDataset

BATCH_SIZE = 5

composed = transforms.Compose([transforms.CenterCrop((256,256)),transforms.Resize((256,256)),transforms.ToTensor()])

# convert rgb is for the cv2 loaded images that I've got in this dir
conv_rgb = False
error_handling = True
train_dataset = ImageListDataset(image_paths_train, cut_from='top', cut_amount=256, 
                                 transform=composed, convert_rgb=conv_rgb, error_handling=True)
test_dataset = ImageListDataset(image_paths_test, cut_from='top', cut_amount=256, 
                                 transform=composed, convert_rgb=conv_rgb, error_handling=True)

train_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=train_dataset)
test_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=test_dataset)

### Define latent distribution of the model

In [17]:
# Latent distribution will be joint distribution of 10 gaussian normal distributions
# and one 10 dimensional Gumbel Softmax distribution
latent_spec = {'cont': 5,
               'disc': [5]}

### Build a model

In [18]:
from jointvae.models_256_xstyle_int_nd import VAE

model = VAE(latent_spec=latent_spec, img_size=(3, 256, 256), use_cuda=use_cuda)

In [19]:
#print(model)

### Train the model

In [20]:
from torch import optim

# Build optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4, amsgrad=True) # added amsgrad # orig lr 5e-4

In [21]:
from jointvae.training import Trainer
#from jointvae.training_debug import Trainer

# Define the capacities
# Continuous channels
cont_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0
# Discrete channels
disc_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0

# Build a trainer
trainer = Trainer(model, optimizer,
                  cont_capacity=cont_capacity,
                  disc_capacity=disc_capacity,
                 use_cuda=use_cuda)

RuntimeError: CUDA error: out of memory

#### Initialize visualizer

In [24]:
#from viz.visualize import Visualizer
from viz.visualize import Visualizer

viz = Visualizer(model)
viz.save_images = False # needed to add this so it returns a tensor

In [25]:
# Train model for 10 epochs
# Note this should really be a 100 epochs and trained on a GPU, but this is just to demo

trainer.train(train_loader, epochs=40, save_training_gif=None)

NameError: name 'trainer' is not defined

In [None]:
print('hi')

In [None]:
# Get a batch of data
for batch, labels in test_loader:
    break
print("batch: ",type(batch),batch.shape)
type(viz.reconstructions(batch))

### Visualize

In [None]:
# Plot reconstructions
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Get a batch of data
for batch, labels in test_loader:
    break

# Reconstruct data using Joint-VAE model
recon = viz.reconstructions(batch)

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(recon.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_reconstructions.png",dpi=200)

In [None]:
# Plot samples
samples = viz.samples()

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(samples.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_samples.png",dpi=200)

#### Traverses all latent dimensions one by one and plots a grid of images where each row corresponds to a latent traversal of one latent dimension

In [None]:
# Plot all traversals
traversals = viz.all_latent_traversals(size=10)

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_all_traversals_n20.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_traversals2100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=1, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(traversals.numpy()[0, :, :]);
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_traversals1100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=9, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_traversals9100.png",dpi=200)

In [None]:
!ls

### Save Model

In [None]:
model_name = "xint256_e600_b100_c10d10_cap48_30k_gam32.pth"

In [None]:
torch.save(model.state_dict(),"trained_models" + "statedict_" + model_name) # save state dict
#torch.save(model, model_name) # save full model

In [None]:
print("Done training: ",model_name)

#### Restore Model from State Dict

In [None]:
sd_model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64))
sd_model.load_state_dict(torch.load("statedict_" + model_name))

#### Restore Full Model
* Note in this case the serialized data is bound to the specific classes and exact directory strucutre used.

In [None]:
full_model = torch.load(model_name)

In [None]:
type(full_model)

In [None]:
type(sd_model)

In [None]:
!ls