# Basic Classifying VAE for MNIST Database

In [1]:
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import torch
import pprint
import numpy as np
import numpy.random as random
import datetime
from src.pytorch_cl_vae.model import ClVaeModel
import matplotlib.pyplot as plt
from collections import defaultdict

## 1 - Specify parameters for to the VAE and training

In [2]:
params = {
    'batch_size': 100,
    'num_epochs': 50,
    'latent_dim': 2,
    'encoder_hidden_size': 512,
    'decoder_hidden_size': 512,
    'classifier_hidden_size': 512,
    'vae_learning_rate': 0.0001,
    'classifier_learning_rate': 0.0001,
    'log_dir': '../data/logs',
    'model_dir': '../data/models',
    'data_dir': '../data'
}

## 2 - Fetch MNIST

In [3]:
mnist = fetch_mldata('MNIST original', data_home=params['data_dir'])
mnist.data = mnist.data / 255
num_samples, input_dim = mnist.data.shape
num_classes = len(np.unique(mnist.target))
lb = preprocessing.LabelBinarizer()
lb.fit(mnist.target)
params['classes_dim'] = [num_classes]
params['original_dim'] = input_dim
print('MNIST db has been successfully loaded, stored in the: "{}"'.format(params['data_dir'] + '/mldata'))
# split data to train and test subsets
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=0)
print("| Train subset shape:{} | Test subset shape:{} |".format(X_train.shape, X_test.shape))

MNIST db has been successfully loaded, stored in the: "../data/mldata"
| Train subset shape:(63000, 784) | Test subset shape:(7000, 784) |


## 3 - Create Model

In [4]:
# Initialize ClVaeModel
model = ClVaeModel(**params)
print("Model successfully initialized with params: ")
pprint.PrettyPrinter(indent=4).pprint(params)

train_losses = []
train_accuracies = []

Model successfully initialized with params: 
{   'batch_size': 100,
    'classes_dim': [10],
    'classifier_hidden_size': 512,
    'classifier_learning_rate': 0.0001,
    'data_dir': '../data',
    'decoder_hidden_size': 512,
    'encoder_hidden_size': 512,
    'latent_dim': 2,
    'log_dir': '../data/logs',
    'model_dir': '../data/models',
    'num_epochs': 50,
    'original_dim': 784,
    'vae_learning_rate': 0.0001}


## 4 - Train

In [None]:
save_each_steps = 500

