# Train a JointVAE model

In [13]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [14]:
import torch
from viz.visualize import Visualizer
use_cuda = torch.cuda.is_available()
use_cuda

True

In [15]:
use_cuda

True

In [16]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # change to your device

#### Prepare data list

In [17]:
!ls data

dress_dresslen_train_test_splits.json	dress_sleeve_train_test_splits.json
dress_sleevelen_train_test_splits.json


In [18]:
import json
with open("./data/dress_dresslen_train_test_splits.json", "r") as infile:
    data_dict = json.load(infile)

In [19]:
data_dict.keys()

dict_keys(['X_train_1', 'y_train_1', 'X_test_1', 'y_test_1', 'X_train_2', 'y_train_2', 'X_test_2', 'y_test_2', 'X_train_3', 'y_train_3', 'X_test_3', 'y_test_3', 'X_train_4', 'y_train_4', 'X_test_4', 'y_test_4', 'X_train_5', 'y_train_5', 'X_test_5', 'y_test_5', 'X_train_6', 'y_train_6', 'X_test_6', 'y_test_6', 'X_train_7', 'y_train_7', 'X_test_7', 'y_test_7', 'X_train_8', 'y_train_8', 'X_test_8', 'y_test_8', 'X_train_9', 'y_train_9', 'X_test_9', 'y_test_9', 'X_train_10', 'y_train_10', 'X_test_10', 'y_test_10'])

In [20]:
data_dict['X_train_1'][:5]

['/2/8/2893552_3773662.jpg',
 '/2/9/2982376_3889235.jpg',
 '/2/7/2783355_3578973.jpg',
 '/2/9/2974380_3918638.jpg',
 '/2/8/2886740_3675612.jpg']

#### Create list of image paths

In [21]:
!python -V

Python 3.6.3 :: Anaconda, Inc.


In [22]:
image_paths_train = []
image_paths_test = []

root_data_dir = "/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables"

for key, val in data_dict.items():
    if 'X_train' in key:
        image_paths_train.extend([root_data_dir + imgpath for imgpath in val])
    elif 'X_test' in key:
        image_paths_test.extend([root_data_dir + imgpath for imgpath in val])

print(f"Number of train image paths: {len(image_paths_train):,d}")
print(f"Number of test image paths: {len(image_paths_test):,d}")
print()
print("Sample paths:")
print(image_paths_train[0])
print(image_paths_train[-1])
print(image_paths_test[0])
print(image_paths_test[-1])

Number of train image paths: 167,742
Number of test image paths: 18,638

Sample paths:
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/8/2893552_3773662.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/6/2683298_3676896.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/4/2431229_3158108.jpg
/home/jovyan/vmldata/raw_source_data/v20180810_all_wearables/2/5/2569101_3223742.jpg


In [23]:
#from utils.dataloaders import get_mnist_dataloaders, get_fashion_mnist_dataloaders
#train_loader, test_loader = get_mnist_dataloaders(batch_size=64)
#train_loader, test_loader = get_fashion_mnist_dataloaders(batch_size=64)

In [24]:
from torchvision import transforms
from utils.dataloaders_custom import get_imagelist_dataloader, ImageListDataset

#composed = transforms.Compose([transforms.Resize((260,260)), transforms.ToTensor()])
composed = transforms.Compose([transforms.Resize((64,64)), transforms.ToTensor()])

train_dataset = ImageListDataset(image_paths_train, transform=composed)
test_dataset = ImageListDataset(image_paths_test, transform=composed)

train_loader = get_imagelist_dataloader(batch_size=20, dataset_object=train_dataset)
test_loader = get_imagelist_dataloader(batch_size=20, dataset_object=test_dataset)

### Define latent distribution of the model

In [25]:
# Latent distribution will be joint distribution of 10 gaussian normal distributions
# and one 10 dimensional Gumbel Softmax distribution
latent_spec = {'cont': 10,
               'disc': [10]}

### Build a model

In [26]:
#from jointvae.models_v1 import VAE
from jointvae.models import VAE

