# Train a JointVAE model

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import numpy as np
import torch

use_cuda = torch.cuda.is_available()
use_cuda

True

In [4]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # change to your device

#### Prepare data list

In [5]:
!ls data

dress_dresslen_train_test_splits.json	dress_sleeve_train_test_splits.json
dress_sleevelen_train_test_splits.json	loadable_women_primary_dress.csv


#### Create list of image paths

In [6]:
loadable_dresses = list(np.loadtxt('data/loadable_women_primary_dress.csv',delimiter=',',skiprows=1,dtype='str'))

In [7]:
print(len(loadable_dresses))

102125


In [8]:
loadable_dresses[:5]

['/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/4/6418008_9882769.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/6/6627534_9864695.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/7/6772508_9949243.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/7/6758001_9588597.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/6/6637527_9387666.jpg']

In [9]:
bad_data = ['/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/7/2/723739_1342692.jpg']

for bad in bad_data:
    if bad in loadable_dresses:
        del loadable_dresses[loadable_dresses.index(bad)]
        
print(len(loadable_dresses))

102124


#### Split into train and test set

In [10]:
image_paths_train = loadable_dresses[:88800]
image_paths_test = loadable_dresses[88800:-124]

print(f"Number of train image paths: {len(image_paths_train):,d}")
print(f"Number of test image paths: {len(image_paths_test):,d}")
print()
print("Sample paths:")
print(image_paths_train[0])
print(image_paths_train[-1])
print(image_paths_test[0])
print(image_paths_test[-1])

Number of train image paths: 88,800
Number of test image paths: 13,200

Sample paths:
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/4/6418008_9882769.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/5/2/529420_884192.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/5/2/527981_1034607.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/3/4/341494_806784.jpg


In [11]:
from torchvision import transforms
from utils.dataloader_tools import get_imagelist_dataloader, ImageListDataset

BATCH_SIZE = 64

composed = transforms.Compose([transforms.CenterCrop((128,128)),transforms.Resize((128,128)),transforms.ToTensor()])

# convert rgb is for the cv2 loaded images that I've got in this dir
train_dataset = ImageListDataset(image_paths_train, cut_from='top', cut_amount=128, transform=composed, convert_rgb=False)
test_dataset = ImageListDataset(image_paths_test, cut_from='top', cut_amount=128, transform=composed, convert_rgb=False)

train_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=train_dataset)
test_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=test_dataset)

### Define latent distribution of the model

In [12]:
# Latent distribution will be joint distribution of 10 gaussian normal distributions
# and one 10 dimensional Gumbel Softmax distribution
latent_spec = {'cont': 10,
               'disc': [10]}

### Build a model

In [13]:
!ls jointvae

decoder_scratchpad.py  models_128_v4.py			models_64_xstyle.py
encoder_scratchpad.py  models_128_xstyle_int_debug.py	models.py
__init__.py	       models_128_xstyle_int_nd.py	__pycache__
models_128_v1.py       models_64_xstyle_finished_v1.py	training_debug.py
models_128_v2.py       models_64_xstyle_int_debug.py	training.py
models_128_v3.py       models_64_xstyle_int_nd.py


In [14]:
from jointvae.models_128_xstyle_int_nd import VAE

#model = VAE(latent_spec=latent_spec, img_size=(3, 260, 260), use_cuda=use_cuda)
model = VAE(latent_spec=latent_spec, img_size=(3, 128, 128), use_cuda=use_cuda)

In [15]:
#print(model)

### Train the model

In [16]:
from torch import optim

# Build optimizer
optimizer = optim.Adam(model.parameters(), lr=3e-4, amsgrad=True) # added amsgrad # orig lr 5e-4

In [17]:
from jointvae.training import Trainer
#from jointvae.training_debug import Trainer

# Define the capacities
# Continuous channels
cont_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0
# Discrete channels
disc_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0

# Build a trainer
trainer = Trainer(model, optimizer,
                  cont_capacity=cont_capacity,
                  disc_capacity=disc_capacity,
                 use_cuda=use_cuda)

#### Initialize visualizer

In [18]:
#from viz.visualize import Visualizer
from viz.visualize import Visualizer

