# Toy Data: 8-Modes Mixture Model

## Introduction

This notebook is supposed to demonstrate how to use FrEIA to create reversible architectures.

We will use the toy data from the paper in this example.

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
%matplotlib notebook

import torch
import torch.optim
import moving_fashion_mnist

import numpy as np, os
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from tqdm import tqdm
from time import time as t_

import FrEIA

from FrEIA.framework import InputNode, OutputNode, Node, ReversibleGraphNet
from FrEIA.modules import rev_gru_layer, F_GRU, rev_multiplicative_layer,\
F_fully_connected, F_GRU_BN, F_cgate, glow_gru_coupling_layer,glow_gru_cgate_coupling_layer,\
 glow_gru_residual_coupling_layer


import data

device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(device)

cuda


## Setting up the data

We generate the data by sampling from a Gaussian mixture distribution with 8 labeled modes. The dataset will contain $2^{20}$ samples, $10000$ of which we will use for testing purposes.

You can see a plot of the test data below.

In the forward process our model is supposed to predict the label (or in this case, color) of a sample based on its position in $\mathbb{R}^2$. In the reverse direction the model should allow us to sample from the mixture component given by a label.

In [3]:
batch_size, T = 256,1
input_labels = [8,9]
# test_split = 100
# tot_dataset_size=2**10

pos, labels, labels_onehot = moving_fashion_mnist.generate(T, input_labels)
# pred_c = labels_onehot.argmax(dim=2)

2
(12000,)


  warn('The default multichannel argument (None) is deprecated.  Please '
  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


N x T x image_size^2:  (12000, 1, 400)
label.shape:  (12000,)
label_onehot.shape:  (12000, 1, 2)


## Setting up the model

Our model consists of three invertible blocks using multiplicative coupling layers and 3-layer fully connected sub-networks for $s_i$ and $t_i$. The input is encoded in 2 dimensions. The latent dimension $z$ is set to 2 and concatenated with the predicted labels $y$ encoded as a 8-dimensional one-hot vector. In- and output are zero-padded to 16 dimensions.

In [4]:
ndim_tot = 512
ndim_x = pos.data.numpy().shape[2]
ndim_y = labels_onehot.data.numpy().shape[2]
ndim_z = 8

# print("ndim_x: ", ndim_x)
# print("ndim_y: ", ndim_y)
# print("ndim_z: ", ndim_z)

P_Z = multivariate_normal(mean=np.zeros(ndim_z))

inp = InputNode(ndim_tot, name='input')

# t1 = Node([inp.out0], rev_gru_layer,
#           {'F_class': F_GRU, 'clamp': 2.0,
#            'F_args': {'dropout': 0.0}})

# t2 = Node([t1.out0], rev_gru_layer,
#           {'F_class': F_GRU, 'clamp': 2.0,
#            'F_args': {'dropout': 0.0}})

# t3 = Node([t2.out0], rev_gru_layer,
#           {'F_class': F_GRU, 'clamp': 2.0,
#            'F_args': {'dropout': 0.0}})

# t1 = Node([inp.out0], glow_gru_coupling_layer,
#           {'F_class': F_GRU, 'clamp': 2.0,
#            'F_args': {'dropout': 0.0}})

# t2 = Node([t1.out0], glow_gru_coupling_layer,
#           {'F_class': F_GRU, 'clamp': 2.0,
#            'F_args': {'dropout': 0.0}})

t1 = Node([inp.out0], glow_gru_residual_coupling_layer,
          {'F_class': F_GRU_BN,'clamp': 2.0,
           'F_args': {}})

t2 = Node([t1.out0], glow_gru_residual_coupling_layer,
          {'F_class': F_GRU_BN, 'clamp': 2.0,
           'F_args': {}})

t3 = Node([t2.out0], glow_gru_residual_coupling_layer,
          {'F_class': F_GRU_BN, 'clamp': 2.0,
           'F_args': {}})

outp = OutputNode([t3.out0], name='output')
# outp = OutputNode([t2.out0], name='output')
nodes = [inp, t1, t2, t3, outp]
# nodes = [inp, t1, t2,  outp]
model = ReversibleGraphNet(nodes)
model = model.cuda()

Node 6a71d0 has following input dimensions:
	 Output #0 of node input: (512,)

Node 6a72b0 has following input dimensions:
	 Output #0 of node 6a71d0: (512,)

Node 69eb38 has following input dimensions:
	 Output #0 of node 6a72b0: (512,)

Node output has following input dimensions:
	 Output #0 of node 69eb38: (512,)



## Training the model

We will train our model using 3 losses. In the forward direction we apply a MSE loss to the assigned label and a distributional loss to the latent variable $z$.
We make use of the reversability of our model and apply a third loss, that matches the distribution of samples from our dataset to the distribution of backward predictions of our model.
You can find more information on the losses in the [paper](https://arxiv.org/abs/1808.04730).


In [5]:
# Training parameters
n_epochs = 1500
meta_epoch = 12
n_its_per_epoch = 10

batch_size_train, T_train = 512,6
# input_labels = [5,7]
test_split = 100

pos, labels, labels_onehot = moving_fashion_mnist.generate(T_train, input_labels)
c = labels_onehot

lr = 1e-4
gamma = 0.01**(1./120)
l2_reg = 2e-5

y_noise_scale = 3e-2
zeros_noise_scale = 3e-2
# y_noise_scale = 0
# zeros_noise_scale = 0

# relative weighting of losses:
lambd_predict = 300.
lambd_latent = 300.
lambd_rev = 400.

# pad_x = torch.zeros(batch_size_train, ndim_tot - ndim_x)
# pad_yz = torch.zeros(batch_size_train, ndim_tot - ndim_y - ndim_z)

# optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.8, 0.8),
#                              eps=1e-04, weight_decay=l2_reg)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999),
                             eps=1e-08, weight_decay=l2_reg)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=meta_epoch,
                                            gamma=gamma)


