# Train a JointVAE model

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import numpy as np
import torch

use_cuda = torch.cuda.is_available()
use_cuda

True

In [4]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # change to your device

#### Home data stuff

#### Loadable data stuff

In [10]:
loadable_dresses = list(np.loadtxt('data/loadable_women_primary_dress.csv',delimiter=',',skiprows=1,dtype='str'))

In [6]:
print(len(loadable_dresses))

102125


In [7]:
loadable_dresses[:5]

['/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/4/6418008_9882769.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/6/6627534_9864695.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/7/6772508_9949243.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/7/6758001_9588597.jpg',
 '/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/6/6637527_9387666.jpg']

In [8]:
bad_data = ['/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/7/2/723739_1342692.jpg']

for bad in bad_data:
    if bad in loadable_dresses:
        del loadable_dresses[loadable_dresses.index(bad)]
        
print(len(loadable_dresses))

102124


#### Split into train and test set

In [11]:
image_paths_train = loadable_dresses[:88800]
image_paths_test = loadable_dresses[88800:-124]
#image_paths_train = files[:2800]
#image_paths_test = files[2800:-1]


print(f"Number of train image paths: {len(image_paths_train):,d}")
print(f"Number of test image paths: {len(image_paths_test):,d}")
print()
print("Sample paths:")
print(image_paths_train[0])
print(image_paths_train[-1])
print(image_paths_test[0])
print(image_paths_test[-1])

Number of train image paths: 88,800
Number of test image paths: 13,201

Sample paths:
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/6/4/6418008_9882769.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/5/0/505051_864215.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/5/2/529420_884192.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/3/4/341494_806784.jpg


In [12]:
from torchvision import transforms
#from utils.dataloaders_custom import get_imagelist_dataloader, ImageListDataset
from utils.dataloader_tools import get_imagelist_dataloader, ImageListDataset, ImgDsetCut5from256

BATCH_SIZE = 50

composed = transforms.Compose([transforms.CenterCrop((256,256)),transforms.Resize((256,256)),transforms.ToTensor()])

# convert rgb is for the cv2 loaded images that I've got in this dir
conv_rgb = False
error_handling = True
# train_dataset = ImageListDataset(image_paths_train, cut_from='top', cut_amount=256, 
#                                  transform=composed, convert_rgb=conv_rgb, error_handling=True)
# test_dataset = ImageListDataset(image_paths_test, cut_from='top', cut_amount=256, 
#                                  transform=composed, convert_rgb=conv_rgb, error_handling=True)

train_dataset = ImgDsetCut5from256(image_paths_train, transform=composed, convert_rgb=conv_rgb, error_handling=True)
test_dataset = ImgDsetCut5from256(image_paths_test, transform=composed, convert_rgb=conv_rgb, error_handling=True)

train_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=train_dataset)
test_loader = get_imagelist_dataloader(batch_size=BATCH_SIZE, dataset_object=test_dataset)

### Define latent distribution of the model

In [13]:
# Latent distribution will be joint distribution of 10 gaussian normal distributions
# and one 10 dimensional Gumbel Softmax distribution
latent_spec = {'cont': 10,
               'disc': [10]}

### Build a model

In [14]:
from jointvae.models_256_branching1 import VAE

model = VAE(latent_spec=latent_spec, img_size=(3, 256, 256), use_cuda=use_cuda)

In [15]:
#print(model)

### Train the model

In [16]:
from torch import optim

# Build optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4, amsgrad=True) # added amsgrad # orig lr 5e-4

In [17]:
from jointvae.training import Trainer
#from jointvae.training_debug import Trainer