viz = Visualizer(model)
viz.save_images = False # needed to add this so it returns a tensor

In [19]:
# Note this should be at least 100 epochs for proper training or more but can be less to demo

trainer.train(train_loader, epochs=150, save_training_gif=None)

0/88800	Loss: 11433.810
3200/88800	Loss: 10633.689
6400/88800	Loss: 7550.398
9600/88800	Loss: 7073.791
12800/88800	Loss: 6881.928
16000/88800	Loss: 6806.055
19200/88800	Loss: 6758.876
22400/88800	Loss: 6770.698
25600/88800	Loss: 6679.411
28800/88800	Loss: 6699.721
32000/88800	Loss: 6664.820
35200/88800	Loss: 6573.270
38400/88800	Loss: 6563.874
41600/88800	Loss: 6535.553
44800/88800	Loss: 6511.978
48000/88800	Loss: 6422.692
51200/88800	Loss: 6335.790
54400/88800	Loss: 6217.512
57600/88800	Loss: 6204.785
60800/88800	Loss: 6135.138
64000/88800	Loss: 6064.215
67200/88800	Loss: 5996.090
70400/88800	Loss: 5988.914
73600/88800	Loss: 5946.626
76800/88800	Loss: 5956.838
80000/88800	Loss: 5947.336
83200/88800	Loss: 5914.032
86400/88800	Loss: 5946.711
Epoch: 1 Average loss: 6572.67
0/88800	Loss: 5977.394
3200/88800	Loss: 5890.867
6400/88800	Loss: 5869.566
9600/88800	Loss: 5917.747
12800/88800	Loss: 5859.091
16000/88800	Loss: 5884.583
19200/88800	Loss: 5898.271
22400/88800	Loss: 5873.001
25600/888

48000/88800	Loss: 5535.333
51200/88800	Loss: 5498.031
54400/88800	Loss: 5497.198
57600/88800	Loss: 5478.769
60800/88800	Loss: 5513.482
64000/88800	Loss: 5544.861
67200/88800	Loss: 5496.678
70400/88800	Loss: 5507.302
73600/88800	Loss: 5518.891
76800/88800	Loss: 5529.847
80000/88800	Loss: 5552.939
83200/88800	Loss: 5535.823
86400/88800	Loss: 5483.239
Epoch: 11 Average loss: 5525.45
0/88800	Loss: 5489.413
3200/88800	Loss: 5549.880
6400/88800	Loss: 5493.918
9600/88800	Loss: 5513.130
12800/88800	Loss: 5495.422
16000/88800	Loss: 5511.846
19200/88800	Loss: 5478.972
22400/88800	Loss: 5512.640
25600/88800	Loss: 5513.185
28800/88800	Loss: 5493.357
32000/88800	Loss: 5514.269
35200/88800	Loss: 5512.717
38400/88800	Loss: 5512.430
41600/88800	Loss: 5516.633
44800/88800	Loss: 5496.525
48000/88800	Loss: 5506.286
51200/88800	Loss: 5437.195
54400/88800	Loss: 5481.016
57600/88800	Loss: 5435.950
60800/88800	Loss: 5523.652
64000/88800	Loss: 5517.812
67200/88800	Loss: 5482.331
70400/88800	Loss: 5506.591
736

0/88800	Loss: 5489.768
3200/88800	Loss: 5413.013
6400/88800	Loss: 5399.377
9600/88800	Loss: 5388.072
12800/88800	Loss: 5389.467
16000/88800	Loss: 5394.927
19200/88800	Loss: 5447.896
22400/88800	Loss: 5383.110
25600/88800	Loss: 5403.263
28800/88800	Loss: 5391.687
32000/88800	Loss: 5398.467
35200/88800	Loss: 5394.406
38400/88800	Loss: 5385.727
41600/88800	Loss: 5406.673
44800/88800	Loss: 5379.280
48000/88800	Loss: 5402.166
51200/88800	Loss: 5438.039
54400/88800	Loss: 5415.748
57600/88800	Loss: 5423.251
60800/88800	Loss: 5396.529
64000/88800	Loss: 5368.751
67200/88800	Loss: 5354.843
70400/88800	Loss: 5427.904
73600/88800	Loss: 5359.076
76800/88800	Loss: 5413.992
80000/88800	Loss: 5395.009
83200/88800	Loss: 5396.696
86400/88800	Loss: 5357.860
Epoch: 22 Average loss: 5399.84
0/88800	Loss: 5353.375
3200/88800	Loss: 5376.517
6400/88800	Loss: 5405.794
9600/88800	Loss: 5394.818
12800/88800	Loss: 5397.762
16000/88800	Loss: 5387.730
19200/88800	Loss: 5405.991
22400/88800	Loss: 5372.112
25600/8880