def MMD_multiscale(x, y):
    xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())

    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))

    dxx = rx.t() + rx - 2.*xx
    dyy = ry.t() + ry - 2.*yy
    dxy = rx.t() + ry - 2.*zz

    XX, YY, XY = (torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device))

    for a in [0.2, 0.5, 0.9, 1.3]:
        XX += a**2 * (a**2 + dxx)**-1
        YY += a**2 * (a**2 + dyy)**-1
        XY += a**2 * (a**2 + dxy)**-1

    return torch.mean(XX + YY - 2.*XY)

def fit(input, target):
#     return torch.mean(torch.abs(input - target))
#     return torch.nn.CrossEntropyLoss(input, target)
    return torch.mean((input - target)**2)



loss_backward = MMD_multiscale
loss_latent = MMD_multiscale
loss_fit = fit

test_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(pos[:test_split],
                                   labels_onehot[:test_split]),
    batch_size=batch_size, shuffle=True, drop_last=True)

train_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(pos[test_split:], labels_onehot[test_split:]),
    batch_size=batch_size_train, shuffle=True, drop_last=True)

2
(12000,)
N x T x image_size^2:  (12000, 6, 400)
label.shape:  (12000,)
label_onehot.shape:  (12000, 6, 2)


We can now define ou training method. Note how we simply used the model for forward training, zeroed the gradients and switch to backwards training simply by setting `rev=True`. Randomness in the samples generated by backwards prediction is achieved by drawing $z$ randomly from a 2-dimensional Normal distribution.

In [6]:
def train(i_epoch=0):
    model.train()

    l_tot = 0
    batch_idx = 0
    
    t_start = t_()
    
    loss_factor = 600**(float(i_epoch) / 300) / 600
    loss_factor2 = 600**(float(i_epoch) / 300) / 600
    if loss_factor > 1:
        loss_factor = 1
        
    if loss_factor2 > 1:
        loss_factor2 = 1
        
    loss = 0
    for x, y in train_loader:
        batch_idx += 1
        if batch_idx > n_its_per_epoch:
            break

        x, y = x.to(device), y.to(device)
        y_clean = y.clone()