# Train loop
train_step_i = 0
for epoch in range(params['num_epochs']):
    print('\nepoch {} out of {}'.format(epoch + 1, params['num_epochs']))
    for i in range(X_train.shape[0] // params['batch_size']):
        # Sample batch
        idx = random.choice(np.arange(0, X_train.shape[0]), params['batch_size'])
        x_batch = torch.from_numpy(X_train[idx]).float()
        y_batch = lb.transform(y_train[idx])
        y_batch = [torch.from_numpy(y_batch).float()]
        step_losses, step_accuracies = model.train_step(x_batch, y_batch)

#         step_losses = [loss.sum().detach().numpy() for loss in step_losses]
        # step_losses = Losses(*step_losses)
        # step_accuracies = Accuracies(*step_accuracies)

        train_losses.append(step_losses)
        train_accuracies.append(step_accuracies)

        train_step_i += 1

        print("\r|train step: {} | rec loss: {:.4f} | z_dkl loss: {:.4f} | class loss: {:.4f}"
              " | w_dkl loss: {:.4f} | class_accuracy: {:.4f} |".format(
            train_step_i, *step_losses, *step_accuracies
            ), end='')
        if train_step_i % 100 == 0:
            print()
        if train_step_i % save_each_steps == 0:
            dt = str(datetime.datetime.now().strftime("%m_%d_%Y_%I_%M_%p"))
            fname = params['model_dir'] + '/cl_vae_mnist_{}.pt'.format(dt)
            model.save_ckpt(fname)
print('*****Finished with the final loss: ', step_losses)


epoch 1 out of 50
|train step: 1 | rec loss: 54457.4727 | z_dkl loss: 0.8549 | class loss: 2304.1892 | w_dkl loss: 6.0656 | class_accuracy: 0.1100 ||train step: 2 | rec loss: 54269.2461 | z_dkl loss: 1.1503 | class loss: 2308.1626 | w_dkl loss: 4.7844 | class_accuracy: 0.0900 ||train step: 3 | rec loss: 54094.4023 | z_dkl loss: 1.6776 | class loss: 2315.0479 | w_dkl loss: 4.4916 | class_accuracy: 0.1000 ||train step: 4 | rec loss: 53887.4922 | z_dkl loss: 2.9424 | class loss: 2312.0674 | w_dkl loss: 4.0443 | class_accuracy: 0.0400 ||train step: 5 | rec loss: 53725.0586 | z_dkl loss: 3.8920 | class loss: 2304.4199 | w_dkl loss: 3.5206 | class_accuracy: 0.1000 ||train step: 6 | rec loss: 53584.9766 | z_dkl loss: 6.5561 | class loss: 2314.7642 | w_dkl loss: 3.9260 | class_accuracy: 0.0600 |



|train step: 100 | rec loss: 23058.5098 | z_dkl loss: 3321.5752 | class loss: 2228.2136 | w_dkl loss: 33.1309 | class_accuracy: 0.2200 |
|train step: 200 | rec loss: 22057.6641 | z_dkl loss: 1979.1127 | class loss: 2201.6150 | w_dkl loss: 64.0202 | class_accuracy: 0.3100 |
|train step: 300 | rec loss: 20551.3203 | z_dkl loss: 1529.8494 | class loss: 2170.4097 | w_dkl loss: 111.6069 | class_accuracy: 0.4300 |
|train step: 400 | rec loss: 19169.2305 | z_dkl loss: 1454.5453 | class loss: 2154.4675 | w_dkl loss: 137.0563 | class_accuracy: 0.4000 |
|train step: 500 | rec loss: 18671.7109 | z_dkl loss: 1189.8763 | class loss: 2072.7827 | w_dkl loss: 185.5578 | class_accuracy: 0.5600 |
|train step: 600 | rec loss: 18743.2188 | z_dkl loss: 1073.1919 | class loss: 2070.5984 | w_dkl loss: 213.8250 | class_accuracy: 0.5400 |
|train step: 630 | rec loss: 17764.0254 | z_dkl loss: 1005.8671 | class loss: 2041.4550 | w_dkl loss: 230.3179 | class_accuracy: 0.6000 |
epoch 2 out of 50
|train step: 700 |

|train step: 5100 | rec loss: 13257.0986 | z_dkl loss: 431.3175 | class loss: 2039.2487 | w_dkl loss: 759.4855 | class_accuracy: 0.6700 |
|train step: 5200 | rec loss: 13080.8379 | z_dkl loss: 444.3137 | class loss: 2002.8621 | w_dkl loss: 779.9443 | class_accuracy: 0.6600 |
|train step: 5300 | rec loss: 13357.3916 | z_dkl loss: 415.3503 | class loss: 1978.1544 | w_dkl loss: 816.4803 | class_accuracy: 0.7200 |
|train step: 5400 | rec loss: 13612.3711 | z_dkl loss: 415.2617 | class loss: 2016.5980 | w_dkl loss: 811.3365 | class_accuracy: 0.6900 |
|train step: 5500 | rec loss: 12916.9844 | z_dkl loss: 436.4155 | class loss: 2029.8784 | w_dkl loss: 772.7612 | class_accuracy: 0.6200 |
|train step: 5600 | rec loss: 12700.3701 | z_dkl loss: 437.4919 | class loss: 1977.6785 | w_dkl loss: 790.9957 | class_accuracy: 0.7600 |
|train step: 5670 | rec loss: 12934.1699 | z_dkl loss: 411.9921 | class loss: 2025.5322 | w_dkl loss: 797.0225 | class_accuracy: 0.6900 |
epoch 10 out of 50
|train step: 57

## 5 - Show losses graph

In [None]:
%matplotlib inline
losses = defaultdict(list)
losses_names = train_losses[0]._fields
print(losses_names)
step_loss = train_losses[0]
print(*step_loss)
for i, loss_name in enumerate(losses_names):
    losses[loss_name] = [l[i] for l in train_losses]
    plt.figure()
    plt.title(loss_name)
    plt.plot(losses[loss_name])
    plt.legend()
plt.show()

## 6 - Test

In [7]:
y_test = lb.transform(y_test)
losses, acc = model.test(torch.from_numpy(X_test).float(), [torch.from_numpy(y_test).float()])
pprint.PrettyPrinter(indent=4).pprint(losses)

Losses(rec_loss=tensor(0.2667), z_dkl_loss=tensor(0.0654), class_loss_0=tensor(2.3070), w_dkl_loss_0=tensor(0.1143))


## 7 - Examples

In [None]:
#TODO