48000/88800	Loss: 5378.557
51200/88800	Loss: 5337.331
54400/88800	Loss: 5396.829
57600/88800	Loss: 5379.719
60800/88800	Loss: 5400.147
64000/88800	Loss: 5401.213
67200/88800	Loss: 5345.615
70400/88800	Loss: 5345.777
73600/88800	Loss: 5384.959
76800/88800	Loss: 5395.301
80000/88800	Loss: 5374.313
83200/88800	Loss: 5357.535
86400/88800	Loss: 5394.592
Epoch: 32 Average loss: 5382.30
0/88800	Loss: 5294.605
3200/88800	Loss: 5373.092
6400/88800	Loss: 5364.199
9600/88800	Loss: 5323.983
12800/88800	Loss: 5389.501
16000/88800	Loss: 5370.531
19200/88800	Loss: 5413.781
22400/88800	Loss: 5365.653
25600/88800	Loss: 5399.290
28800/88800	Loss: 5346.460
32000/88800	Loss: 5374.028
35200/88800	Loss: 5371.023
38400/88800	Loss: 5383.812
41600/88800	Loss: 5373.953
44800/88800	Loss: 5406.013
48000/88800	Loss: 5401.960
51200/88800	Loss: 5393.664
54400/88800	Loss: 5317.509
57600/88800	Loss: 5385.149
60800/88800	Loss: 5385.530
64000/88800	Loss: 5405.014
67200/88800	Loss: 5384.501
70400/88800	Loss: 5375.628
736

0/88800	Loss: 5228.882
3200/88800	Loss: 5411.351
6400/88800	Loss: 5401.722
9600/88800	Loss: 5346.475
12800/88800	Loss: 5369.532
16000/88800	Loss: 5384.727
19200/88800	Loss: 5342.320
22400/88800	Loss: 5338.340
25600/88800	Loss: 5385.168
28800/88800	Loss: 5372.803
32000/88800	Loss: 5344.543
35200/88800	Loss: 5366.819
38400/88800	Loss: 5353.181
41600/88800	Loss: 5348.500
44800/88800	Loss: 5380.757
48000/88800	Loss: 5337.602
51200/88800	Loss: 5336.931
54400/88800	Loss: 5385.557
57600/88800	Loss: 5366.687
60800/88800	Loss: 5354.022
64000/88800	Loss: 5366.603
67200/88800	Loss: 5361.871
70400/88800	Loss: 5381.995
73600/88800	Loss: 5370.319
76800/88800	Loss: 5384.885
80000/88800	Loss: 5357.361
83200/88800	Loss: 5389.455
86400/88800	Loss: 5365.790
Epoch: 43 Average loss: 5368.86
0/88800	Loss: 5422.070
3200/88800	Loss: 5437.728
6400/88800	Loss: 5406.865
9600/88800	Loss: 5321.546
12800/88800	Loss: 5378.778
16000/88800	Loss: 5391.958
19200/88800	Loss: 5353.849
22400/88800	Loss: 5378.484
25600/8880

