# Basic Classifying VAE for MNIST Database

In [1]:
from sklearn.datasets import fetch_mldata
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import torch
import pprint
import numpy as np
import numpy.random as random
import datetime
from src.pytorch_cl_vae.model import ClVaeModel
import matplotlib.pyplot as plt
from collections import defaultdict

## 1 - Specify parameters for to the VAE and training

In [2]:
params = {
    'batch_size': 2048,
    'num_epochs': 300,
    'latent_dim': 2,
    'encoder_hidden_size': 256,
    'decoder_hidden_size': 256,
    'classifier_hidden_size': 256,
    'vae_learning_rate': 0.001,
    'classifier_learning_rate': 0.001,
    'log_dir': '../data/logs',
    'model_dir': '../data/models',
    'data_dir': '../data'
}

## 2 - Fetch MNIST

In [3]:
mnist = fetch_mldata('MNIST original', data_home=params['data_dir'])
mnist.data = mnist.data / 255
num_samples, input_dim = mnist.data.shape
num_classes = len(np.unique(mnist.target))
lb = preprocessing.LabelBinarizer()
lb.fit(mnist.target)
params['classes_dim'] = [num_classes]
params['original_dim'] = input_dim
print('MNIST db has been successfully loaded, stored in the: "{}"'.format(params['data_dir'] + '/mldata'))
# split data to train and test subsets
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=0)
print("| Train subset shape:{} | Test subset shape:{} |".format(X_train.shape, X_test.shape))

MNIST db has been successfully loaded, stored in the: "../data/mldata"
| Train subset shape:(63000, 784) | Test subset shape:(7000, 784) |


## 3 - Create Model

In [4]:
# Initialize ClVaeModel
model = ClVaeModel(**params)
print("Model successfully initialized with params: ")
pprint.PrettyPrinter(indent=4).pprint(params)

train_losses = []
train_accuracies = []

Model successfully initialized with params: 
{   'batch_size': 2048,
    'classes_dim': [10],
    'classifier_hidden_size': 512,
    'classifier_learning_rate': 0.0001,
    'data_dir': '../data',
    'decoder_hidden_size': 512,
    'encoder_hidden_size': 512,
    'latent_dim': 2,
    'log_dir': '../data/logs',
    'model_dir': '../data/models',
    'num_epochs': 100,
    'original_dim': 784,
    'vae_learning_rate': 0.0001}


## 4 - Train

In [5]:
save_each_steps = 500