#         pad_x = zeros_noise_scale * torch.randn(batch_size_train, T_train, ndim_tot - ndim_x, device=device)
        pad_x = torch.zeros(batch_size_train, T_train, ndim_tot - ndim_x, device=device)
#         pad_yz = zeros_noise_scale * torch.randn(batch_size_train, T_train, ndim_tot - ndim_y - ndim_z, device=device)
        pad_yz = torch.zeros(batch_size_train, T_train, ndim_tot - ndim_y - ndim_z, device=device)

#         y += y_noise_scale * torch.randn(batch_size_train, T_train, ndim_y, dtype=torch.float, device=device)

        x, y = (torch.cat((x, pad_x),  dim=2),
                torch.cat((torch.randn(batch_size_train, T_train, ndim_z, device=device), pad_yz, y), dim=2))

        optimizer.zero_grad()
        
        # Forward step:
        output = model(x)
        l = 0
        for t in range(T_train):

            # L_y( y, y_gt)
            l += lambd_predict * loss_fit(output[:, t, ndim_z:], y[:, t, ndim_z:])

            # Shorten output, and remove gradients wrt y, for latent loss
            y_short = torch.cat((y[:, t, :ndim_z], y[:, t, -ndim_y:]), dim=1)

            output_block_grad = torch.cat((output[:, t, :ndim_z],
                                           output[:, t, -ndim_y:].data), dim=1)

            # L_z( [y,z], [y_gt, z_sample] )
            l += lambd_latent * loss_latent(output_block_grad, y_short)
            
        l_tot += l.data.item()

        l.backward()

        # Backward step:
#         pad_yz = zeros_noise_scale * torch.randn(batch_size_train, T_train, ndim_tot - ndim_y - ndim_z, device=device)
        pad_yz = torch.zeros(batch_size_train, T_train, ndim_tot - ndim_y - ndim_z, device=device)
        y = y_clean # + y_noise_scale * torch.randn(batch_size_train, T_train, ndim_y, device=device)

#         orig_z_perturbed = (output.data[:, :, :ndim_z] + y_noise_scale *
#                             torch.randn(batch_size_train, T_train, ndim_z, device=device))
        
        orig_z_perturbed = output.data[:, :, :ndim_z]
        
        # perturbed [z=f_z(x), pad, y_gt]
        y_rev = torch.cat((orig_z_perturbed, pad_yz, y), dim=2)
        # perturbed [z_sample, pad, y_gt]
        y_rev_rand = torch.cat((torch.randn(batch_size_train, T_train, ndim_z, device=device), pad_yz, y), dim=2)
        
        output_rev = model(y_rev, rev=True)
        output_rev_rand = model(y_rev_rand, rev=True)

        # L_x (MMD)
        l_rev = 0
        for t in range(T_train):
            l_rev += lambd_rev * loss_factor * loss_backward(output_rev_rand[:, t, :ndim_x], x[:, t, :ndim_x])

            # MSE fit loss against X -> [Y,Z]+perturb -> X_inv and original X
            # MSE( g([output_z_pert, pad, y_gt_pert]), [x, pad] )
            l_rev += 0.50 * lambd_predict * loss_fit(output_rev[:,t,:], x[:,t,:])
    #         l_rev += 0.50 * lambd_predict * loss_backward(output_rev, x)
            
            '''reducing sample noise'''
            l_rev -= loss_factor2 * lambd_rev * output_rev.std()
        
        l_tot += l_rev.data.item()
        l_rev.backward()

        for p in model.parameters():
            if p.grad is not None:
                p.grad.data.clamp_(-15.00, 15.00)

        optimizer.step()

#     print('%.1f\t%.5f' % (
#                              float(batch_idx)/(time()-t_start),
#                              l_tot / batch_idx,
#                            ), flush=True)
        loss += l_tot
#     print('loss:\t%.5f' % (loss), flush=True)

    return l_tot / batch_idx

In [7]:
def log_e(s):
    '''log of the nonlinear function e'''
    return 2.0 * 0.636 * np.arctan(s)

def e(s):
    '''nonlinear function e'''
    return np.exp(2.0 * 0.636 * np.arctan(s))