48000/88800	Loss: 5366.199
51200/88800	Loss: 5341.942
54400/88800	Loss: 5320.141
57600/88800	Loss: 5356.890
60800/88800	Loss: 5348.833
64000/88800	Loss: 5371.995
67200/88800	Loss: 5346.173
70400/88800	Loss: 5389.704
73600/88800	Loss: 5354.755
76800/88800	Loss: 5379.103
80000/88800	Loss: 5336.153
83200/88800	Loss: 5347.570
86400/88800	Loss: 5375.398
Epoch: 53 Average loss: 5361.73
0/88800	Loss: 5407.084
3200/88800	Loss: 5394.751
6400/88800	Loss: 5413.090
9600/88800	Loss: 5375.836
12800/88800	Loss: 5373.496
16000/88800	Loss: 5332.226
19200/88800	Loss: 5362.668
22400/88800	Loss: 5282.189
25600/88800	Loss: 5359.977
28800/88800	Loss: 5347.648
32000/88800	Loss: 5353.391
48000/88800	Loss: 5353.276
51200/88800	Loss: 5405.942
54400/88800	Loss: 5364.697
57600/88800	Loss: 5338.459
60800/88800	Loss: 5327.289
64000/88800	Loss: 5338.621
67200/88800	Loss: 5349.163
70400/88800	Loss: 5354.422
73600/88800	Loss: 5356.868
76800/88800	Loss: 5361.518
80000/88800	Loss: 5317.538
83200/88800	Loss: 5357.355
864

48000/88800	Loss: 5332.537
51200/88800	Loss: 5343.247
54400/88800	Loss: 5338.893
57600/88800	Loss: 5330.571
60800/88800	Loss: 5364.001
64000/88800	Loss: 5382.452
67200/88800	Loss: 5349.575
70400/88800	Loss: 5354.338
73600/88800	Loss: 5319.843
76800/88800	Loss: 5383.228
80000/88800	Loss: 5358.496
83200/88800	Loss: 5349.392
86400/88800	Loss: 5366.733
Epoch: 64 Average loss: 5353.72
0/88800	Loss: 5450.370
3200/88800	Loss: 5346.094
6400/88800	Loss: 5316.447
9600/88800	Loss: 5369.782
12800/88800	Loss: 5302.724
16000/88800	Loss: 5358.702
19200/88800	Loss: 5341.956
22400/88800	Loss: 5325.157
25600/88800	Loss: 5348.939
28800/88800	Loss: 5300.593
32000/88800	Loss: 5400.498
35200/88800	Loss: 5382.213
38400/88800	Loss: 5370.907
41600/88800	Loss: 5355.135
44800/88800	Loss: 5375.895
48000/88800	Loss: 5352.733
51200/88800	Loss: 5346.360
54400/88800	Loss: 5348.276
57600/88800	Loss: 5350.106
60800/88800	Loss: 5371.231
64000/88800	Loss: 5332.130
67200/88800	Loss: 5350.174
70400/88800	Loss: 5323.647
736

0/88800	Loss: 5242.719
3200/88800	Loss: 5342.817
6400/88800	Loss: 5333.101
9600/88800	Loss: 5352.634
12800/88800	Loss: 5335.346
16000/88800	Loss: 5359.129
19200/88800	Loss: 5344.385
22400/88800	Loss: 5344.364
25600/88800	Loss: 5340.089
28800/88800	Loss: 5323.556
32000/88800	Loss: 5309.434
35200/88800	Loss: 5381.083
38400/88800	Loss: 5362.251
41600/88800	Loss: 5333.041
44800/88800	Loss: 5362.475
48000/88800	Loss: 5310.977
51200/88800	Loss: 5395.142
54400/88800	Loss: 5363.600
57600/88800	Loss: 5356.528
60800/88800	Loss: 5330.789
64000/88800	Loss: 5370.723
67200/88800	Loss: 5322.192
70400/88800	Loss: 5329.305
73600/88800	Loss: 5337.247
76800/88800	Loss: 5341.082
80000/88800	Loss: 5366.508
83200/88800	Loss: 5323.649
86400/88800	Loss: 5366.132
Epoch: 75 Average loss: 5347.55
0/88800	Loss: 5182.687
3200/88800	Loss: 5353.141
6400/88800	Loss: 5412.255
9600/88800	Loss: 5367.709
12800/88800	Loss: 5328.940
16000/88800	Loss: 5327.088
19200/88800	Loss: 5349.444
22400/88800	Loss: 5329.559
25600/8880