# Train loop
train_step_i = 0
for epoch in range(params['num_epochs']):
    print('\nepoch {} out of {}'.format(epoch + 1, params['num_epochs']))
    for i in range(X_train.shape[0] // params['batch_size']):
        # Sample batch
        idx = random.choice(np.arange(0, X_train.shape[0]), params['batch_size'])
        x_batch = torch.from_numpy(X_train[idx]).float()
        y_batch = lb.transform(y_train[idx])
        y_batch = [torch.from_numpy(y_batch).float()]
        step_losses, step_accuracies = model.train_step(x_batch, y_batch)

#         step_losses = [loss.sum().detach().numpy() for loss in step_losses]
        # step_losses = Losses(*step_losses)
        # step_accuracies = Accuracies(*step_accuracies)

        train_losses.append(step_losses)
        train_accuracies.append(step_accuracies)

        train_step_i += 1

        print("\r|train step: {} | rec loss: {:.4f} | z_dkl loss: {:.4f} | class loss: {:.4f}"
              " | w_dkl loss: {:.4f} | class_accuracy: {:.4f} |".format(
            train_step_i, *step_losses, *step_accuracies
            ), end='')
        if train_step_i % 100 == 0:
            print()
        if train_step_i % save_each_steps == 0:
            dt = str(datetime.datetime.now().strftime("%m_%d_%Y_%I_%M_%p"))
            fname = params['model_dir'] + '/cl_vae_mnist_{}.pt'.format(dt)
            model.save_ckpt(fname)
print('*****Finished with the final loss: ', step_losses)


epoch 1 out of 100
|train step: 30 | rec loss: 0.6270 | z_dkl loss: 0.0001 | class loss: 2.2528 | w_dkl loss: 0.0000 | class_accuracy: 0.2153 |
epoch 2 out of 100
|train step: 60 | rec loss: 0.5634 | z_dkl loss: 0.0002 | class loss: 2.1307 | w_dkl loss: 0.0000 | class_accuracy: 0.4287 |
epoch 3 out of 100
|train step: 90 | rec loss: 0.5076 | z_dkl loss: 0.0006 | class loss: 1.9839 | w_dkl loss: 0.0000 | class_accuracy: 0.6016 |
epoch 4 out of 100
|train step: 100 | rec loss: 0.4911 | z_dkl loss: 0.0006 | class loss: 1.9392 | w_dkl loss: 0.0000 | class_accuracy: 0.6558 |
|train step: 120 | rec loss: 0.4566 | z_dkl loss: 0.0008 | class loss: 1.8722 | w_dkl loss: 0.0000 | class_accuracy: 0.6973 |
epoch 5 out of 100
|train step: 150 | rec loss: 0.4141 | z_dkl loss: 0.0001 | class loss: 1.7944 | w_dkl loss: 0.0000 | class_accuracy: 0.7690 |
epoch 6 out of 100
|train step: 180 | rec loss: 0.3733 | z_dkl loss: 0.0040 | class loss: 1.7489 | w_dkl loss: 0.0000 | class_accuracy: 0.7881 |
epoch 

|train step: 1440 | rec loss: 0.2082 | z_dkl loss: 0.0068 | class loss: 1.5347 | w_dkl loss: 0.0000 | class_accuracy: 0.9429 |
epoch 49 out of 100
|train step: 1470 | rec loss: 0.2100 | z_dkl loss: 0.0032 | class loss: 1.5365 | w_dkl loss: 0.0000 | class_accuracy: 0.9375 |
epoch 50 out of 100
|train step: 1500 | rec loss: 0.2117 | z_dkl loss: 0.0015 | class loss: 1.5361 | w_dkl loss: 0.0000 | class_accuracy: 0.9355 |

epoch 51 out of 100
|train step: 1530 | rec loss: 0.2073 | z_dkl loss: 0.0016 | class loss: 1.5304 | w_dkl loss: nan | class_accuracy: 0.9419 |5 |
epoch 52 out of 100
|train step: 1560 | rec loss: 0.2068 | z_dkl loss: 0.0097 | class loss: 1.5348 | w_dkl loss: nan | class_accuracy: 0.9370 |8 |
epoch 53 out of 100
|train step: 1590 | rec loss: 0.2068 | z_dkl loss: 0.0041 | class loss: nan | w_dkl loss: nan | class_accuracy: 0.0996 |4 |5 |
epoch 54 out of 100
|train step: 1600 | rec loss: 0.2074 | z_dkl loss: 0.0010 | class loss: nan | w_dkl loss: nan | class_accuracy: 0.095

|train step: 2910 | rec loss: 0.1941 | z_dkl loss: 0.0043 | class loss: nan | w_dkl loss: nan | class_accuracy: 0.0972 |
epoch 98 out of 100
|train step: 2940 | rec loss: 0.1912 | z_dkl loss: 0.0038 | class loss: nan | w_dkl loss: nan | class_accuracy: 0.0947 |
epoch 99 out of 100
|train step: 2970 | rec loss: 0.1907 | z_dkl loss: 0.0032 | class loss: nan | w_dkl loss: nan | class_accuracy: 0.0996 |
epoch 100 out of 100
|train step: 3000 | rec loss: 0.1913 | z_dkl loss: 0.0102 | class loss: nan | w_dkl loss: nan | class_accuracy: 0.1123 |
*****Finished with the final loss:  Losses(rec_loss=tensor(0.1913, grad_fn=<BinaryCrossEntropyBackward>), z_dkl_loss=tensor(0.0102, grad_fn=<MulBackward>), class_loss_0=tensor(nan, grad_fn=<NllLossBackward>), w_dkl_loss_0=tensor(nan, grad_fn=<MulBackward>))


## 5 - Show losses graph

In [None]:
%matplotlib inline
losses = defaultdict(list)
losses_names = train_losses[0]._fields
print(losses_names)
step_loss = train_losses[0]
print(*step_loss)
for i, loss_name in enumerate(losses_names):
    losses[loss_name] = [l[i] for l in train_losses]
    plt.figure()
    plt.title(loss_name)
    plt.plot(losses[loss_name])
    plt.legend()
plt.show()

## 6 - Test

In [None]:
y_test = lb.transform(y_test)
losses, acc = model.test(torch.from_numpy(X_test).float(), [torch.from_numpy(y_test).float()])
pprint.PrettyPrinter(indent=4).pprint(losses)

## 7 - Examples

In [None]:
#TODO

