# VAE 구현


## 외부 파일 가져오기 & requirements 설치

In [1]:
from google.colab import drive
drive.mount("/content/drive")
import os
import sys
from datetime import datetime

drive_project_root = "/content/drive/MyDrive/#fastcampus"
sys.path.append(drive_project_root)
!pip install -r "/content/drive/MyDrive/#fastcampus/requirements.txt"

Mounted at /content/drive
Collecting pytorch-lightning==1.3.8
  Downloading pytorch_lightning-1.3.8-py3-none-any.whl (813 kB)
[K     |████████████████████████████████| 813 kB 13.4 MB/s 
[?25hCollecting torch-optimizer==0.1.0
  Downloading torch_optimizer-0.1.0-py3-none-any.whl (72 kB)
[K     |████████████████████████████████| 72 kB 1.2 MB/s 
[?25hCollecting hydra-core==1.1
  Downloading hydra_core-1.1.0-py3-none-any.whl (144 kB)
[K     |████████████████████████████████| 144 kB 69.3 MB/s 
[?25hCollecting wandb==0.11.1
  Downloading wandb-0.11.1-py2.py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 26.9 MB/s 
Collecting efficientnet_pytorch==0.7.1
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
Collecting tensorflow-addons
  Downloading tensorflow_addons-0.14.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 63.6 MB/s 
[?25hCollecting future>=0.17.1
  Downloading future-0.18.2.ta

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)


Sat Nov  6 16:03:05 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
from typing import List
from typing import Dict
from typing import Union
from typing import Any
from typing import Optional
from typing import Iterable
from typing import Callable
from abc import abstractmethod
from abc import ABC
from datetime import datetime
from functools import partial
from collections import Counter
from collections import OrderedDict
import random
import math
from functools import partial

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# data & models
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torchvision.utils as vutils

# For configuration
from omegaconf import DictConfig
from omegaconf import OmegaConf
import hydra
from hydra.core.config_store import ConfigStore

# For logger
from torch.utils.tensorboard import SummaryWriter
import wandb
os.environ["WANDB_START_METHOD"]="thread"

In [4]:
from data_utils import dataset_split
from config_utils import flatten_dict
from config_utils import register_config
from config_utils import configure_optimizers_from_cfg
from config_utils import get_loggers
from config_utils import get_callbacks

# Base 모델 정의

In [None]:
# Define Model
class BaseGenerativeModel(pl.LightningModule):
    def __init__(self, cfg: DictConfig):
        pl.LightningModule.__init__(self)
        self.cfg = cfg
    
    @abstractmethod
    def forward(self, x):
        raise NotImplementedError()
    
    def sample_generate(self, *args, **kwargs):
        raise NotImplementedError()
    
    def loss_function(self, *args, **kwargs):
        pass
        
    def configure_optimizers(self):
        self._optimizers, self._schedulers = configure_optimizers_from_cfg(
            self.cfg, self
        )
        return self._optimizers, self._schedulers
    
    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass
    

class VanillaVAE(BaseGenerativeModel):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)
        self.latent_dim = cfg.model.latent_dim

        # define posterior (encoder) modules
        posterior_mlp_modules_list = []
        prev_dim = cfg.model.posterior.hidden_dims[0]

        for h_dim in cfg.model.posterior.hidden_dims[1:]:
            posterior_mlp_modules_list.append(nn.Linear(prev_dim, h_dim))
            prev_dim = h_dim

        self.posterior_mlp_modules = nn.Sequential(
            *posterior_mlp_modules_list
        )

        # define latent encode
        self.posterior_mu = nn.Linear(
            cfg.model.posterior.hidden_dims[-1], self.latent_dim
        )
        self.posterior_log_var = nn.Linear(
            cfg.model.posterior.hidden_dims[-1], self.latent_dim
        )

        # define prior (decoder) modules
        prior_mlp_modules_list = []
        prev_dim = self.latent_dim
        for h_dim in cfg.model.prior.hidden_dims:
            prior_mlp_modules_list.append(nn.Linear(prev_dim, h_dim))
            prev_dim = h_dim
        self.prior_mlp_modules = nn.Sequential(
            *prior_mlp_modules_list
        )

    def encode(self, input):
        hidden = self.posterior_mlp_modules(input)
        mu = self.posterior_mu(hidden)
        log_var = self.posterior_log_var(hidden)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mu + (epsilon * std)
    
    def decode(self, z):
        result = self.prior_mlp_modules(z)
        return torch.sigmoid(result)
    
    def forward(self, x): # B, 1, 28, 28
        input = x.view(-1, self.cfg.model.posterior.hidden_dims[0])
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var  # return recons, mu, log_var
    
    def sample_generate(
        self,
        num_samples: int = 64, # for training (=batch_size) or random sampling
        z: Optional[torch.Tensor] = None, # for manual sample generation
    ):
        if z is None:
            z = torch.randn(num_samples, self.latent_dim)
        else:
            num_samples = z.shape[0]
        assert z.shape[-1] == self.latent_dim
        z = z.to(self.device)
        samples = self.decode(z)
        return samples.view(num_samples, self.cfg.data.C, self.cfg.data.H, self.cfg.data.W)
    
    def loss_function(
        self,
        recons,
        real_img,
        mu,
        log_var,
        kld_weight,
        mode="train"
    ) -> dict:
        assert mode in ["train", "val", "test"]
        
        # reconstruction loss
        recons_loss = F.binary_cross_entropy(
            recons,
            real_img.view(-1, self.cfg.model.prior.hidden_dims[-1]),
            reduction="sum"
        )

        # kld loss
        kld_loss = 0.5 * torch.sum(
            mu.pow(2) + log_var.exp() - log_var - 1
        )

        # summation
        loss = recons_loss + kld_weight * kld_loss
        loss_result = {
            f"{mode}_loss": loss / self.cfg.data[f"num_{mode}_imgs"],
            f"{mode}_reconstruction_loss": recons_loss / self.cfg.data[f"num_{mode}_imgs"],
            f"{mode}_kld_loss": kld_loss / self.cfg.data[f"num_{mode}_imgs"],
        }
        return loss_result
    
    def training_step(self, batch, batch_idx):
        real_img, labels = batch
        recons, mu, log_var = self.forward(real_img)
        loss_results = self.loss_function(
            recons,
            real_img, 
            mu,
            log_var,
            kld_weight=self.cfg.model.kld_weight,
            mode="train"
        )
        loss_results["loss"] = loss_results["train_loss"]
        self.log_dict(loss_results)
        return loss_results
    
    def validation_step(self, batch, batch_idx):
        real_img, labels = batch
        recons, mu, log_var = self.forward(real_img)
        loss_results = self.loss_function(
            recons,
            real_img, 
            mu,
            log_var,
            kld_weight=self.cfg.model.kld_weight,
            mode="val"
        )
        self.log_dict(loss_results)

        # random sample_generate
        sample_gens = self.sample_generate(real_img.shape[0])

        self.logger.experiment[0].log({
            "inputs": wandb.Image(real_img),
            "recons": wandb.Image(recons.view(-1, self.cfg.data.C, self.cfg.data.H, self.cfg.data.W)),
            "sample_gens": wandb.Image(sample_gens),
        })

        return loss_results
    


## Configuration Def



In [None]:
# data configs
data_fashion_mnist_cfg = {
    "name": "fashion_mnist",
    "data_root": os.path.join(os.getcwd(), "data"),
    "transforms": [
        {
            "name": "ToTensor",
            "kwargs": {}
        }
    ],
    "W": 28,
    "H": 28,
    "C": 1,
    "n_class": 10,
}

# model configs
model_mnist_vanilla_vae_cfg = {
    "name": "VanillaVAE",
    "latent_dim": 4,
    "posterior": {
        "hidden_dims": [28*28, 512, 256],
    }, # encoder
    "prior": {
        "hidden_dims": [256, 512, 28*28],
    }, # decoder
    "kld_weight": 1,
}


# optimizer configs
opt_cfg = {
    "optimizers": [
        {
            "name": "RAdam",
            "kwargs": {
                "lr": 1e-3,
            }
        }
    ],
    "lr_schedulers": [
        {
            "name": None,
            "kwargs": {
                "warmup_end_steps": 1000
            }
        },
    ]
}

_merged_cfg_presets = {
    "vanilla_vae_fashion_mnist": {
        "opt": opt_cfg,
        "data": data_fashion_mnist_cfg,
        "model": model_mnist_vanilla_vae_cfg,
    },
}

# clear config instance first
hydra.core.global_hydra.GlobalHydra.instance().clear()

# register preset configs
register_config(_merged_cfg_presets)

# initialize & make config
## select mode here ##
# .................. #
hydra.initialize(config_path=None)
cfg = hydra.compose("vanilla_vae_fashion_mnist")

# override some cfg
run_name = f"{datetime.now().isoformat(timespec='seconds')}-{cfg.model.name}-{cfg.data.name}"

# Define other train configs & log_configs 
# Merge configs into one & register it to Hydra.
project_root_dir = os.path.join(
    drive_project_root, "runs", "generative-dnn-tutorial-fashion-mnist-runs"
)
save_dir = os.path.join(project_root_dir, run_name)
run_root_dir = os.path.join(project_root_dir, run_name)

train_cfg = {
    "train_batch_size": 256,
    "val_batch_size": 64,
    "test_batch_size": 64,
    "train_val_split": [0.9, 0.1],
    "run_root_dir": run_root_dir,
    "trainer_kwargs": {
        "accelerator": "dp",
        "gpus": "0",
        "max_epochs": 50,
        "val_check_interval": 1.0,
        "log_every_n_steps": 100,
        "flush_logs_every_n_steps": 100,
    }
}

# logger config
log_cfg = {
    "loggers": {
        "WandbLogger": {
            "project": "fastcampus_generative_fashion_mnist_tutorials",
            "name": run_name,
            "tags": ["fastcampus_generative_fashion_mnist_tutorials"],
            "save_dir": run_root_dir,
        },
        "TensorBoardLogger": {
            "save_dir": project_root_dir,
            "name": run_name,
            "log_graph": True,
        }
    },
    "callbacks": {
        "ModelCheckpoint": {
            "save_top_k": 3,
            "monitor": "val_loss",
            "mode": "min",
            "verbose": True,
            "dirpath": os.path.join(run_root_dir, "weights"),
            "filename": "{epoch}-{val_loss:.3f}",
        },
        "EarlyStopping": {
            "monitor": "val_loss",
            "mode": "min",
            "patience": 10,
            "verbose": True,
        }
    }
}

# unlock config & set train_cfg & log_cfg
OmegaConf.set_struct(cfg, False)
cfg.train = train_cfg 
cfg.log = log_cfg

# print config
print(OmegaConf.to_yaml(cfg))

opt:
  optimizers:
  - name: RAdam
    kwargs:
      lr: 0.001
  lr_schedulers:
  - name: null
    kwargs:
      warmup_end_steps: 1000
data:
  name: fashion_mnist
  data_root: /content/data
  transforms:
  - name: ToTensor
    kwargs: {}
  W: 28
  H: 28
  C: 1
  n_class: 10
model:
  name: VanillaVAE
  latent_dim: 4
  posterior:
    hidden_dims:
    - 784
    - 512
    - 256
  prior:
    hidden_dims:
    - 256
    - 512
    - 784
  kld_weight: 1
train:
  train_batch_size: 256
  val_batch_size: 64
  test_batch_size: 64
  train_val_split:
  - 0.9
  - 0.1
  run_root_dir: /content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist
  trainer_kwargs:
    accelerator: dp
    gpus: '0'
    max_epochs: 50
    val_check_interval: 1.0
    log_every_n_steps: 100
    flush_logs_every_n_steps: 100
log:
  loggers:
    WandbLogger:
      project: fastcampus_generative_fashion_mnist_tutorials
      name: 2021-09-12T15:51:37-VanillaVAE-

## Data & dataloader

In [None]:
# get transforms from torch.vision
def get_transforms(cfg: DictConfig):
    transforms_list = []
    for tfm in cfg.data.transforms:
        if hasattr(transforms, tfm.name):
            transforms_list.append(
                getattr(transforms, tfm.name)(**tfm.kwargs)
            )
        else:
            raise ValueError(
                f"Not supported transform {tfm} in torch.vision.transform"
            )
    return transforms.Compose(transforms_list)

transform = get_transforms(cfg)


def get_datasets(
    cfg: DictConfig, download: bool = True
) -> Dict[str, torch.utils.data.Dataset]:
    data_root = cfg.data.data_root
    datasets = {}
    if cfg.data.name == "fashion_mnist":
        fashion_mnist_dataset = FashionMNIST(data_root, download=download, train=True, transform=transform)
        datasets = dataset_split(
            fashion_mnist_dataset, split=cfg.train.train_val_split
        )
        datasets["test"] = FashionMNIST(data_root, download=download, train=False, transform=transforms.ToTensor())
    else:
        raise NotImplementedError("Not supported dataset yet")
    return datasets

datasets = get_datasets(cfg, download=True)

train_dataset = datasets["train"]
val_dataset = datasets["val"]
test_dataset = datasets["test"]

# save_dataset_N
cfg.data.num_train_imgs = len(datasets["train"])
cfg.data.num_val_imgs = len(datasets["val"])
cfg.data.num_test_imgs = len(datasets["test"])

# define dataloader
train_batch_size = cfg.train.train_batch_size
val_batch_size = cfg.train.val_batch_size
test_batch_size = cfg.train.test_batch_size

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=0
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=val_batch_size, shuffle=False, num_workers=0
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=0
)


In [None]:
# model define
def get_pl_model(cfg: DictConfig, checkpoint_path: Optional[str] = None):
    if cfg.model.name == "VanillaVAE":
        model = VanillaVAE(cfg)
    else:
        raise NotImplementedError("not implemented model")
    
    if checkpoint_path is not None:
        model = model.load_from_checkpoint(cfg=cfg, checkpoint_path=checkpoint_path)
    return model

model = None
model = get_pl_model(cfg)

print(model)


VanillaVAE(
  (posterior_mlp_modules): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
  )
  (posterior_mu): Linear(in_features=256, out_features=4, bias=True)
  (posterior_log_var): Linear(in_features=256, out_features=4, bias=True)
  (prior_mlp_modules): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=784, bias=True)
  )
)


In [None]:
# pytorch-lightning trainer def
logger = get_loggers(cfg)
callbacks = get_callbacks(cfg)

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    default_root_dir=cfg.train.run_root_dir,
    num_sanity_val_steps=2,
    **cfg.train.trainer_kwargs,
)