48000/88800	Loss: 5333.228
51200/88800	Loss: 5353.731
54400/88800	Loss: 5333.493
57600/88800	Loss: 5349.228
60800/88800	Loss: 5310.190
64000/88800	Loss: 5304.732
67200/88800	Loss: 5315.559
70400/88800	Loss: 5336.008
73600/88800	Loss: 5358.639
76800/88800	Loss: 5336.760
80000/88800	Loss: 5327.087
83200/88800	Loss: 5370.964
86400/88800	Loss: 5349.748
Epoch: 85 Average loss: 5343.91
0/88800	Loss: 5544.281
3200/88800	Loss: 5330.093
6400/88800	Loss: 5318.239
9600/88800	Loss: 5316.202
12800/88800	Loss: 5348.148
16000/88800	Loss: 5352.627
19200/88800	Loss: 5315.973
22400/88800	Loss: 5349.214
25600/88800	Loss: 5308.656
28800/88800	Loss: 5358.636
32000/88800	Loss: 5379.210
35200/88800	Loss: 5344.135
38400/88800	Loss: 5313.132
41600/88800	Loss: 5341.653
44800/88800	Loss: 5336.238
48000/88800	Loss: 5363.424
51200/88800	Loss: 5320.129
54400/88800	Loss: 5342.234
57600/88800	Loss: 5316.141
60800/88800	Loss: 5347.027
64000/88800	Loss: 5374.916
67200/88800	Loss: 5308.277
70400/88800	Loss: 5338.101
736

0/88800	Loss: 5519.539
3200/88800	Loss: 5340.968
6400/88800	Loss: 5344.765
9600/88800	Loss: 5334.276
12800/88800	Loss: 5323.898
16000/88800	Loss: 5315.916
19200/88800	Loss: 5324.202
22400/88800	Loss: 5381.001
25600/88800	Loss: 5314.939
28800/88800	Loss: 5313.137
32000/88800	Loss: 5344.729
35200/88800	Loss: 5343.219
38400/88800	Loss: 5356.769
41600/88800	Loss: 5325.000
44800/88800	Loss: 5310.414
48000/88800	Loss: 5337.636
51200/88800	Loss: 5355.212
54400/88800	Loss: 5342.859
57600/88800	Loss: 5307.640
60800/88800	Loss: 5352.467
64000/88800	Loss: 5360.574
67200/88800	Loss: 5339.115
70400/88800	Loss: 5299.497
73600/88800	Loss: 5332.536
76800/88800	Loss: 5316.329
80000/88800	Loss: 5373.697
83200/88800	Loss: 5358.924
86400/88800	Loss: 5346.303
Epoch: 96 Average loss: 5338.48
0/88800	Loss: 5347.262
3200/88800	Loss: 5305.453
6400/88800	Loss: 5308.920
9600/88800	Loss: 5374.367
12800/88800	Loss: 5330.031
16000/88800	Loss: 5358.514
19200/88800	Loss: 5357.053
22400/88800	Loss: 5329.095
25600/8880

48000/88800	Loss: 5341.101
51200/88800	Loss: 5286.689
54400/88800	Loss: 5367.908
57600/88800	Loss: 5344.688
60800/88800	Loss: 5385.551
64000/88800	Loss: 5322.651
67200/88800	Loss: 5334.025
70400/88800	Loss: 5289.842
73600/88800	Loss: 5306.209
76800/88800	Loss: 5324.103
80000/88800	Loss: 5322.823
83200/88800	Loss: 5338.577
86400/88800	Loss: 5305.740
Epoch: 106 Average loss: 5334.61
0/88800	Loss: 5351.843
19200/88800	Loss: 5333.986
22400/88800	Loss: 5300.410
25600/88800	Loss: 5324.292
28800/88800	Loss: 5357.484
32000/88800	Loss: 5355.992
35200/88800	Loss: 5354.856
38400/88800	Loss: 5366.813
41600/88800	Loss: 5342.406
44800/88800	Loss: 5317.268
48000/88800	Loss: 5315.185
51200/88800	Loss: 5280.863
54400/88800	Loss: 5334.861
57600/88800	Loss: 5339.290
60800/88800	Loss: 5279.512
64000/88800	Loss: 5285.740
67200/88800	Loss: 5343.482
70400/88800	Loss: 5355.848
73600/88800	Loss: 5349.215
76800/88800	Loss: 5350.946
80000/88800	Loss: 5363.813
83200/88800	Loss: 5315.977
86400/88800	Loss: 5316.588

