This notebook contains training and evauation code for St-RKM model on 3Dshapes dataset (https://github.com/deepmind/3d-shapes). Play with the existing code and various tuning-paramteres. As an exercise, inspect the influence of number of gaussian mixtures on the generation quality of the model. 

### Install and activate python packages

In [None]:
!pip install -r requirements.txt
print('Successful install.')

### Import python packages

In [None]:
import logging
import argparse
import time
import os
import h5py
import numpy as np
import torch
import torch.nn as nn
from scipy import stats
import matplotlib.pyplot as plt
from tqdm import tqdm
import skimage
import scipy
import stiefel_optimizer
from skimage.transform import resize
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
%matplotlib inline

### Create directories to save checkpoints and training logs

In [None]:
class create_dirs:
    """ Creates directories for logging, Checkpoints and saving trained models """

    def __init__(self, name, ct):
        self.name = name
        self.ct = ct
        self.dircp = 'checkpoint.pth_{}.tar'.format(self.ct)
        self.dirout = '{}_Trained_rkm_{}.tar'.format(self.name, self.ct)

    def create(self):
        if not os.path.exists('cp/{}'.format(self.name)):
            os.makedirs('cp/{}'.format(self.name))

        if not os.path.exists('log/{}'.format(self.name)):
            os.makedirs('log/{}'.format(self.name))

        if not os.path.exists('out/{}'.format(self.name)):
            os.makedirs('out/{}'.format(self.name))

    def save_checkpoint(self, state, is_best):
        if is_best:
            torch.save(state, 'cp/{}/{}'.format(self.name, self.dircp))


### Misc utilities

In [None]:
def convert_to_imshow_format(image):
    # convert from CHW to HWC
    if image.shape[0] == 1:
        return image[0, :, :]
    else:
        if np.any(np.where(image < 0)):
            # first convert back to [0,1] range from [-1,1] range
            image = image / 2 + 0.5
        return image.transpose(1, 2, 0)
    
class Resize:
    def __init__(self, size):
        assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
        if isinstance(size, int):
            self._size = (size, size)
        else:
            self._size = size

    def __call__(self, img: np.ndarray):
        resize_image = skimage.transform.resize(img, self._size)
        # the resize will return a float32 array
        return skimage.util.img_as_float32(resize_image)
    
class Lin_View(nn.Module):
    """ Unflatten linear layer to be used in Convolution layer"""
    def __init__(self, c, a, b):
        super(Lin_View, self).__init__()
        self.c, self.a, self.b = c, a, b

    def forward(self, x):
        try:
            return x.view(x.size(0), self.c, self.a, self.b)
        except:
            return x.view(1, self.c, self.a, self.b)

### Define encoder ($\mathbf{\phi}_{\theta}$) / decoder ($\mathbf{\psi}_{\zeta}$) networks

In [None]:
class Net1(nn.Module):
    """ Encoder - network architecture """
    def __init__(self, nChannels, args, cnn_kwargs):
        super(Net1, self).__init__()  # inheritance used here.
        self.args = args
        self.main = nn.Sequential(
            nn.Conv2d(nChannels, self.args.capacity, **cnn_kwargs[0]),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(self.args.capacity, self.args.capacity * 2, **cnn_kwargs[0]),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Conv2d(self.args.capacity * 2, self.args.capacity * 4, **cnn_kwargs[1]),
            nn.LeakyReLU(negative_slope=0.2),

            nn.Flatten(),
            nn.Linear(self.args.capacity * 4 * cnn_kwargs[2] ** 2, self.args.x_fdim1),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(self.args.x_fdim1, self.args.x_fdim2),
        )

    def forward(self, x):
        return self.main(x)


class Net3(nn.Module):
    """ Decoder - network architecture """
    def __init__(self, nChannels, args, cnn_kwargs):
        super(Net3, self).__init__()
        self.args = args
        self.main = nn.Sequential(
            nn.Linear(self.args.x_fdim2, self.args.x_fdim1),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Linear(self.args.x_fdim1, self.args.capacity * 4 * cnn_kwargs[2] ** 2),
            nn.LeakyReLU(negative_slope=0.2),
            Lin_View(self.args.capacity * 4, cnn_kwargs[2], cnn_kwargs[2]),  # Unflatten

            nn.ConvTranspose2d(self.args.capacity * 4, self.args.capacity * 2, **cnn_kwargs[1]),
            nn.LeakyReLU(negative_slope=0.2),

            nn.ConvTranspose2d(self.args.capacity * 2, self.args.capacity, **cnn_kwargs[0]),
            nn.LeakyReLU(negative_slope=0.2),

            nn.ConvTranspose2d(self.args.capacity, nChannels, **cnn_kwargs[0]),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

### Define Stiefel-RKM Model
$\DeclareMathOperator{\tr}{\mathrm{Tr}}$
$\DeclareMathOperator{\St}{St}$

Below we present the Stiefel-Restricted Kernel Machine model as defined in 'Disentangled Representation Learning and Generation with Manifold Optimization' (https://arxiv.org/abs/2006.07046). 

The objective function is
\begin{equation}
\min_{\substack{ U\in \St(\ell,m)\\\mathbf{\theta}, \mathbf{\xi}}} \tr\left( C_{\mathbf{\theta}} - \mathbb{P}_U C_{\mathbf{\theta}} \mathbb{P}_U\right) + \lambda \frac{1}{n}\sum_{i=1}^{n}\left\{ L_{\mathbf{\xi},U}(\mathbf{x}_i,\mathbf{\phi}_{\mathbf{\theta}}(\mathbf{x}_i))\right\},
\end{equation}
where $U = [\mathbf{u}_1 \dots \mathbf{u}_m]$ is the interconnection matrix belonging to the Stiefel manifold $\St(\ell,m)$, that is, the set of $\ell\times m$ matrices with orthonormal columns ($\ell\geq m$), $\mathbb{P}_U = U U^\top$ is the projector and C is the covariance martrix. 

Here the proposed loss function is
\begin{equation}L^{(\sigma)}_{\mathbf{\xi},U}(\mathbf{x},\mathbf{z}) =\mathbb{E}_{\mathbf{\epsilon}\sim\mathcal{N}(0,\mathbb{I}_m)} \left\|\mathbf{x} - \mathbf{\psi}_{\mathbf{\xi}}\big(\mathbb{P}_U\mathbf{z}+\sigma  U\mathbf{\epsilon}\big)\right\|_2^2, \end{equation}

with $\mathbf{z} = \mathbf{\phi}_{\mathbf{\theta}}(\mathbf{x})$, which is deterministic for $\sigma = 0$. The noise term $\sigma U\mathbf{\epsilon}$ promotes a _smoother_ decoder network. Another option for the loss used below is the _splitted AE loss_

\begin{equation}
L^{(\sigma),sl}_{\mathbf{\xi},U}(\mathbf{x},\mathbf{z}) =L^{(0)}_{\mathbf{\xi},U}(\mathbf{x},\mathbf{z})+\mathbb{E}_{\mathbf{\epsilon}\sim\mathcal{N}(0,\mathbb{I}_m)} \left\|\mathbf{\psi}_{\mathbf{\xi}}\big(\mathbb{P}_U\mathbf{z}\big) - \mathbf{\psi}_{\mathbf{\xi}}\big(\mathbb{P}_U\mathbf{z}+\sigma  U\mathbf{\epsilon}\big)\right\|_2^2.
\end{equation}

In [None]:
class RKM_Stiefel(nn.Module):
    """ Defines the Stiefel RKM model and its loss functions """
    def __init__(self, ipVec_dim, args, nChannels=1, recon_loss=nn.MSELoss(reduction='sum'), ngpus=1):
        super(RKM_Stiefel, self).__init__()
        self.ipVec_dim = ipVec_dim
        self.ngpus = ngpus
        self.args = args
        self.nChannels = nChannels
        self.recon_loss = recon_loss

        # Initialize manifold parameter
        self.manifold_param = nn.Parameter(nn.init.orthogonal_(torch.Tensor(self.args.h_dim, self.args.x_fdim2)))

        # Settings for Conv layers
        self.cnn_kwargs = dict(kernel_size=4, stride=2, padding=1)
        if self.ipVec_dim <= 28*28*3:
            self.cnn_kwargs = self.cnn_kwargs, dict(kernel_size=3, stride=1), 5
        else:
            self.cnn_kwargs = self.cnn_kwargs, self.cnn_kwargs, 8

        self.encoder = Net1(self.nChannels, self.args, self.cnn_kwargs)
        self.decoder = Net3(self.nChannels, self.args, self.cnn_kwargs)

    def forward(self, x):
        op1 = self.encoder(x)  # features
        op1 = op1 - torch.mean(op1, dim=0)  # feature centering
        C = torch.mm(op1.t(), op1)  # Covariance matrix

        """ Various types of losses as described in paper """
        if self.args.loss == 'splitloss':
            x_tilde1 = self.decoder(torch.mm(torch.mm(op1, self.manifold_param.t())
                                            + self.args.noise_level * torch.randn((x.shape[0], self.args.h_dim)).to(self.args.proc),
                                            self.manifold_param))
            x_tilde2 = self.decoder(torch.mm(torch.mm(op1, self.manifold_param.t()), self.manifold_param))
            f2 = self.args.c_accu * 0.5 * (
                    self.recon_loss(x_tilde2.view(-1, self.ipVec_dim), x.view(-1, self.ipVec_dim))
                    + self.recon_loss(x_tilde2.view(-1, self.ipVec_dim),
                                      x_tilde1.view(-1, self.ipVec_dim))) / x.size(0)  # Recons_loss

        elif self.args.loss == 'noisyU':
            x_tilde = self.decoder(torch.mm(torch.mm(op1, self.manifold_param.t())
                                            + self.args.noise_level * torch.randn((x.shape[0], self.args.h_dim)).to(self.args.proc),
                                            self.manifold_param))
            f2 = self.args.c_accu * 0.5 * (
                self.recon_loss(x_tilde.view(-1, self.ipVec_dim), x.view(-1, self.ipVec_dim))) / x.size(0)  # Recons_loss

        elif self.args.loss == 'deterministic':
            x_tilde = self.decoder(torch.mm(op1, torch.mm(self.manifold_param.t(), self.manifold_param)))
            f2 = self.args.c_accu * 0.5 * (self.recon_loss(x_tilde.view(-1, self.ipVec_dim), x.view(-1, self.ipVec_dim)))/x.size(0)  # Recons_loss

        f1 = torch.trace(C - torch.mm(torch.mm(self.manifold_param.t(), self.manifold_param), C))/x.size(0)  # KPCA
        return f1 + f2, f1, f2

# Accumulate trainable parameters in 2 groups. 1. Manifold_params 2. Network param
def param_state(model):
    param_g, param_e1 = [], []
    for name, param in model.named_parameters():
        if param.requires_grad and name != 'manifold_param':
            param_e1.append(param)
        elif name == 'manifold_param':
            param_g.append(param)
    return param_g, param_e1

def stiefel_opti(stief_param, lrg=1e-4):
    dict_g = {'params': stief_param, 'lr': lrg, 'momentum': 0.9, 'weight_decay': 0.0005, 'stiefel': True}
    return stiefel_optimizer.AdamG([dict_g])  # CayleyAdam

def final_compute(model, args, ct, device=torch.device('cuda')):
    """ Utility to re-compute U. Since some datasets could exceed the GPU memory limits, some intermediate
    variables are saved  on HDD, and retrieved later"""
    if not os.path.exists('oti/'):
        os.makedirs('oti/')

    args.shuffle = False
    x, _, _ = get_3dshapes_dataloader(args)

    # Compute feature-vectors
    for i, sample_batch in enumerate(tqdm(x)):
        torch.save({'oti': model.encoder(sample_batch[0].to(device))},
                   'oti/oti{}_checkpoint.pth_{}.tar'.format(i, ct))

    # Load feature-vectors
    ot = torch.Tensor([]).to(device)
    for i in range(0, len(x)):
        ot = torch.cat((ot, torch.load('oti/oti{}_checkpoint.pth_{}.tar'.format(i, ct))['oti']), dim=0)
    os.removedirs("oti/")

    ot = (ot - torch.mean(ot, dim=0)).to(device)  # Centering
    u, _, _ = torch.svd(torch.mm(ot.t(), ot))
    u = u[:, :args.h_dim]
    with torch.no_grad():
        model.manifold_param.masked_scatter_(model.manifold_param != u.t(), u.t())
    return torch.mm(ot, u.to(device)), u

### Define dataloader

In [None]:
def get_3dshapes_dataloader(args, path_to_data='3dshapes'):
    """3dshapes dataloader with images rescaled to (28,28,3)"""
    name = '{}/3dshapes.h5'.format(path_to_data)
    if not os.path.exists(name):
        print('Data at the given path doesn\'t exist. Downloading now...')
        os.system("  mkdir 3dshapes;"
                  "  wget -O 3dshapes/3dshapes.h5 https://storage.googleapis.com/3d-shapes/3dshapes.h5")
        print('Done.')

    transform = transforms.Compose([Resize(28), transforms.ToTensor()])
    print('Loading data...')
    d3shapes_data = d3shapesDataset(name, transform=transform)
    d3shapes_loader = DataLoader(d3shapes_data, batch_size=args.mb_size,
                                 shuffle=args.shuffle, pin_memory=True, num_workers=args.workers)
    _, c, x, y = next(iter(d3shapes_loader))[0].size()
    return d3shapes_loader, c*x*y, c


class d3shapesDataset(Dataset):
    """3dshapes dataloader class"""

    lat_names = ('floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape', 'orientation')
    lat_sizes = np.array([10, 10, 10, 8, 4, 15])

    def __init__(self, path_to_data, subsample=1, transform=None):
        """
        Parameters
        ----------
        subsample : int
            Only load every |subsample| number of images.
        """
        dataset = h5py.File(path_to_data, 'r')
        self.imgs = dataset['images'][::subsample]
        self.lat_val = dataset['labels'][::subsample]
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        sample = self.imgs[idx] / 255
        if self.transform:
            sample = self.transform(sample)
        return sample, self.lat_val[idx]

### Training and hyper-parameter settings

In [None]:
# Model Settings ================================
parser = argparse.ArgumentParser(description='St-RKM Model', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('--dataset_name', type=str, default='3dshapes')
parser.add_argument('--h_dim', type=int, default=6, help='Dim of latent vector')
parser.add_argument('--capacity', type=int, default=48, help='Conv_filters of network')
parser.add_argument('--mb_size', type=int, default=256, help='Mini-batch size')
parser.add_argument('--x_fdim1', type=int, default=256, help='Input x_fdim1')
parser.add_argument('--x_fdim2', type=int, default=50, help='Input x_fdim2')
parser.add_argument('--c_accu', type=float, default=1, help='Input weight on recons_error')
parser.add_argument('--noise_level', type=float, default=1e-3, help='Noise-level')
parser.add_argument('--loss', type=str, default='deterministic', help='loss type: deterministic/noisyU/splitloss')

# Training Settings =============================
# Change the device type to the resources on your computer
parser.add_argument('--lr', type=float, default=2e-4, help='Input learning rate for ADAM optimizer')
parser.add_argument('--lrg', type=float, default=1e-4, help='Input learning rate for Cayley_ADAM optimizer')
parser.add_argument('--max_epochs', type=int, default=1, help='Input max_epoch')
parser.add_argument('--proc', type=str, default='cuda', help='device type: cuda or cpu')
parser.add_argument('--workers', type=int, default=16, help='Number of workers for dataloader')
parser.add_argument('--shuffle', type=bool, default=True, help='shuffle dataset: True/False')

opt = parser.parse_args(args=[])

In [None]:
device = torch.device(opt.proc)

if torch.cuda.is_available():
    torch.cuda.empty_cache()

ct = time.strftime("%Y%m%d-%H%M")
dirs = create_dirs(name=opt.dataset_name, ct=ct)
dirs.create()

# noinspection PyArgumentList
logging.basicConfig(level=logging.INFO,
                    format='%(message)s',
                    handlers=[logging.FileHandler('log/{}/{}_{}.log'.format(opt.dataset_name, opt.dataset_name, ct)),
                              logging.StreamHandler()])

### Load and visualize training data

In [None]:
xtrain, ipVec_dim, nChannels = get_3dshapes_dataloader(args=opt)

# Visualize
perm1 = torch.randperm(len(xtrain.dataset))
it = 0
fig, ax = plt.subplots(5, 5)
for i in range(5):
    for j in range(5):
        ax[i, j].imshow(convert_to_imshow_format(xtrain.dataset[perm1[it]][0].numpy()))
        it+=1
plt.suptitle('Ground Truth Data')
plt.setp(ax, xticks=[], yticks=[])
plt.show()

### Initialize the model and optimizer

In [None]:
ngpus = torch.cuda.device_count()

rkm = RKM_Stiefel(ipVec_dim=ipVec_dim, args=opt, nChannels=nChannels, ngpus=ngpus).to(device)
logging.info(rkm)
logging.info(opt)
logging.info('\nN: {}, mb_size: {}'.format(len(xtrain.dataset), opt.mb_size))
logging.info('We are using {} GPU(s)!'.format(ngpus))

# Accumulate trainable parameters in 2 groups. 1. Manifold_params 2. Network params
param_g, param_e1 = param_state(rkm)

optimizer1 = stiefel_opti(param_g, opt.lrg)
optimizer2 = torch.optim.Adam(param_e1, lr=opt.lr, weight_decay=0)

### Train the model
Note: Since the training might take longer and could be difficult on laptop, we skip the following 2 cells and go straight to downloading/evaluating pre-trained model.

In [None]:
start = datetime.now()
Loss_stk = np.empty(shape=[0, 3])
cost, l_cost = np.inf, np.inf  # Initialize cost
is_best = False
t = 1
while cost > 1e-10 and t <= opt.max_epochs:  # run epochs until convergence or cut-off
    avg_loss, avg_f1, avg_f2 = 0, 0, 0

    for _, sample_batched in enumerate(tqdm(xtrain, desc="Epoch {}/{}".format(t, opt.max_epochs))):
        loss, f1, f2 = rkm(sample_batched[0].to(device))

        optimizer1.zero_grad()
        optimizer2.zero_grad()
        loss.backward()
        optimizer2.step()
        optimizer1.step()

        avg_loss += loss.item()
        avg_f1 += f1.item()
        avg_f2 += f2.item()
    cost = avg_loss

    # Remember lowest cost and save checkpoint
    is_best = cost < l_cost
    l_cost = min(cost, l_cost)
    dirs.save_checkpoint({
        'epochs': t,
        'rkm_state_dict': rkm.state_dict(),
        'optimizer1': optimizer1.state_dict(),
        'optimizer2': optimizer2.state_dict(),
        'Loss_stk': Loss_stk,
    }, is_best)

    logging.info('Epoch {}/{}, Loss: [{}], Kpca: [{}], Recon: [{}]'.format(t, opt.max_epochs, cost, avg_f1, avg_f2))
    Loss_stk = np.append(Loss_stk, [[cost, avg_f1, avg_f2]], axis=0)
    t += 1

logging.info('Finished Training. Lowest cost: {}'
             '\nLoading best checkpoint [{}] & computing sub-space...'.format(l_cost, dirs.dircp))

sd_mdl = torch.load('cp/{}/{}'.format(opt.dataset_name, dirs.dircp))
rkm.load_state_dict(sd_mdl['rkm_state_dict'])

h, U = final_compute(model=rkm, args=opt, ct=ct)
logging.info("\nTraining complete in: " + str(datetime.now() - start))

### Save model and tensors

In [None]:
torch.save({'rkm': rkm,
            'rkm_state_dict': rkm.state_dict(),
            'optimizer1': optimizer1.state_dict(),
            'optimizer2': optimizer2.state_dict(),
            'Loss_stk': Loss_stk,
            'opt': opt,
            'h': h, 'U': U}, 'out/{}/{}'.format(opt.dataset_name, dirs.dirout))
logging.info('\nSaved File: {}'.format(dirs.dirout))

# Evaluate Model

### Download pre-trained model

In [None]:
from sklearn.mixture import GaussianMixture as GMM
import matplotlib.pyplot as plt

if not os.path.exists('out/3dshapes/3dshapes_Trained_rkm.tar'):
        print('Pre-trained model at given path doesn\'t exist. Downloading now...')
        os.system("  wget -O out/3dshapes/3dshapes_Trained_rkm.tar https://www.dropbox.com/s/chwzwaodljq2bn9/3dshapes_Trained_rkm.tar?dl=1")
        print('Done.')

In [None]:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--filename', type=str, default='3dshapes_Trained_rkm', help='Enter Filename')
parser.add_argument('--dataset_name', default='3dshapes', type=str, help='Enter dataset name')
opt_gen = parser.parse_args(args=[])

sd_mdl = torch.load('out/{}/{}.tar'.format(opt_gen.dataset_name, opt_gen.filename),
                    map_location=lambda storage, loc: storage)

rkm = sd_mdl['rkm']
rkm.load_state_dict(sd_mdl['rkm_state_dict'])
h = sd_mdl['h']
U = sd_mdl['U']
opt = sd_mdl['opt']

""" Load Data """
opt.mb_size = 500
opt.workers = 16
opt.shuffle = True
opt = argparse.Namespace(**vars(opt), **vars(opt_gen))

WH = next(iter(xtrain))[0].shape[2]  # Number of channels in image


In [None]:
with torch.no_grad():
    # Visualize correlatedness of latent variables
    cov = torch.mm(torch.t(h), h)
    print('Cov_mat:\n {}'.format(cov))
    plt.figure()
    plt.imshow(cov.detach().cpu().numpy())
    plt.title('$Cov(H^{T}H)$')
    plt.show()
    
    # Visualize quality of reconstructed samples
    perm1 = torch.randperm(len(xtrain.dataset))
    m = 5
    fig2, ax = plt.subplots(m, m)
    it = 0
    for i in range(m):
        for j in range(m):
            ax[i, j].imshow(convert_to_imshow_format(xtrain.dataset[perm1[it]][0].numpy()))
            it += 1
    plt.suptitle('Ground Truth')
    plt.setp(ax, xticks=[], yticks=[])
    plt.show()

    fig1, ax = plt.subplots(m, m)
    x_gen = rkm.decoder(torch.mm(h[perm1[:m * m], :], U.t()).float()).detach().numpy().reshape(-1, nChannels, WH, WH)
    it = 0
    for i in range(m):
        for j in range(m):
            ax[i, j].imshow(convert_to_imshow_format(x_gen[it, :, :, :]))
            it += 1
    plt.suptitle('Reconstructed samples')
    plt.setp(ax, xticks=[], yticks=[])
    plt.show()
    
    # Random samples from distribution over H ============================================================
    print('Generating random images')
    n_components = 1 # Number of components for the GMM
    n_samples = 30 # Number of samples for the GMM
    
    gmm = GMM(n_components=n_components, covariance_type='full').fit(h.detach().cpu().numpy())
    z = torch.FloatTensor(gmm.sample(n_samples)[0])

    x_gen = rkm.decoder(torch.mm(z, U.t())).detach().cpu().numpy().reshape(-1, nChannels, WH, WH)
    
    m = 5 # Parameter for plotting
    fig1, ax = plt.subplots(m, m)
    it = 0
    for i in range(m):
        for j in range(m):
            ax[i, j].imshow(convert_to_imshow_format(x_gen[it, :, :, :]))
            it += 1
    plt.suptitle('Random generation')
    plt.setp(ax, xticks=[], yticks=[])
    plt.show()
    
    m = 5  # Number of steps
    fig2, ax = plt.subplots(opt.h_dim, m)

    # Interpolation along principal components ================
    for i in range(opt.h_dim):
        dim = i
        mul_off = 0.5  # (for no-offset, set multiplier to 0)

        # Manually set the linspace range or get from Unit-Gaussian
        lambd = torch.linspace(-2, 2, steps=m)
        # lambd = torch.linspace(*utils._get_traversal_range(0.475), steps=m)

        uvec = torch.FloatTensor(torch.zeros(h.shape[1]))
        uvec[dim] = 1  # unit vector
        yoff = mul_off * torch.ones(h.shape[1]).float()
        yoff[dim] = 0

        yop = yoff.repeat(lambd.size(0), 1) + torch.mm(torch.diag(lambd),
                                                       uvec.repeat(lambd.size(0), 1))  # Traversal vectors
        x_gen = rkm.decoder(torch.mm(yop, U.t()).float()).detach().numpy().reshape(-1, nChannels, WH, WH)

        # Save Images in the directory
        if not os.path.exists('Traversal_imgs/{}/{}/{}'.format(opt.dataset_name, opt.filename, dim)):
            os.makedirs('Traversal_imgs/{}/{}/{}'.format(opt.dataset_name, opt.filename, dim))

        for j in range(x_gen.shape[0]):
            scipy.misc.imsave(
                'Traversal_imgs/{}/{}/{}/{}im{}.png'.format(opt.dataset_name, opt.filename, dim, dim, j),
                convert_to_imshow_format(x_gen[j, :, :, :]))
            ax[i, j].imshow(convert_to_imshow_format(x_gen[j, :, :, :]))
            
    plt.suptitle('Interpolation')
    plt.setp(ax, xticks=[], yticks=[])
    plt.show()

    print('Traversal Images saved in: Traversal_imgs/{}/{}/'.format(opt.dataset_name, opt.filename))