VBox(children=(Label(value=' 0.48MB of 0.48MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
_runtime,6.0
_timestamp,1631461886.0
_step,2.0
train_loss,1.47884
train_reconstruction_loss,1.43011
train_kld_loss,0.04874
loss,1.47884
epoch,0.0
trainer/global_step,99.0


0,1
_runtime,▁▁█
_timestamp,▁▁█
_step,▁▅█
train_loss,▁
train_reconstruction_loss,▁
train_kld_loss,▁
loss,▁
epoch,▁
trainer/global_step,▁


  f"Parsing of the Trainer argument gpus='{s}' (string) will change in the future."
GPU available: True, used: True
TPU available: False, using: 0 TPU cores




```
# 코드로 형식 지정됨
```



In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/\#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/

# train/val fit
# model = modelload_from_checkpoint()
trainer.fit(model, train_dataloader, val_dataloader)

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 604), started 0:10:24 ago. (Use '!kill 604' to kill it.)

<IPython.core.display.Javascript object>

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name                  | Type       | Params
-----------------------------------------------------
0 | posterior_mlp_modules | Sequential | 533 K 
1 | posterior_mu          | Linear     | 1.0 K 
2 | posterior_log_var     | Linear     | 1.0 K 
3 | prior_mlp_modules     | Sequential | 535 K 
-----------------------------------------------------
1.1 M     Trainable params
0         Non-trainable params
1.1 M     Total params
4.281     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 3.059
Epoch 0, global step 210: val_loss reached 3.05861 (best 3.05861), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=0-val_loss=3.059.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.040 >= min_delta = 0.0. New best score: 3.018
Epoch 1, global step 421: val_loss reached 3.01843 (best 3.01843), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=1-val_loss=3.018.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.013 >= min_delta = 0.0. New best score: 3.005
Epoch 2, global step 632: val_loss reached 3.00550 (best 3.00550), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=2-val_loss=3.005.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 3, global step 843: val_loss reached 3.00896 (best 3.00550), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=3-val_loss=3.009.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 3.005
Epoch 4, global step 1054: val_loss reached 3.00515 (best 3.00515), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=4-val_loss=3.005.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 5, global step 1265: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 3.000
Epoch 6, global step 1476: val_loss reached 2.99999 (best 2.99999), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=6-val_loss=3.000.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 7, global step 1687: val_loss reached 3.00009 (best 2.99999), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=7-val_loss=3.000.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 2.999
Epoch 8, global step 1898: val_loss reached 2.99949 (best 2.99949), saving model to "/content/drive/MyDrive/#fastcampus/runs/generative-dnn-tutorial-fashion-mnist-runs/2021-09-12T15:51:37-VanillaVAE-fashion_mnist/weights/epoch=8-val_loss=2.999.ckpt" as top 3


Validating: 0it [00:00, ?it/s]

Epoch 9, global step 2109: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Epoch 10, global step 2320: val_loss was not in top 3


Validating: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 3 records. Best score: 2.999. Signaling Trainer to stop.
Epoch 11, global step 2531: val_loss was not in top 3


In [None]:
# test 
ckpt_path = os.path.join(
    cfg.log.callbacks.ModelCheckpoint.dirpath,
    "epoch=8-val_loss=2.999.ckpt"
)
model = get_pl_model(cfg, ckpt_path).eval()
print(model)

VanillaVAE(
  (posterior_mlp_modules): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
  )
  (posterior_mu): Linear(in_features=256, out_features=4, bias=True)
  (posterior_log_var): Linear(in_features=256, out_features=4, bias=True)
  (prior_mlp_modules): Sequential(
    (0): Linear(in_features=4, out_features=256, bias=True)
    (1): Linear(in_features=256, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=784, bias=True)
  )
)


In [None]:
def create_interpolation_images(
    model,
    axis1=0,
    axis2=1,
    latent_dim=4,
    save_img_path=None,
    range1=np.arange(-2, 2, 0.2),
    range2=np.arange(-2, 2, 0.2),
):
    assert len(range1) == len(range2)
    z = []
    for i in range1:
        for j in range2:
            cur = [0. for _ in range(latent_dim)]
            cur[axis1] = i
            cur[axis2] = j
            z.append(cur)
    z = torch.Tensor(z)
    out = model.sample_generate(z=z)
    out = vutils.make_grid(out, nrow=len(range1))

    if save_img_path is None:
        save_img_path = f"interpolation_results_{axis1}vs{axis2}.png"    
    vutils.save_image(out, save_img_path)

create_interpolation_images(model, 0, 1)
create_interpolation_images(model, 0, 2)
create_interpolation_images(model, 0, 3)
create_interpolation_images(model, 1, 2)
create_interpolation_images(model, 1, 3)
create_interpolation_images(model, 2, 3)