16000/88800	Loss: 5333.855
19200/88800	Loss: 5345.670
22400/88800	Loss: 5322.736
25600/88800	Loss: 5377.337
28800/88800	Loss: 5357.715
32000/88800	Loss: 5317.743
35200/88800	Loss: 5338.020
38400/88800	Loss: 5327.151
41600/88800	Loss: 5302.781
44800/88800	Loss: 5312.041
48000/88800	Loss: 5330.548
51200/88800	Loss: 5365.857
54400/88800	Loss: 5322.622
57600/88800	Loss: 5317.677
60800/88800	Loss: 5332.847
64000/88800	Loss: 5340.017
67200/88800	Loss: 5326.597
70400/88800	Loss: 5332.160
73600/88800	Loss: 5288.420
76800/88800	Loss: 5343.619
80000/88800	Loss: 5323.612
83200/88800	Loss: 5328.595
86400/88800	Loss: 5326.372
Epoch: 117 Average loss: 5332.32
0/88800	Loss: 5162.479
3200/88800	Loss: 5348.217
6400/88800	Loss: 5333.669
9600/88800	Loss: 5317.936
12800/88800	Loss: 5343.440
16000/88800	Loss: 5322.856
19200/88800	Loss: 5331.748
22400/88800	Loss: 5336.023
25600/88800	Loss: 5314.711
28800/88800	Loss: 5317.540
32000/88800	Loss: 5318.757
35200/88800	Loss: 5352.502
38400/88800	Loss: 5277.377
41

60800/88800	Loss: 5317.772
64000/88800	Loss: 5328.944
67200/88800	Loss: 5333.219
70400/88800	Loss: 5328.492
73600/88800	Loss: 5336.569
76800/88800	Loss: 5287.605
80000/88800	Loss: 5356.226
83200/88800	Loss: 5329.838
86400/88800	Loss: 5279.659
Epoch: 127 Average loss: 5329.03
0/88800	Loss: 5131.483
3200/88800	Loss: 5326.691
6400/88800	Loss: 5353.637
9600/88800	Loss: 5359.841
12800/88800	Loss: 5283.090
16000/88800	Loss: 5315.089
19200/88800	Loss: 5313.402
22400/88800	Loss: 5322.940
25600/88800	Loss: 5369.965
28800/88800	Loss: 5302.253
32000/88800	Loss: 5329.659
35200/88800	Loss: 5301.200
38400/88800	Loss: 5320.189
41600/88800	Loss: 5302.033
44800/88800	Loss: 5312.138
48000/88800	Loss: 5366.781
51200/88800	Loss: 5294.322
54400/88800	Loss: 5337.808
57600/88800	Loss: 5337.079
60800/88800	Loss: 5304.888
64000/88800	Loss: 5367.520
67200/88800	Loss: 5296.543
70400/88800	Loss: 5331.962
73600/88800	Loss: 5311.708
76800/88800	Loss: 5330.016
80000/88800	Loss: 5352.362
83200/88800	Loss: 5376.841
86

12800/88800	Loss: 5349.191
16000/88800	Loss: 5341.161
19200/88800	Loss: 5300.591
22400/88800	Loss: 5330.924
25600/88800	Loss: 5353.420
28800/88800	Loss: 5328.855
32000/88800	Loss: 5315.295
35200/88800	Loss: 5320.942
38400/88800	Loss: 5326.124
41600/88800	Loss: 5304.211
44800/88800	Loss: 5306.424
48000/88800	Loss: 5296.278
51200/88800	Loss: 5328.784
54400/88800	Loss: 5317.922
57600/88800	Loss: 5344.784
60800/88800	Loss: 5311.235
64000/88800	Loss: 5323.120
67200/88800	Loss: 5299.582
70400/88800	Loss: 5362.076
73600/88800	Loss: 5338.688
76800/88800	Loss: 5340.090
80000/88800	Loss: 5307.502
83200/88800	Loss: 5323.227
86400/88800	Loss: 5311.406
Epoch: 138 Average loss: 5326.80
0/88800	Loss: 5200.038
3200/88800	Loss: 5344.568
6400/88800	Loss: 5279.110
9600/88800	Loss: 5291.429
12800/88800	Loss: 5301.892
16000/88800	Loss: 5313.756
19200/88800	Loss: 5304.182
22400/88800	Loss: 5342.583
25600/88800	Loss: 5350.564
28800/88800	Loss: 5311.670
32000/88800	Loss: 5321.870
35200/88800	Loss: 5324.179
38