In [8]:
batch_size, T = 100,8
# input_labels = [5,7]
test_split = 100
N_samp = 100

pos_long, labels_long, labels_onehot_long = moving_fashion_mnist.generate(T, input_labels)
c_long = labels_onehot_long

test_long_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(pos_long, labels_onehot_long),
    batch_size=100, shuffle=True, drop_last=True)

2
(12000,)
N x T x image_size^2:  (12000, 8, 400)
label.shape:  (12000,)
label_onehot.shape:  (12000, 8, 2)


We initialize our model parameters using normal distributions. The following loop over epochs plots label predictions and backwards predicted samples, so you can see the model getting better during training.

In [9]:
save_path='%s_vs_%s_final_fashion/'%(input_labels[0], input_labels[1])
# if os.path.exists(save_path):
#     files = os.listdir(os.path.join(save_path,'weights'))
#     if files:
        
#         max_iter= max([int((file.split('_')[1]).split('.')[0])for file in files])
#         cpt = torch.load(os.path.join(save_path,'weights','epoch_{:d}'.format(max_iter) + '.pth'))
#         model.load_state_dict(cpt['state_dict'])
#         optimizer = cpt['op'] 
#         i_epoch = cpt['epoch'] 
#         scheduler = cpt['scheduler']

In [10]:
# for mod_list in model.children():
#     for block in mod_list.children():
#         for coeff in block.children():
#             coeff.fc3.weight.data = 0.01*torch.randn(coeff.fc3.weight.shape)
#             torch.nn.init.xavier_uniform_(coeff.conv3.weight)


model = model.cuda()
model.to(device)
        
nrows = 15
# fig, axes = plt.subplots(nrows=nrows, ncols=T, figsize=(16,16))
# plt.subplots_adjust(wspace=0.000001, hspace=T/8)
plt.subplots_adjust(wspace=0.000001, hspace=T/12)
fig, axes = plt.subplots(nrows=1, ncols=T, figsize=(8,8))
# First set up the test samples that we use consistently to see the change over training.
N_samp = 100

# x_samps, y_gt = test_long_loader.dataset.tensors
for x, y in test_long_loader:
    x_samps = x[:N_samp]
    y_gt = y[:N_samp]
    
y_samps = y_gt

# ----------------swap the y value at timepoint T_train ----------------------
# y_samps[:,:2,:] = (1 - y_samps[:,:2,:])
# y_samps[:,T_train - 1,:] = (1 - y_samps[:,T_train - 1,:])
y_samps[:,2:,:] = (1 - y_samps[:,2:,:])
# y_samps[:,0,:] = (1 - y_samps[:,0,:])
# ----------------linear gradual changes of y----------------------
# for t in range(T):
#     d = (2 * y_samps[:,t,:] - 1)/(T - 1)
#     y_samps[:,t,:] -=d*t
    
y_samps_label = y_samps.data
# ---------------No perturbation----------------------
# y_samps += y_noise_scale * torch.randn(N_samp, T, ndim_y)
# ---------------Having same z across time----------------------
# z_samps = torch.randn(N_samp, T, ndim_z)
z_samps = torch.randn(N_samp,ndim_z)
z_samps = z_samps.repeat(1,T).reshape((100,T,8))

yz_pads =  torch.zeros(N_samp, T, ndim_tot - ndim_y - ndim_z) # * zeros_noise_scale
y_samps = torch.cat([z_samps, yz_pads,  y_samps], dim=2)
y_samps = y_samps.to(device)

N_density_set = 500
x_samps_density = torch.cat([x for x,y in test_long_loader], dim=0)[:N_density_set]
x_pads = torch.zeros(N_density_set, T, ndim_tot - ndim_x)
x_samps_padded_density = torch.cat((x_samps_density, x_pads), dim=2)

img_size = 20
import os, time 