# Define the capacities
# Continuous channels
cont_capacity = [0.0, 5, 30000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0
# Discrete channels
disc_capacity = [0.0, 5, 30000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0

# Build a trainer
trainer = Trainer(model, optimizer,
                  cont_capacity=cont_capacity,
                  disc_capacity=disc_capacity,
                 use_cuda=use_cuda)

In [20]:
type([]) == list

True

#### Initialize visualizer

In [59]:
#from viz.visualize import Visualizer
from viz.visualize import Visualizer

viz = Visualizer(model)
viz.save_images = False # needed to add this so it returns a tensor

In [None]:
# Train model for 10 epochs
# Note this should really be a 100 epochs and trained on a GPU, but this is just to demo

trainer.train(train_loader, epochs=600, save_training_gif=None)

0/98000	Loss: 45938.223
5000/98000	Loss: 45836.388
10000/98000	Loss: 44186.887
15000/98000	Loss: 39192.616
20000/98000	Loss: 35767.140
25000/98000	Loss: 33118.496
30000/98000	Loss: 29994.918
35000/98000	Loss: 25210.262
40000/98000	Loss: 23122.057
45000/98000	Loss: 22706.840
50000/98000	Loss: 22346.867
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
55000/98000	Loss: 21667.467
60000/98000	Loss: 21695.529
65000/98000	Loss: 21556.159
70000/98000	Loss: 21330.863
75000/98000	Loss: 21218.270
80000/98000	Loss: 20893.011
85000/98000	Loss: 20793.962
90000/98000	Loss: 20709.121
95000/98000	Loss: 20658.434
Epoch: 1 Average loss: 26784.89
0/98000	Loss: 19312.150
5000/98000	Loss: 20610.697
10000/98000	Loss: 20387.597
15000/98000	Loss: 20268.18

30000/98000	Loss: 18992.454
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
35000/98000	Loss: 19114.485
40000/98000	Loss: 18999.400
45000/98000	Loss: 18849.209
50000/98000	Loss: 18828.645
55000/98000	Loss: 18847.261
60000/98000	Loss: 18984.152
65000/98000	Loss: 18976.299
70000/98000	Loss: 19091.572
75000/98000	Loss: 18691.995
80000/98000	Loss: 18956.541
85000/98000	Loss: 18985.167
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
90000/98000	Loss: 18888.997
95000/98000	Loss: 18650.322
Epoch: 10 Average loss: 18919.96
0/98000	Loss: 19419.498
5000/98000	Loss: 18927.541
10000/98000	Loss: 18908.474
15000/98000	Loss: 18747.000
20000/98000	Loss: 19003.965
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacin

30000/98000	Loss: 18744.958
35000/98000	Loss: 18501.176
40000/98000	Loss: 18670.724
45000/98000	Loss: 18549.386
50000/98000	Loss: 18602.382
55000/98000	Loss: 18755.100
60000/98000	Loss: 18751.868
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
65000/98000	Loss: 18482.421
70000/98000	Loss: 18680.482
75000/98000	Loss: 18635.898
80000/98000	Loss: 18612.858
85000/98000	Loss: 18486.575
90000/98000	Loss: 18579.675
95000/98000	Loss: 18794.670
Epoch: 19 Average loss: 18657.44
0/98000	Loss: 19082.020
5000/98000	Loss: 18471.628
10000/98000	Loss: 18628.692
15000/98000	Loss: 18394.853
20000/98000	Loss: 18719.681
25000/98000	Loss: 18737.864
30000/98000	Loss: 18637.780
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
35000/98000	Loss: 18627.413
40000/98000	Loss: 18814.244
file /workspace/mnt/

55000/98000	Loss: 18440.558
60000/98000	Loss: 18451.028
65000/98000	Loss: 18418.848
70000/98000	Loss: 18519.268
75000/98000	Loss: 18458.613
80000/98000	Loss: 18508.934
85000/98000	Loss: 18433.143
90000/98000	Loss: 18450.806
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
95000/98000	Loss: 18517.157
Epoch: 28 Average loss: 18492.49
0/98000	Loss: 17637.447
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
5000/98000	Loss: 18459.344
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
10000/98000	Loss: 18422.599
15000/98000	Loss: 18445.589
20000/98000	Loss: 18580.428
25000/98000	Loss: 18415.066
30000/98000	Loss: 18519.242
35000/98000	Loss: 18498.239
40000/98000	Loss: 18310.714
45000/

80000/98000	Loss: 18426.882
85000/98000	Loss: 18392.845
90000/98000	Loss: 18356.397
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
95000/98000	Loss: 18439.591
Epoch: 37 Average loss: 18399.84
0/98000	Loss: 18235.398
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
5000/98000	Loss: 18282.535
10000/98000	Loss: 18410.909
15000/98000	Loss: 18391.657
20000/98000	Loss: 18522.023
25000/98000	Loss: 18367.687
30000/98000	Loss: 18529.380
35000/98000	Loss: 18252.039
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
40000/98000	Loss: 18655.215
45000/98000	Loss: 18307.260
50000/98000	Loss: 18462.158
55000/98000	Loss: 18382.383
60000/98000	Loss: 18337.846
65000/98000	Loss: 18110.212
70000/

90000/98000	Loss: 18359.954
95000/98000	Loss: 18436.953
Epoch: 46 Average loss: 18338.69
0/98000	Loss: 18154.814
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
5000/98000	Loss: 18363.793
10000/98000	Loss: 18172.671
15000/98000	Loss: 18523.835
20000/98000	Loss: 18275.852
25000/98000	Loss: 18266.622
30000/98000	Loss: 18325.628
35000/98000	Loss: 18422.053
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
40000/98000	Loss: 18334.951
45000/98000	Loss: 18361.792
50000/98000	Loss: 18232.788
55000/98000	Loss: 18493.833
60000/98000	Loss: 18333.154
65000/98000	Loss: 18336.574
70000/98000	Loss: 18246.020
75000/98000	Loss: 18418.041
80000/98000	Loss: 18245.817
85000/98000	Loss: 18426.751
90000/98000	Loss: 18309.248
95000/98000	Loss: 18378.088
Epoch: 47 Average loss: 18334.53
0/98000	Loss: 1

10000/98000	Loss: 18318.048
15000/98000	Loss: 18370.125
20000/98000	Loss: 18219.907
25000/98000	Loss: 18115.613
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
30000/98000	Loss: 18315.011
35000/98000	Loss: 18261.255
40000/98000	Loss: 18368.589
45000/98000	Loss: 18279.568
50000/98000	Loss: 18293.174
55000/98000	Loss: 18455.125
60000/98000	Loss: 18309.238
65000/98000	Loss: 18267.792
70000/98000	Loss: 18371.693
75000/98000	Loss: 18208.657
80000/98000	Loss: 18211.103
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
85000/98000	Loss: 18256.970
90000/98000	Loss: 18325.550
95000/98000	Loss: 18262.106
Epoch: 56 Average loss: 18289.32
0/98000	Loss: 18001.643
5000/98000	Loss: 18198.772
10000/98000	Loss: 18231.879
15000/98000	Loss: 18382.100
20000/98000	Loss: 18272.821
25000/98000	Loss: 18

10000/98000	Loss: 18427.402
15000/98000	Loss: 18161.666
20000/98000	Loss: 18306.520
25000/98000	Loss: 18253.398
30000/98000	Loss: 18234.945
35000/98000	Loss: 18315.061
40000/98000	Loss: 18221.312
45000/98000	Loss: 18248.450
50000/98000	Loss: 18036.279
55000/98000	Loss: 18275.621
60000/98000	Loss: 18240.495
65000/98000	Loss: 18314.093
70000/98000	Loss: 18272.684
75000/98000	Loss: 18160.271
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
80000/98000	Loss: 18280.154
85000/98000	Loss: 18309.724
90000/98000	Loss: 18252.292
95000/98000	Loss: 18276.277
Epoch: 65 Average loss: 18252.63
0/98000	Loss: 17756.936
5000/98000	Loss: 18185.813
10000/98000	Loss: 18224.214
15000/98000	Loss: 18409.942
20000/98000	Loss: 18272.788
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
25000/98000	Loss: 18

40000/98000	Loss: 18210.874
45000/98000	Loss: 18132.997
50000/98000	Loss: 18118.721
55000/98000	Loss: 18345.347
60000/98000	Loss: 18207.705
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
65000/98000	Loss: 18175.275
70000/98000	Loss: 18130.231
75000/98000	Loss: 18131.301
80000/98000	Loss: 18346.791
85000/98000	Loss: 18230.458
90000/98000	Loss: 18227.330
95000/98000	Loss: 18278.863
Epoch: 74 Average loss: 18216.14
0/98000	Loss: 18984.467
5000/98000	Loss: 18135.622
10000/98000	Loss: 18128.019
15000/98000	Loss: 18119.814
20000/98000	Loss: 18237.028
25000/98000	Loss: 18337.047
30000/98000	Loss: 18111.120
35000/98000	Loss: 18177.138
40000/98000	Loss: 18220.698
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
45000/98000	Loss: 18222.173
50000/98000	Loss: 18164.004
file /workspace/mnt/

55000/98000	Loss: 18089.506
60000/98000	Loss: 18111.077
65000/98000	Loss: 18157.629
70000/98000	Loss: 18268.898
75000/98000	Loss: 18028.392
80000/98000	Loss: 18232.928
85000/98000	Loss: 18330.457
90000/98000	Loss: 18221.235
95000/98000	Loss: 18181.354
Epoch: 83 Average loss: 18187.24
0/98000	Loss: 17965.225
5000/98000	Loss: 18209.794
10000/98000	Loss: 18173.091
15000/98000	Loss: 18098.752
20000/98000	Loss: 18201.275
25000/98000	Loss: 18175.952
30000/98000	Loss: 18231.273
35000/98000	Loss: 18204.550
40000/98000	Loss: 18112.534
45000/98000	Loss: 18052.188
50000/98000	Loss: 18269.677
55000/98000	Loss: 18230.912
60000/98000	Loss: 18233.881
65000/98000	Loss: 18146.235
70000/98000	Loss: 18136.263
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previou

80000/98000	Loss: 18089.022
85000/98000	Loss: 18231.581
90000/98000	Loss: 18213.641
95000/98000	Loss: 18139.175
Epoch: 92 Average loss: 18145.11
0/98000	Loss: 18597.836
5000/98000	Loss: 18138.862
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
10000/98000	Loss: 18218.223
15000/98000	Loss: 18190.076
20000/98000	Loss: 18209.427
25000/98000	Loss: 18039.026
30000/98000	Loss: 17966.117
35000/98000	Loss: 18223.237
40000/98000	Loss: 18012.171
45000/98000	Loss: 18168.417
50000/98000	Loss: 18139.918
55000/98000	Loss: 18143.427
60000/98000	Loss: 18242.872
65000/98000	Loss: 18187.156
70000/98000	Loss: 18187.722
75000/98000	Loss: 18138.701
80000/98000	Loss: 18074.227
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
85000/98000	Loss: 18162.046
90000/98000	Loss: 18105.873
95000/98000	Loss: 18

0/98000	Loss: 17579.402
5000/98000	Loss: 18029.641
10000/98000	Loss: 18104.120
15000/98000	Loss: 17998.125
20000/98000	Loss: 18130.054
25000/98000	Loss: 18064.003
30000/98000	Loss: 18129.693
35000/98000	Loss: 18110.434
40000/98000	Loss: 18102.514
45000/98000	Loss: 18040.540
50000/98000	Loss: 18088.388
55000/98000	Loss: 18013.211
60000/98000	Loss: 18246.037
65000/98000	Loss: 18098.905
70000/98000	Loss: 17980.997
75000/98000	Loss: 18133.390
80000/98000	Loss: 18244.126
85000/98000	Loss: 17956.985
90000/98000	Loss: 18106.430
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
95000/98000	Loss: 18136.901
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
Epoch: 102 Average loss: 18090.11
0/98000	Loss: 18581.422
5000/98000	Loss: 18102.276
10000/98000	Loss: 18033.973
15000/98000	Loss: 18167.

25000/98000	Loss: 18023.643
30000/98000	Loss: 17928.967
35000/98000	Loss: 18109.307
40000/98000	Loss: 17870.173
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
45000/98000	Loss: 18128.595
50000/98000	Loss: 17833.398
55000/98000	Loss: 17956.557
60000/98000	Loss: 17977.067
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
65000/98000	Loss: 17920.282
70000/98000	Loss: 18021.016
75000/98000	Loss: 17884.346
80000/98000	Loss: 18033.052
85000/98000	Loss: 17958.498
90000/98000	Loss: 17881.776
95000/98000	Loss: 17816.068
Epoch: 111 Average loss: 17979.97
0/98000	Loss: 18254.568
5000/98000	Loss: 18062.315
10000/98000	Loss: 17914.279
15000/98000	Loss: 18122.294
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replaci

file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
20000/98000	Loss: 17798.651
25000/98000	Loss: 17979.833
30000/98000	Loss: 17907.165
35000/98000	Loss: 18017.109
40000/98000	Loss: 17893.726
45000/98000	Loss: 17878.372
50000/98000	Loss: 17999.187
55000/98000	Loss: 17744.926
60000/98000	Loss: 17906.216
65000/98000	Loss: 17916.723
70000/98000	Loss: 17910.798
75000/98000	Loss: 17903.535
80000/98000	Loss: 18040.579
85000/98000	Loss: 17891.526
90000/98000	Loss: 17977.209
95000/98000	Loss: 17923.134
Epoch: 120 Average loss: 17921.82
0/98000	Loss: 17980.713
5000/98000	Loss: 17758.101
10000/98000	Loss: 17951.469
15000/98000	Loss: 17680.592
20000/98000	Loss: 17967.378
25000/98000	Loss: 17894.138
30000/98000	Loss: 17993.741
35000/98000	Loss: 17994.271
40000/98000	Loss: 17745.185
45000/98000	Loss: 18083.802
50000/98000	Loss: 18117.558
55000/98000	Loss: 17929.488
60000/98000	Loss: 17957.569
65000

45000/98000	Loss: 17934.491
50000/98000	Loss: 17725.454
55000/98000	Loss: 18016.353
60000/98000	Loss: 17763.342
65000/98000	Loss: 18031.033
70000/98000	Loss: 17890.720
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2801945_3529818.jpg does not have 3 channels
Replacing with previous image
75000/98000	Loss: 17800.514
80000/98000	Loss: 17993.048
85000/98000	Loss: 17821.690
90000/98000	Loss: 17853.480
95000/98000	Loss: 17791.325
Epoch: 129 Average loss: 17882.25
0/98000	Loss: 16979.531
5000/98000	Loss: 17836.823
10000/98000	Loss: 17941.625
15000/98000	Loss: 17935.413
20000/98000	Loss: 17858.334
file /workspace/mnt/crucial_2TB/111_Extra_Data/Macys/all_wearables20180810/2/8/2884734_3674822.jpg does not have 3 channels
Replacing with previous image
25000/98000	Loss: 18021.012
30000/98000	Loss: 18062.624
35000/98000	Loss: 17997.493
40000/98000	Loss: 17822.199
45000/98000	Loss: 17854.643
50000/98000	Loss: 18008.941
55000/98000	Loss: 17614.081
file /workspace/mnt

In [None]:
print('hi')

In [None]:
# Get a batch of data
for batch, labels in test_loader:
    break
print("batch: ",type(batch),batch.shape)
type(viz.reconstructions(batch))

### Visualize

In [None]:
# Plot reconstructions
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Get a batch of data
for batch, labels in test_loader:
    break

# Reconstruct data using Joint-VAE model
recon = viz.reconstructions(batch)

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(recon.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_reconstructions.png",dpi=200)

In [None]:
# Plot samples
samples = viz.samples()

plt.figure(figsize=(26,26))
plt.imshow(np.rot90(np.transpose(samples.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_samples.png",dpi=200)

#### Traverses all latent dimensions one by one and plots a grid of images where each row corresponds to a latent traversal of one latent dimension

In [None]:
# Plot all traversals
traversals = viz.all_latent_traversals(size=10)

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_all_traversals_n20.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_traversals2100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=1, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(traversals.numpy()[0, :, :]);
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_traversals1100.png",dpi=200)

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=9, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(20,20))
plt.imshow(np.rot90(np.transpose(traversals.numpy(),(2,1,0)),k=3));
plt.savefig("sample_images/home/256/xint256_e600_b100_c10d10_cap48_30k_gam32_traversals9100.png",dpi=200)

In [None]:
!ls

### Save Model

In [None]:
model_name = "xint256_e600_b100_c10d10_cap48_30k_gam32.pth"

In [None]:
torch.save(model.state_dict(),"trained_models" + "statedict_" + model_name) # save state dict
#torch.save(model, model_name) # save full model

In [None]:
print("Done training: ",model_name)

#### Restore Model from State Dict

In [None]:
sd_model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64))
sd_model.load_state_dict(torch.load("statedict_" + model_name))

#### Restore Full Model
* Note in this case the serialized data is bound to the specific classes and exact directory strucutre used.

In [None]:
full_model = torch.load(model_name)

In [None]:
type(full_model)

In [None]:
type(sd_model)

In [None]:
!ls