#model = VAE(latent_spec=latent_spec, img_size=(3, 260, 260), use_cuda=use_cuda)
model = VAE(latent_spec=latent_spec, img_size=(3, 64, 64), use_cuda=use_cuda)

In [27]:
print(model)

VAE(
  (img_to_features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU()
  )
  (features_to_hidden): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
  )
  (fc_mean): Linear(in_features=256, out_features=10, bias=True)
  (fc_log_var): Linear(in_features=256, out_features=10, bias=True)
  (fc_alphas): ModuleList(
    (0): Linear(in_features=256, out_features=10, bias=True)
  )
  (latent_to_features): Sequential(
    (0): Linear(in_features=20, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=1024, bias=True)
    (3): ReLU()
  )
  (features_to_img): Sequential(
    (0): ConvTranspose

### Train the model

In [28]:
from torch import optim

# Build optimizer
optimizer = optim.Adam(model.parameters(), lr=3e-4, amsgrad=True) # added amsgrad # orig lr 5e-4

In [29]:
from jointvae.training import Trainer

# Define the capacities
# Continuous channels
cont_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0
# Discrete channels
disc_capacity = [0.0, 5.0, 25000, 30.0]  # Starting at a capacity of 0.0, increase this to 5.0
                                         # over 25000 iterations with a gamma of 30.0

# Build a trainer
trainer = Trainer(model, optimizer,
                  cont_capacity=cont_capacity,
                  disc_capacity=disc_capacity,
                 use_cuda=use_cuda)

#### Initialize visualizer

In [30]:
# Build a visualizer which will be passed to trainer to visualize progress during training
viz = Visualizer(model)

In [None]:
# Train model for 10 epochs
# Note this should really be a 100 epochs and trained on a GPU, but this is just to demo

trainer.train(train_loader, epochs=200, save_training_gif=('./training_rd_308_64_200e_v1.gif', viz))

0/167742	Loss: 2776.650
1000/167742	Loss: 2394.200
2000/167742	Loss: 1490.457
3000/167742	Loss: 1388.712
4000/167742	Loss: 1377.599
5000/167742	Loss: 1347.805
6000/167742	Loss: 1359.986
7000/167742	Loss: 1322.137
8000/167742	Loss: 1328.284
9000/167742	Loss: 1309.065
10000/167742	Loss: 1338.004
11000/167742	Loss: 1298.892
12000/167742	Loss: 1302.248
13000/167742	Loss: 1292.436
14000/167742	Loss: 1288.961
15000/167742	Loss: 1279.462
16000/167742	Loss: 1264.345
17000/167742	Loss: 1272.750
18000/167742	Loss: 1268.310
19000/167742	Loss: 1248.455
20000/167742	Loss: 1261.820
21000/167742	Loss: 1268.755
22000/167742	Loss: 1255.233
23000/167742	Loss: 1243.324
24000/167742	Loss: 1244.678
25000/167742	Loss: 1249.922
26000/167742	Loss: 1242.560
27000/167742	Loss: 1236.872
28000/167742	Loss: 1252.023
29000/167742	Loss: 1244.712
30000/167742	Loss: 1212.802
31000/167742	Loss: 1219.299
32000/167742	Loss: 1233.910
33000/167742	Loss: 1215.865
34000/167742	Loss: 1229.535
35000/167742	Loss: 1236.695
36000

122000/167742	Loss: 1141.774
123000/167742	Loss: 1140.258
124000/167742	Loss: 1138.254
125000/167742	Loss: 1138.963
126000/167742	Loss: 1141.148
127000/167742	Loss: 1125.395
128000/167742	Loss: 1141.781
129000/167742	Loss: 1138.013
130000/167742	Loss: 1144.048
131000/167742	Loss: 1143.607
132000/167742	Loss: 1153.787
133000/167742	Loss: 1146.007
134000/167742	Loss: 1133.059
135000/167742	Loss: 1147.595
136000/167742	Loss: 1149.698
137000/167742	Loss: 1150.166
138000/167742	Loss: 1134.337
139000/167742	Loss: 1137.348
140000/167742	Loss: 1142.850
141000/167742	Loss: 1141.274
142000/167742	Loss: 1138.851
143000/167742	Loss: 1147.465
144000/167742	Loss: 1120.345
145000/167742	Loss: 1127.850
146000/167742	Loss: 1149.447
147000/167742	Loss: 1152.066
148000/167742	Loss: 1148.424
149000/167742	Loss: 1149.769
150000/167742	Loss: 1144.669
151000/167742	Loss: 1138.698
152000/167742	Loss: 1135.808
153000/167742	Loss: 1145.737
154000/167742	Loss: 1127.891
155000/167742	Loss: 1147.361
156000/167742	

74000/167742	Loss: 1127.604
75000/167742	Loss: 1102.840
76000/167742	Loss: 1110.059
77000/167742	Loss: 1125.361
78000/167742	Loss: 1120.768
79000/167742	Loss: 1114.697
80000/167742	Loss: 1118.496
81000/167742	Loss: 1106.070
82000/167742	Loss: 1133.506
83000/167742	Loss: 1118.939
84000/167742	Loss: 1123.652
85000/167742	Loss: 1129.381
86000/167742	Loss: 1126.298
87000/167742	Loss: 1114.553
88000/167742	Loss: 1101.963
89000/167742	Loss: 1111.856
90000/167742	Loss: 1107.500
91000/167742	Loss: 1136.283
92000/167742	Loss: 1117.913
93000/167742	Loss: 1118.486
94000/167742	Loss: 1115.669
95000/167742	Loss: 1110.490
96000/167742	Loss: 1122.332
97000/167742	Loss: 1111.769
98000/167742	Loss: 1119.796
99000/167742	Loss: 1114.584
100000/167742	Loss: 1135.358
101000/167742	Loss: 1117.688
102000/167742	Loss: 1122.146
103000/167742	Loss: 1111.575
104000/167742	Loss: 1129.973
105000/167742	Loss: 1129.955
106000/167742	Loss: 1126.129
107000/167742	Loss: 1115.044
108000/167742	Loss: 1119.180
109000/1677

25000/167742	Loss: 1105.798
26000/167742	Loss: 1116.422
27000/167742	Loss: 1112.671
28000/167742	Loss: 1130.253
29000/167742	Loss: 1109.935
30000/167742	Loss: 1106.061
31000/167742	Loss: 1123.800
32000/167742	Loss: 1114.503
33000/167742	Loss: 1108.395
34000/167742	Loss: 1131.950
35000/167742	Loss: 1117.048
36000/167742	Loss: 1117.967
37000/167742	Loss: 1094.239
38000/167742	Loss: 1108.166
39000/167742	Loss: 1105.263
40000/167742	Loss: 1114.625
41000/167742	Loss: 1119.204
42000/167742	Loss: 1132.214
43000/167742	Loss: 1101.805
44000/167742	Loss: 1107.661
45000/167742	Loss: 1118.891
46000/167742	Loss: 1128.171
47000/167742	Loss: 1097.130
48000/167742	Loss: 1105.160
49000/167742	Loss: 1097.473
50000/167742	Loss: 1108.472
51000/167742	Loss: 1115.602
52000/167742	Loss: 1103.519
53000/167742	Loss: 1130.415
54000/167742	Loss: 1106.352
55000/167742	Loss: 1121.881
56000/167742	Loss: 1114.410
57000/167742	Loss: 1116.020
58000/167742	Loss: 1101.765
59000/167742	Loss: 1117.470
60000/167742	Loss: 1

145000/167742	Loss: 1111.655
146000/167742	Loss: 1105.109
147000/167742	Loss: 1105.817
148000/167742	Loss: 1103.402
149000/167742	Loss: 1117.752
150000/167742	Loss: 1102.602
151000/167742	Loss: 1107.561
152000/167742	Loss: 1107.306
153000/167742	Loss: 1112.975
154000/167742	Loss: 1097.184
155000/167742	Loss: 1107.800
156000/167742	Loss: 1116.139
157000/167742	Loss: 1121.572
158000/167742	Loss: 1099.544
159000/167742	Loss: 1103.702
160000/167742	Loss: 1111.635
161000/167742	Loss: 1096.803
162000/167742	Loss: 1105.956
163000/167742	Loss: 1122.083
164000/167742	Loss: 1120.696
165000/167742	Loss: 1095.268
166000/167742	Loss: 1111.881
167000/167742	Loss: 1095.050
Epoch: 7 Average loss: 1110.29
0/167742	Loss: 1110.327
1000/167742	Loss: 1104.304
2000/167742	Loss: 1115.880
3000/167742	Loss: 1106.099
4000/167742	Loss: 1105.108
5000/167742	Loss: 1107.442
6000/167742	Loss: 1109.707
7000/167742	Loss: 1094.505
8000/167742	Loss: 1101.306
9000/167742	Loss: 1102.716
10000/167742	Loss: 1122.575
11000/1

98000/167742	Loss: 1098.519
99000/167742	Loss: 1098.455
100000/167742	Loss: 1115.109
101000/167742	Loss: 1107.962
102000/167742	Loss: 1114.675
103000/167742	Loss: 1107.044
104000/167742	Loss: 1102.678
105000/167742	Loss: 1100.052
106000/167742	Loss: 1103.616
107000/167742	Loss: 1096.937
108000/167742	Loss: 1095.577
109000/167742	Loss: 1106.442
110000/167742	Loss: 1115.480
111000/167742	Loss: 1094.946
112000/167742	Loss: 1097.791
113000/167742	Loss: 1102.268
114000/167742	Loss: 1102.790
115000/167742	Loss: 1099.757
116000/167742	Loss: 1103.093
117000/167742	Loss: 1086.206
118000/167742	Loss: 1116.478
119000/167742	Loss: 1114.836
120000/167742	Loss: 1093.280
121000/167742	Loss: 1090.854
122000/167742	Loss: 1082.849
123000/167742	Loss: 1099.129
124000/167742	Loss: 1100.675
125000/167742	Loss: 1101.228
126000/167742	Loss: 1109.369
127000/167742	Loss: 1105.126
128000/167742	Loss: 1103.702
129000/167742	Loss: 1112.174
130000/167742	Loss: 1104.817
131000/167742	Loss: 1092.111
132000/167742	Lo

49000/167742	Loss: 1107.097
50000/167742	Loss: 1115.129
51000/167742	Loss: 1084.913
52000/167742	Loss: 1120.376
53000/167742	Loss: 1099.549
54000/167742	Loss: 1095.549
55000/167742	Loss: 1113.982
56000/167742	Loss: 1097.934
57000/167742	Loss: 1099.048
58000/167742	Loss: 1118.971
59000/167742	Loss: 1089.446
60000/167742	Loss: 1108.064
61000/167742	Loss: 1089.272
62000/167742	Loss: 1095.721
63000/167742	Loss: 1092.871
64000/167742	Loss: 1092.123
65000/167742	Loss: 1100.823
66000/167742	Loss: 1097.186
67000/167742	Loss: 1099.917
68000/167742	Loss: 1103.493
69000/167742	Loss: 1093.612
70000/167742	Loss: 1109.070
71000/167742	Loss: 1102.289
72000/167742	Loss: 1089.283
73000/167742	Loss: 1097.527
74000/167742	Loss: 1106.840
75000/167742	Loss: 1109.960
76000/167742	Loss: 1097.458
77000/167742	Loss: 1091.871
78000/167742	Loss: 1099.904
79000/167742	Loss: 1098.180
80000/167742	Loss: 1103.083
81000/167742	Loss: 1083.046
82000/167742	Loss: 1098.317
83000/167742	Loss: 1095.085
84000/167742	Loss: 1

1000/167742	Loss: 1096.720
2000/167742	Loss: 1096.839
3000/167742	Loss: 1094.025
4000/167742	Loss: 1091.384
5000/167742	Loss: 1101.878
6000/167742	Loss: 1089.034
7000/167742	Loss: 1107.386
8000/167742	Loss: 1089.084
9000/167742	Loss: 1110.923
10000/167742	Loss: 1096.184
11000/167742	Loss: 1103.053
12000/167742	Loss: 1085.305
13000/167742	Loss: 1094.085
14000/167742	Loss: 1096.766
15000/167742	Loss: 1100.333
16000/167742	Loss: 1087.867
17000/167742	Loss: 1086.594
18000/167742	Loss: 1102.433
19000/167742	Loss: 1103.796
20000/167742	Loss: 1094.831
21000/167742	Loss: 1098.516
22000/167742	Loss: 1095.794
23000/167742	Loss: 1100.225
24000/167742	Loss: 1079.781
25000/167742	Loss: 1095.494
26000/167742	Loss: 1101.523
27000/167742	Loss: 1093.209
28000/167742	Loss: 1085.938
29000/167742	Loss: 1106.989
30000/167742	Loss: 1096.040
31000/167742	Loss: 1093.166
32000/167742	Loss: 1079.943
33000/167742	Loss: 1097.139
34000/167742	Loss: 1111.413
35000/167742	Loss: 1112.370
36000/167742	Loss: 1111.002
3

123000/167742	Loss: 1091.680
124000/167742	Loss: 1088.631
125000/167742	Loss: 1106.971
126000/167742	Loss: 1102.425
127000/167742	Loss: 1103.474
128000/167742	Loss: 1097.000
129000/167742	Loss: 1091.150
130000/167742	Loss: 1093.323
131000/167742	Loss: 1099.504
132000/167742	Loss: 1091.580
133000/167742	Loss: 1102.167
134000/167742	Loss: 1092.989
135000/167742	Loss: 1109.572
136000/167742	Loss: 1089.730
137000/167742	Loss: 1096.125
138000/167742	Loss: 1103.345
139000/167742	Loss: 1094.840
140000/167742	Loss: 1085.076
141000/167742	Loss: 1100.857
142000/167742	Loss: 1089.209
143000/167742	Loss: 1109.164
144000/167742	Loss: 1095.887
145000/167742	Loss: 1105.013
146000/167742	Loss: 1104.031
147000/167742	Loss: 1090.349
148000/167742	Loss: 1095.186
149000/167742	Loss: 1103.025
150000/167742	Loss: 1096.677
151000/167742	Loss: 1098.308
152000/167742	Loss: 1092.172
153000/167742	Loss: 1103.953
154000/167742	Loss: 1103.441
155000/167742	Loss: 1096.253
156000/167742	Loss: 1099.465
157000/167742	

75000/167742	Loss: 1088.801
76000/167742	Loss: 1104.360
77000/167742	Loss: 1091.656
78000/167742	Loss: 1110.553
79000/167742	Loss: 1085.623
80000/167742	Loss: 1103.378
81000/167742	Loss: 1104.978
82000/167742	Loss: 1098.791
83000/167742	Loss: 1096.561
84000/167742	Loss: 1093.522
85000/167742	Loss: 1096.628
86000/167742	Loss: 1091.045
87000/167742	Loss: 1092.312
88000/167742	Loss: 1089.281
89000/167742	Loss: 1119.655
90000/167742	Loss: 1097.396
91000/167742	Loss: 1099.984
92000/167742	Loss: 1102.687
93000/167742	Loss: 1098.679
94000/167742	Loss: 1095.215
95000/167742	Loss: 1095.159
96000/167742	Loss: 1082.935
97000/167742	Loss: 1088.247
98000/167742	Loss: 1087.646
99000/167742	Loss: 1101.009
100000/167742	Loss: 1082.443
101000/167742	Loss: 1100.962
102000/167742	Loss: 1087.825
103000/167742	Loss: 1091.426
104000/167742	Loss: 1084.190
105000/167742	Loss: 1096.399
106000/167742	Loss: 1111.059
107000/167742	Loss: 1084.089
108000/167742	Loss: 1094.230
109000/167742	Loss: 1089.451
110000/167

26000/167742	Loss: 1097.344
27000/167742	Loss: 1084.010
28000/167742	Loss: 1104.090
29000/167742	Loss: 1090.429
30000/167742	Loss: 1093.561
31000/167742	Loss: 1102.346
32000/167742	Loss: 1097.768
33000/167742	Loss: 1095.365
34000/167742	Loss: 1096.838
35000/167742	Loss: 1098.230
36000/167742	Loss: 1093.344
37000/167742	Loss: 1108.634
38000/167742	Loss: 1096.965
39000/167742	Loss: 1093.586
40000/167742	Loss: 1096.268
41000/167742	Loss: 1083.576
42000/167742	Loss: 1089.848
43000/167742	Loss: 1094.100
44000/167742	Loss: 1090.226
45000/167742	Loss: 1092.986
46000/167742	Loss: 1084.696
47000/167742	Loss: 1088.351
48000/167742	Loss: 1097.087
49000/167742	Loss: 1097.263
50000/167742	Loss: 1106.987
51000/167742	Loss: 1094.706
52000/167742	Loss: 1096.302
53000/167742	Loss: 1102.962
54000/167742	Loss: 1093.085
55000/167742	Loss: 1096.966
56000/167742	Loss: 1088.057
57000/167742	Loss: 1096.026
58000/167742	Loss: 1101.437
59000/167742	Loss: 1103.599
60000/167742	Loss: 1094.611
61000/167742	Loss: 1

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



27000/167742	Loss: 1087.157
28000/167742	Loss: 1092.367
29000/167742	Loss: 1074.937
30000/167742	Loss: 1087.013
31000/167742	Loss: 1096.829
32000/167742	Loss: 1086.423
33000/167742	Loss: 1093.094
34000/167742	Loss: 1096.672
35000/167742	Loss: 1089.502
36000/167742	Loss: 1098.509
37000/167742	Loss: 1088.533
38000/167742	Loss: 1076.668
39000/167742	Loss: 1109.909
40000/167742	Loss: 1090.229
41000/167742	Loss: 1087.649
42000/167742	Loss: 1091.458
43000/167742	Loss: 1077.457
44000/167742	Loss: 1088.997
45000/167742	Loss: 1097.685
46000/167742	Loss: 1095.167
47000/167742	Loss: 1091.063
48000/167742	Loss: 1086.609
49000/167742	Loss: 1095.066
50000/167742	Loss: 1087.203
51000/167742	Loss: 1082.062
52000/167742	Loss: 1097.592
53000/167742	Loss: 1105.743
54000/167742	Loss: 1090.510
55000/167742	Loss: 1096.406
56000/167742	Loss: 1086.790
57000/167742	Loss: 1084.927
58000/167742	Loss: 1094.324
59000/167742	Loss: 1081.432
60000/167742	Loss: 1092.269
61000/167742	Loss: 1086.557
62000/167742	Loss: 1

### Visualize

In [None]:
# Plot reconstructions
%matplotlib inline
import matplotlib.pyplot as plt

# Get a batch of data
for batch, labels in test_loader:
    break

# Reconstruct data using Joint-VAE model
recon = viz.reconstructions(batch)

plt.figure(figsize=(8,8))
plt.imshow(recon.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot samples
samples = viz.samples()

plt.figure(figsize=(8,8))
plt.imshow(samples.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot all traversals
traversals = viz.all_latent_traversals(size=10)

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=2, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=1, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

In [None]:
# Plot a grid of some traversals
traversals = viz.latent_traversal_grid(cont_idx=9, cont_axis=1, disc_idx=0, disc_axis=0, size=(10, 10))

plt.figure(figsize=(8,8))
plt.imshow(traversals.numpy()[0, :, :], cmap='gray');

### Save Model

In [None]:
model_name = "jvae_fmnist_oct292018.pth"

In [None]:
torch.save(model.state_dict(),"statedict_" + model_name) # save state dict
torch.save(model, model_name) # save full model

#### Restore Model from State Dict

In [None]:
sd_model = VAE(latent_spec=latent_spec, img_size=(1, 32, 32))
sd_model.load_state_dict(torch.load("statedict_" + model_name))

#### Restore Full Model
* Note in this case the serialized data is bound to the specific classes and exact directory strucutre used.

In [None]:
full_model = torch.load(model_name)

In [None]:
type(full_model)

In [None]:
type(sd_model)