try:
    t_start = t_()
    for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

        scheduler.step()

        # Initially, the l2 reg. on x and z can give huge gradients, set
        # the lr lower for this
        if i_epoch < 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr * 1e-2
        
        train(i_epoch)
        
        if i_epoch % 50 != 0 and i_epoch != n_epochs-1:
            continue
        if i_epoch == 0:
            continue
        rev_x = model(y_samps, rev=True)
        
        x_pads = torch.zeros(N_samp, T, ndim_tot - ndim_x)
        x_samps_padded = torch.cat((x_samps, x_pads), dim=2)
        forward_z_pad_y = model(x_samps_padded.to(device))
                                   
        pred_c = forward_z_pad_y.data[:, :, -ndim_y:].argmax(dim=2)
        pred_z = forward_z_pad_y.data[:, :, :ndim_z]
        P_Z_z = P_Z.pdf(pred_z)
        # Inverse Jacobian J_X = dx / d[y,z] given forwarded [y,z] = f(x)
        # This is for computing the densities of given X
#         log_det_J_X_yz = model.jacobian(forward_z_pad_y, rev=True).data       
#         P_X_x = e(log_e(P_Z_z) - log_det_J_X_yz) # P_X_x.shape = [N, T]
#         P_X_x = P_X_x.data.numpy()

#         Forward Jacobian
        log_det_J_yz_X = model.jacobian(x_samps_padded.to(device)).data
        P_X_x = e(log_e(P_Z_z) + log_det_J_yz_X)
        P_X_x = P_X_x.data.numpy()
        
        
        forward_z_pad_y_density = model(x_samps_padded_density.to(device))
        pred_z_density = forward_z_pad_y_density.data[:, :, :ndim_z]
        P_Z_z_density = P_Z.pdf(pred_z_density)
         # [X density, clean zero pad] -> J_yz_X
        log_det_J_yz_X_density = model.jacobian(x_samps_padded_density.to(device)).data
        P_X_x_density = e(log_e(P_Z_z_density) + log_det_J_yz_X_density) # P_X_x.shape = [N, T]
        P_X_x_density = P_X_x_density.data.numpy()
        
        rev_x_only = torch.tensor(rev_x[:,:,:ndim_x], dtype=torch.float)
        rev_x = rev_x.cpu().data.numpy()
        P_Z_z_samps = P_Z.pdf(z_samps)
        rev_x_clean = torch.cat((rev_x_only, x_pads), dim=2)
        # [y input, z sample] -> [X sample, clean zero pad] -> J_yz_X
        log_det_J_yz_X_samps = model.jacobian(torch.tensor(rev_x_clean).cuda()).data
        P_X_x_samps = e(log_e(P_Z_z_samps) + log_det_J_yz_X_samps) # P_X_x.shape = [N, T]
        # [y input, z sample]  -> J_X_yz
#         log_det_J_X_yz_samps = model.jacobian(y_samps, rev=True).data
#         P_X_x_samps = e(log_e(P_Z_z_samps) - log_det_J_X_yz_samps) # P_X_x.shape = [N, T]
        P_X_x_samps = P_X_x_samps.data.numpy()
        

        P_X_x_full = np.concatenate((P_X_x_density, P_X_x, P_X_x_samps), axis=0)
        P_X_min = np.min(P_X_x_full, axis=0)
        P_X_max = np.max(P_X_x_full, axis=0)
        P_X_x = (P_X_x - P_X_min) / (P_X_max - P_X_min)
        P_X_x_samps = (P_X_x_samps - P_X_min) / (P_X_max - P_X_min)
        
        
        pred_diff = torch.zeros(T)

#         for k in range(nrows):
#             for t in range(T):
#                 axes[k][t].clear(); axes[k][t].set_xticks([]); axes[k][t].set_yticks([])
#                 im_squeezed = x_samps[k,t,:].data.numpy()
#                 im = 255*im_squeezed.reshape(20, -1)
#                 axes[k][t].imshow(im, cmap='gray')
#                 y_pred = pred_c[k,t].cuda()
#                 P_X_x_kt = P_X_x[k][t]
#                 P_X_str = str("{:.2f}".format(P_X_x_kt) )
#                 axes[k][t].set_title(str(y_pred.cpu().data.numpy()) + " " + P_X_str, fontsize=8)