57600/88800	Loss: 5334.445
60800/88800	Loss: 5340.371
64000/88800	Loss: 5359.004
67200/88800	Loss: 5297.739
70400/88800	Loss: 5302.557
73600/88800	Loss: 5360.360
76800/88800	Loss: 5336.861
80000/88800	Loss: 5347.243
83200/88800	Loss: 5295.576
86400/88800	Loss: 5305.719
Epoch: 148 Average loss: 5323.44
0/88800	Loss: 5349.219
3200/88800	Loss: 5322.082
6400/88800	Loss: 5388.283
9600/88800	Loss: 5324.732
12800/88800	Loss: 5306.743
16000/88800	Loss: 5350.773
19200/88800	Loss: 5321.825
22400/88800	Loss: 5314.068
25600/88800	Loss: 5321.574
28800/88800	Loss: 5322.400
32000/88800	Loss: 5300.290
35200/88800	Loss: 5331.171
38400/88800	Loss: 5284.447
41600/88800	Loss: 5318.600
44800/88800	Loss: 5354.820
48000/88800	Loss: 5310.639
51200/88800	Loss: 5300.702
54400/88800	Loss: 5302.878
57600/88800	Loss: 5282.761
60800/88800	Loss: 5324.508
64000/88800	Loss: 5335.447
67200/88800	Loss: 5333.039
70400/88800	Loss: 5334.488
73600/88800	Loss: 5297.864
76800/88800	Loss: 5326.481
80000/88800	Loss: 5341.376
83

In [20]:
print('hi')

hi


### Visualize

In [21]:
# Plot reconstructions
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Get a batch of data
for batch, labels in test_loader:
    break

# Reconstruct data using Joint-VAE model
recon = viz.reconstructions(batch)

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(recon.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/128/xint128_e150_b64_c10d10_gam30_reconstructions.png",dpi=200)

AttributeError: 'ImageListDataset' object has no attribute 'error_handling'

In [None]:
# Plot samples
samples = viz.samples()

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(samples.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/128/xint128_e150_b64_c10d10_gam30_samples.png",dpi=200)

#### Traverses all latent dimensions one by one and plots a grid of images where each row corresponds to a latent traversal of one latent dimension

In [None]:
# Plot all traversals
traversals = viz.all_latent_traversals(size=10)

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/128/xint128_e150_b64_c10d10_gam30_all_traversals_n20.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/128/xint128_e150_b64_c10d10_gam30_traversals2100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=1, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(traversals.numpy()[0, :, :]);
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/128/xint128_e150_b64_c10d10_gam30_traversals1100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=9, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/306/128/xint128_e150_b64_c10d10_gam30_traversals9100.png",dpi=200)

In [None]:
!ls

### Save Model

In [None]:
model_name = "306/128/xint128_e150_b64_c10d10_gam30.pth"

In [None]:
torch.save(model.state_dict(),"trained_models" + "statedict_" + model_name) # save state dict
#torch.save(model, model_name) # save full model

In [None]:
print("Done training: ",model_name)

#### Restore Model from State Dict

In [None]:
sd_model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64))
sd_model.load_state_dict(torch.load("statedict_" + model_name))

#### Restore Full Model
* Note in this case the serialized data is bound to the specific classes and exact directory strucutre used.

In [None]:
full_model = torch.load(model_name)

In [None]:
type(full_model)

In [None]:
type(sd_model)

In [None]:
!ls