# Train a JointVAE model

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
from viz.visualize import Visualizer
use_cuda = torch.cuda.is_available()
use_cuda

True

In [3]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # change to your device

#### Prepare data list

In [4]:
!ls data

dress_dresslen_train_test_splits.json	dress_sleeve_train_test_splits.json
dress_sleevelen_train_test_splits.json


In [5]:
import json
with open("./data/dress_dresslen_train_test_splits.json", "r") as infile:
    data_dict = json.load(infile)

In [6]:
data_dict.keys()

dict_keys(['X_train_1', 'y_train_1', 'X_test_1', 'y_test_1', 'X_train_2', 'y_train_2', 'X_test_2', 'y_test_2', 'X_train_3', 'y_train_3', 'X_test_3', 'y_test_3', 'X_train_4', 'y_train_4', 'X_test_4', 'y_test_4', 'X_train_5', 'y_train_5', 'X_test_5', 'y_test_5', 'X_train_6', 'y_train_6', 'X_test_6', 'y_test_6', 'X_train_7', 'y_train_7', 'X_test_7', 'y_test_7', 'X_train_8', 'y_train_8', 'X_test_8', 'y_test_8', 'X_train_9', 'y_train_9', 'X_test_9', 'y_test_9', 'X_train_10', 'y_train_10', 'X_test_10', 'y_test_10'])

In [7]:
data_dict['X_train_1'][:5]

['/2/8/2893552_3773662.jpg',
 '/2/9/2982376_3889235.jpg',
 '/2/7/2783355_3578973.jpg',
 '/2/9/2974380_3918638.jpg',
 '/2/8/2886740_3675612.jpg']

#### Create list of image paths

In [8]:
!python -V

Python 3.6.5


In [9]:
image_paths_train = []
image_paths_test = []

root_data_dir = "/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables"

for key, val in data_dict.items():
    if 'X_train' in key:
        image_paths_train.extend([root_data_dir + imgpath for imgpath in val])
    elif 'X_test' in key:
        image_paths_test.extend([root_data_dir + imgpath for imgpath in val])

print(f"Number of train image paths: {len(image_paths_train):,d}")
print(f"Number of test image paths: {len(image_paths_test):,d}")
print()
print("Sample paths:")
print(image_paths_train[0])
print(image_paths_train[-1])
print(image_paths_test[0])
print(image_paths_test[-1])

Number of train image paths: 167,742
Number of test image paths: 18,638

Sample paths:
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/8/2893552_3773662.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/6/2683298_3676896.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/4/2431229_3158108.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/5/2569101_3223742.jpg


In [10]:
#from utils.dataloaders import get_mnist_dataloaders, get_fashion_mnist_dataloaders
#train_loader, test_loader = get_mnist_dataloaders(batch_size=64)
#train_loader, test_loader = get_fashion_mnist_dataloaders(batch_size=64)

In [11]:
from torchvision import transforms
from utils.dataloaders_custom import get_imagelist_dataloader, ImageListDataset

#composed = transforms.Compose([transforms.Resize((260,260)), transforms.ToTensor()])
composed = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])

train_dataset = ImageListDataset(image_paths_train, transform=composed)
test_dataset = ImageListDataset(image_paths_test, transform=composed)

train_loader = get_imagelist_dataloader(batch_size=20, dataset_object=train_dataset)
test_loader = get_imagelist_dataloader(batch_size=20, dataset_object=test_dataset)

### Define latent distribution of the model

In [12]:
# Latent distribution will be joint distribution of 10 gaussian normal distributions
# and one 10 dimensional Gumbel Softmax distribution
latent_spec = {'cont': 10,
               'disc': [10]}

### Build a model

In [13]:
#from jointvae.models_v1 import VAE
from jointvae.models import VAE

#model = VAE(latent_spec=latent_spec, img_size=(3, 260, 260), use_cuda=use_cuda)
model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64), use_cuda=use_cuda)

In [14]:
print(model)