#                 axes[k][t+T].clear(); axes[k][t+T].set_xticks([]); axes[k][t+T].set_yticks([])
#                 im_squeezed = rev_x_only[k,t,:].data.numpy()
#                 im = im_squeezed.reshape(20, -1)
# #                 if not np.all(im > 0):
# #                     print(im, k , t)
# #                     im[im < 0] = 0
#                 im = 255 * (im - np.min(im))/(np.max(im) - np.min(im))
#                 axes[k][t+T].imshow(im, cmap='gray')
# #                 if t == 3:
# #                     print(im)
#                 y_samp_gt = y_samps_label[k,t,:].argmax(dim=0)
# #                 y_samp_gt = torch.max(y_samps_label[k,t,:])
# #                 P_X_x_samps_kt = P_X_x_samps[k][t]
# #                 P_X_x_samps_str = "{:.2f}".format(P_X_x_samps_kt)
#                 title = '{:d} {:.2f}'.format(int(y_samp_gt.cpu().data.numpy()), P_X_x_samps[k][t])
#                 axes[k][t+T].set_title(title, fontsize=8)
#         fig.canvas.draw()

#         ''' save output plots'''
        
        if not os.path.exists(save_path):
            try: 
                import pathlib
                pathlib.Path(save_path).mkdir(parents=True)
            except FileNotFoundError:
                raise
                
        if not os.path.exists(os.path.join(save_path,'imgs')):
            try: 
                import pathlib
                pathlib.Path(os.path.join(save_path,'imgs')).mkdir(parents=True)
            except FileNotFoundError:
                raise
        
        if not os.path.exists(os.path.join(save_path,'pngs')):
            try: 
                import pathlib
                pathlib.Path(os.path.join(save_path,'pngs')).mkdir(parents=True)
            except FileNotFoundError:
                raise
        K = np.where(np.greater_equal(P_X_x_samps[:,0]+0.2, P_X_x_samps[:,1]) \
                     & np.greater(P_X_x_samps[:,1]+0.2 , P_X_x_samps[:,2])\
                     & np.greater(P_X_x_samps[:,3]+0.2 , P_X_x_samps[:,2])\
                     & np.greater(P_X_x_samps[:,4]+0.2 , P_X_x_samps[:,3]) \
                     & np.greater(P_X_x_samps[:,5]+0.2 , P_X_x_samps[:,4]))[0]
#         K = np.where(np.greater(P_X_x_samps[:,1] , P_X_x_samps[:,2])\
#                      & np.greater(P_X_x_samps[:,3] , P_X_x_samps[:,2])\
#                      & np.greater(P_X_x_samps[:,4] , P_X_x_samps[:,3]) \
#                      & np.greater(P_X_x_samps[:,5] , P_X_x_samps[:,4]))[0]
       
        np.savez(os.path.join(save_path,'imgs','moving_mnist_%s_vs_%s_%s'%(input_labels[0],input_labels[1], time.strftime("%Y_%m_%d_%H_%M_%S"))), rev_x_only[K].data.numpy())
        for i,k in enumerate(K):
            for t in range(T):
                im_squeezed = rev_x_only[k,t,:].data.numpy()
                im = im_squeezed.reshape(img_size, -1)
        #                 if not np.all(im > 0):
        #                     print(im, k , t)
        #                     im[im < 0] = 0
#                 im = 255 * (im - np.min(im))/(np.max(im) - np.min(im))
                im[im < 0] = 0
                axes[t].clear(); axes[t].set_xticks([]); axes[t].set_yticks([])
                axes[t].imshow(im, cmap='gray')
                y_samp_gt = y_samps_label[k,t,:].argmax(dim=0)
                title = '[{:d} {:.2f}]'.format(int(y_samp_gt.cpu().data.numpy()), P_X_x_samps[k][t])
                axes[t].set_title(title, fontsize=8, fontweight='bold')
#             fig.canvas.draw()
            fig.savefig(os.path.join(save_path,'pngs','moving_mnist_%s_vs_%s_%s_%s_%s.png'%(input_labels[0],input_labels[1], i_epoch,k + 1, time.strftime("%Y_%m_%d_%H_%M_%S"))), dpi=600, bbox_inches = 'tight') 
            
        if i_epoch % 100 != 0 and i_epoch != n_epochs-1:
            continue
            
        ''' save model weights'''
        
        if not os.path.exists(os.path.join(save_path,'weights')):
            try: 
                import pathlib
                pathlib.Path(os.path.join(save_path,'weights')).mkdir(parents=True)
            except FileNotFoundError:
                raise
        fn = os.path.join(save_path, 'weights', 'epoch_{:d}'.format(i_epoch) + '.pth')
        from collections import OrderedDict as OD
        od = OD()
        od['state_dict'] = model.state_dict()
        od['op'] = optimizer
        od['epoch'] = i_epoch
        od['scheduler'] = scheduler
        torch.save(od, fn)
            
except KeyboardInterrupt:
    pass
finally:
    print(f"\n\nTraining took {(t_()-t_start)/60:.2f} minutes\n")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

100%|#####################################| 1500/1500 [3:00:59<00:00, 43.16s/it]



Training took 180.99 minutes






K = np.where(np.sum(P_X_x_samps, axis = 1) > 0.45 * T and P_X_x_samps[:,0,:] >= P_X_x_samps[:,1,:] >= P_X_x_samps[:,2,:] )[0]
K

In [11]:
# P_X_x_samps[np.greater(P_X_x_samps[:,1] , P_X_x_samps[:,2]) & np.greater(P_X_x_samps[:,3] , P_X_x_samps[:,2])\
#            & np.greater(P_X_x_samps[:,4] , P_X_x_samps[:,3]) & np.greater(P_X_x_samps[:,5] , P_X_x_samps[:,4])\
#            & np.greater(P_X_x_samps[:,0] , 0.02)]
np.where(np.greater(P_X_x_samps[:,1] , P_X_x_samps[:,2]) & np.greater(P_X_x_samps[:,3] , P_X_x_samps[:,2])\
           & np.greater(P_X_x_samps[:,4] , P_X_x_samps[:,3]) & np.greater(P_X_x_samps[:,5] , P_X_x_samps[:,4])\
           & np.greater(P_X_x_samps[:,0] , 0.01))[0]

array([51, 91, 93])

In [12]:
plt.subplots_adjust(wspace=0.00001, hspace=T/12)
fig, axes = plt.subplots(nrows=nrows, ncols=T, figsize=(16,16))

for k in range(nrows):
    for t in range(T):
        axes[k][t].clear(); axes[k][t].set_xticks([]); axes[k][t].set_yticks([])
        im_squeezed = x_samps[k,t,:].data.numpy()
        im = 255*im_squeezed.reshape(20, -1)
        axes[k][t].imshow(im, cmap='gray')
        y_pred = pred_c[k,t].cuda()
        P_X_x_kt = P_X_x[k][t]
        P_X_str = str("{:.2f}".format(P_X_x_kt) )
        axes[k][t].set_title(str(y_pred.cpu().data.numpy()) + " " + P_X_str, fontsize=8)

        axes[k][t+T].clear(); axes[k][t+T].set_xticks([]); axes[k][t+T].set_yticks([])
        im_squeezed = rev_x_only[k,t,:].data.numpy()
        im = im_squeezed.reshape(20, -1)
#                 if not np.all(im > 0):
#                     print(im, k , t)
#                     im[im < 0] = 0
        im = 255 * (im - np.min(im))/(np.max(im) - np.min(im))
        axes[k][t+T].imshow(im, cmap='gray')
#                 if t == 3:
#                     print(im)
        y_samp_gt = y_samps_label[k,t,:].argmax(dim=0)
        title = '[{:d} {:.2f}]'.format(int(y_samp_gt.cpu().data.numpy()), P_X_x_samps[k][t])
        axes[k][t+T].set_title(title, fontsize=8, fontweight='bold')
    fig.canvas.draw()


<IPython.core.display.Javascript object>

IndexError: index 8 is out of bounds for axis 0 with size 8