VAE(
  (img_to_features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
  )
  (features_to_hidden): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
  )
  (fc_mean): Linear(in_features=256, out_features=10, bias=True)
  (fc_log_var): Linear(in_features=256, out_features=10, bias=True)
  (fc_alphas): ModuleList(
    (0): Linear(in_features=256, out_features=10, bias=True)
  )
  (latent_to_features): Sequential(
    (0): Linear(in_features=20, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=1024, bias=True)
    (3): ReLU()
  )
  (features_to_img): Sequential(
    (0): ConvTranspose

### Train the model

In [15]:
from torch import optim

# Build optimizer
optimizer = optim.Adam(model.parameters(), lr=3e-4, amsgrad=True) # added amsgrad # orig lr 5e-4

In [16]:
from jointvae.training import Trainer

# Define the capacities
# Continuous channels
cont_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0
# Discrete channels
disc_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0

# Build a trainer
trainer = Trainer(model, optimizer,
                  cont_capacity=cont_capacity,
                  disc_capacity=disc_capacity,
                 use_cuda=use_cuda)

#### Initialize visualizer

In [17]:
# Build a visualizer which will be passed to trainer to visualize progress during training
viz = Visualizer(model)

In [None]:
# Train model for 10 epochs
# Note this should really be a 100 epochs and trained on a GPU, but this is just to demo

trainer.train(train_loader, epochs=10, save_training_gif=('./training_rd1.gif', viz))

0/167742	Loss: 2836.089
1000/167742	Loss: 2435.052
2000/167742	Loss: 1539.812
3000/167742	Loss: 1419.947
4000/167742	Loss: 1387.930
5000/167742	Loss: 1355.027
6000/167742	Loss: 1350.809
7000/167742	Loss: 1339.796
8000/167742	Loss: 1367.289
9000/167742	Loss: 1345.060
10000/167742	Loss: 1336.565
11000/167742	Loss: 1332.512
12000/167742	Loss: 1315.887
13000/167742	Loss: 1327.997
14000/167742	Loss: 1331.076
15000/167742	Loss: 1310.714
16000/167742	Loss: 1320.737
17000/167742	Loss: 1313.769
18000/167742	Loss: 1289.622
19000/167742	Loss: 1284.920
20000/167742	Loss: 1285.379
21000/167742	Loss: 1294.884
22000/167742	Loss: 1280.471
23000/167742	Loss: 1280.193
24000/167742	Loss: 1271.653
25000/167742	Loss: 1266.280
26000/167742	Loss: 1287.138
27000/167742	Loss: 1248.366
28000/167742	Loss: 1252.842
29000/167742	Loss: 1242.234
30000/167742	Loss: 1246.277
31000/167742	Loss: 1262.468
32000/167742	Loss: 1247.468
33000/167742	Loss: 1248.933
34000/167742	Loss: 1236.049
35000/167742	Loss: 1232.997
36000

122000/167742	Loss: 1154.110
123000/167742	Loss: 1159.725
124000/167742	Loss: 1156.037
125000/167742	Loss: 1145.645
126000/167742	Loss: 1151.117
127000/167742	Loss: 1158.272
128000/167742	Loss: 1131.610
129000/167742	Loss: 1144.172
130000/167742	Loss: 1144.308
131000/167742	Loss: 1144.943
132000/167742	Loss: 1150.766
133000/167742	Loss: 1156.026
134000/167742	Loss: 1159.668
135000/167742	Loss: 1140.138
136000/167742	Loss: 1152.661
137000/167742	Loss: 1150.571
138000/167742	Loss: 1133.499
139000/167742	Loss: 1160.535
140000/167742	Loss: 1142.486
141000/167742	Loss: 1145.103
142000/167742	Loss: 1146.185
143000/167742	Loss: 1152.941
144000/167742	Loss: 1160.805
145000/167742	Loss: 1132.364
146000/167742	Loss: 1140.348
147000/167742	Loss: 1144.176
148000/167742	Loss: 1148.573
149000/167742	Loss: 1140.276
150000/167742	Loss: 1153.005
151000/167742	Loss: 1144.512
152000/167742	Loss: 1146.774
153000/167742	Loss: 1133.058
154000/167742	Loss: 1142.521
155000/167742	Loss: 1146.512
156000/167742	

### Visualize

In [None]:
# Plot reconstructions
%matplotlib inline
import matplotlib.pyplot as plt

# Get a batch of data
for batch, labels in test_loader:
    break

# Reconstruct data using Joint-VAE model
recon = viz.reconstructions(batch)

plt.figure(figsize=(8,8))
plt.imshow(recon.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot samples
samples = viz.samples()

plt.figure(figsize=(8,8))
plt.imshow(samples.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot all traversals
traversals = viz.all_latent_traversals(size=10)

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=1, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=9, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

### Save Model

In [None]:
model_name = "jvae_fmnist_oct292018.pth"

In [None]:
torch.save(model.state_dict(),"statedict_" + model_name) # save state dict
torch.save(model, model_name) # save full model

#### Restore Model from State Dict

In [None]:
sd_model = VAE(latent_spec=latent_spec, img_size=(1, 32, 32))
sd_model.load_state_dict(torch.load("statedict_" + model_name))

#### Restore Full Model
* Note in this case the serialized data is bound to the specific classes and exact directory strucutre used.

In [None]:
full_model = torch.load(model_name)

In [None]:
type(full_model)

In [None]:
type(sd_model)