# Assignment 2
## ✅ Rename the filename with your roll number. E.g. if your roll number is `MT24003` then rename the file `MT24003_a2.ipynb`.
## ✅ Write code only in the sections marked with `# YOUR CODE HERE`. No, you can NOT write code anywhere else.
## ✅ Download and extract the `data.zip` folder next to this file. If you extract it correctly, you will have a `data` folder next to this file.
## ✅ Submit a .zip (NOT .tar, .rar, etc) file containining:
###    1. This Notebook after filling the code where asked.
###    2. The loss and metric plots generated using the `save_training_report` functions [`auto_encoder.png` + `variational_auto_encoder.png` + `conditional_variational_auto_encoder.png`].
###    3. The model weights saved using the  `save_model_weights` functions [`auto_encoder.pth` + `variational_auto_encoder.pth` + `conditional_variational_auto_encoder.pth`].
## ❌ Do not modify any other function or class definitions; doing so may lead to the autograder failing to judge your submission, resulting in a zero.
## ❌ Deleting or adding new cells may lead to the `autograder` failing to judge your submission, resulting in a zero. Even if a cell is empty, do NOT delete it.
## ❌ Do NOT install / import any other libraries. You should be able to solve all the questions using only the libraries imported below.

In [1]:
!pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 -q
!pip install numpy==1.25.2 -q
!pip install soundfile==0.13.0 -q
!pip install pandas==2.2.3 -q
!pip install matplotlib==3.9.4 -q
!pip install scikit-image==0.21.0 -q
!pip install tqdm==4.67.1 -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m59.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m69.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m26.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
import random
import timeit
from pathlib import Path
from typing import Tuple
from skimage.metrics import structural_similarity as ssim
from numpy import array as NumpyArray
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
import torch
import torchvision
from tqdm import tqdm

In [None]:
PATH_TO_DATA_DIR = Path("./data")
PATH_TO_TRAIN_DATA_DIR = str(PATH_TO_DATA_DIR / "train")
PATH_TO_TEST_DATA_DIR = str(PATH_TO_DATA_DIR / "test")

# `q1`: `FashionMNIST` Dataset
1. Implement a Dataset class for the `FashionMNIST` data for the task of `Image Restoration`.
2. The task of `Image Restoration` is an [Ill-posed problem](https://en.wikipedia.org/wiki/Well-posed_problem) where the goal is to restore the original image from a corrupted image. Thus there may be more than one augmented image for each clean image, and vice versa.
3. The `data` directory has the following directory structure:
4. ```
	data
    ├── train
    │   ├── aug
    │   │   ├── <imagenumber>_<classlabel>.png
    │   │   ├── ...
    │   ├── clean
    │   │   ├── <imagenumber>_<classlabel>.png
    │   │   ├── ...
    └── test
        ├── aug
        │   ├── <imagenumber>_<classlabel>.png
        │   ├── ...
        └── clean
            ├── <imagenumber>_<classlabel>.png
            ├── ...
    ```
5. Constraints:
   1. The `__getitem__` method should return a tuple of the form `(aug_image, clean_image, label)`.  `clean_image` is the clean image, `aug_image` is the augmented image, and `label` is the class label of the image.
   2. Both `clean_image` and `aug_image` tensors should be of the shape `(1, 28, 28)` and of type `torch.float32`.
   3. Both `clean_image` and `aug_image` tensors should have pixel values between `[0, 1]`.
   4. `label` should be of type `torch.int64`.


`q1` Grading [Total: 1]: `1` point if the code runs without any errors on hidden test cases, otherwise `0` points. No partial points for this question.

In [None]:
class FashionMNISTDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for loading paired FashionMNIST images (augmented and clean versions).

    Attributes:
        augmented_images (List[str]): List of file paths to augmented images, sorted alphabetically.
        clean_images (List[str]): List of file paths to clean images, sorted alphabetically.
    """
    def __init__(
        self, path_to_augmented_images_dir: str, path_to_clean_images_dir: str
    ):
        """
        Initializes the dataset by loading file paths for augmented and clean images.

        Args:
            path_to_augmented_images_dir (str): Path to the directory containing augmented images.
            path_to_clean_images_dir (str): Path to the directory containing clean images.
        """
        # YOUR CODE HERE
        raise NotImplementedError()

    def __len__(self) -> int:
        """
        Returns the total number of samples in the dataset.

        Returns:
            int
        """
        # YOUR CODE HERE
        raise NotImplementedError()

    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor, int]:
        """
        Retrieves the augmented image, clean image, and label for a given index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                - Augmented image as a tensor with values normalized to [0, 1].
                - Clean image as a tensor with values normalized to [0, 1].
                - Label as an integer tensor, extracted from the filename.
        """
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# tests for q1

path_to_train_images_aug_dir = str(PATH_TO_TRAIN_DATA_DIR + "/aug")
path_to_train_images_clean_dir = str(PATH_TO_TRAIN_DATA_DIR + "/clean")
fashion_mnist_dataset = FashionMNISTDataset(
    path_to_augmented_images_dir=path_to_train_images_aug_dir,
    path_to_clean_images_dir=path_to_train_images_clean_dir,
)


In [None]:
# tests for q1

path_to_test_images_aug_dir = str(PATH_TO_TEST_DATA_DIR + "/aug")
path_to_test_images_clean_dir = str(PATH_TO_TEST_DATA_DIR + "/clean")
fashion_mnist_dataset = FashionMNISTDataset(
    path_to_augmented_images_dir=path_to_test_images_aug_dir,
    path_to_clean_images_dir=path_to_test_images_clean_dir,
)


del fashion_mnist_dataset

# `q2`: Encoder, and Decoder classes

* Your task is to create AutoEncoder models for the task of `Image Restoration` using the `FashionMNIST` dataset. You need to implement the `Encoder` and `Decoder` classes for the AutoEncoder model. The `Encoder` class will be used to encode the input image into a latent representation, and the `Decoder` class will be used to decode the latent representation back to the original image. The `Encoder` and `Decoder` classes will be used in the AutoEncoder model, Variational AutoEncoder model, and (optionally) Conditional AutoEncoder model, so the implementation should be **generic and not specific** to any of the models.
* `q2a`: `Encoder` class: Implement a generic Encoder Module that will be used within all the AutoEncoder flavors (AutoEncoder, Variational AutoEncoder, and (optinally) Conditional AutoEncoder). Constraints:
  1. The input tensor will be of shape `[batch_size, 1, 28, 28]` that comes out of the DataLoader of the `FashionMNIST` dataset.
  2. Feel free to use any architecture you like with any layer or activation function in it. **You can NOT use pre-trained model weights**.
  3. The output tensor must be of shape `[batch_size, output_channels, height, width]`. This tensor will be the latent representation of the input tensor and will be passed to the Decoder Module.
  4. The number of parameters in the Encoder Module must be between 2,000 and 1,000,000 (both inclusive). Note that the number of parameters in the Encoder Module and Decoder Module will be counted separately and may not be the same.

* `q2b`: `Decoder` class: Implement a generic Decoder Module that will be used within all the AutoEncoder flavors (AutoEncoder, Variational AutoEncoder, and (optinally) Conditional AutoEncoder). Constraints:
  1. The input tensor will be of shape `[batch_size, input_channels, height, width]` that comes out of the Encoder Module.
  2. Feel free to use any architecture you like with any layer or activation function in it. **You can NOT use pre-trained model weights**.
  3. The output tensor must be of shape `[batch_size, 1, 28, 28]`. This tensor will be the reconstructed image of the input tensor.
  4. The number of parameters in the Decoder Module must be between 2,000 and 1,000,000 (both inclusive). Note that the number of parameters in the Encoder Module and Decoder Module will be counted separately and may not be the same.

`q2` Grading [Total: 1 point]:
1. `q2a`: `Encoder` class: `0.5` points if the code runs without any errors on hidden test cases, otherwise 0 points. No partial points for this question.
2. `q2b`: `Decoder` class:  `0.5` points if the code runs without any errors on hidden test cases, otherwise 0 points. No partial points for this question.

## `q2a`: `Encoder` class

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self, output_channels: int, type_of_autoencoder: str = None):
        super(Encoder, self).__init__()
        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, x):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# tests for q2a

encoder = Encoder(output_channels=64, type_of_autoencoder="vae")

random_input_tensor = torch.randn(1, 1, 28, 28)
output_tensor = encoder(random_input_tensor)


del encoder

## `q2b`: `Decoder` class

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, input_channels: int, type_of_autoencoder: str):
        super(Decoder, self).__init__()
        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, x):
        # YOUR CODE HERE
        raise NotImplementedError()

In [None]:
# tests for q2b

decoder = Decoder(input_channels=64, type_of_autoencoder="vae")

random_input_tensor = torch.randn(1, 64, 4, 4)
output_tensor = decoder(random_input_tensor)


del decoder

# `q3`: AutoEncoder Model
* `q3a`: `AutoEncoder` class: Implement a AutoEncoder that uses the Encoder and Decoder Modules implemented in `q2a` and `q2b`. Constraints:
  1. The number of parameters in the AutoEncoder must be between 4,000 and 2,000,000 (both inclusive).
  2. The input tensor will be of shape `[batch_size, 1, 28, 28]` that comes out of the DataLoader of the `FashionMNIST` dataset.
  3. The output tensor must be of shape `[batch_size, 1, 28, 28]`. This tensor will be the reconstructed image of the input tensor.

* `q3b`: Training the models: Implement the training loop for the AutoEncoder model. Constraints:
  1. Use the `FashionMNIST` dataset implemented in `q1` to load the data.
  2. Use the `AutoEncoder` model implemented in `q3a`.
  3. You are free to choose any loss function, optimizer, and hyperparameters.
  4. **You must**:
     1. Book-keep the training and validation losses and SSIM scores for each epoch and use it to plot the training curves with the `AutoEncoder.save_training_report` method.
     2. To calculate the SSIM score, you can use the `get_ssim` function provided below.
     3. Save the model weights using `AutoEncoder.save_model_weights` method.


`q3` Grading [Total: 1.5 points]:
1. `q3a`: `AutoEncoder` class: `0.5` points if the code runs without any errors on hidden test cases, otherwise 0 points. No partial points for this question.
2. `q3b`: Training the models: `1` points. You will be awarded points based on the SSIM score of the `AutoEncoder` model on a **hidden test set**. The grading will be as follows:
   1. 0.8 or more: `1` point
   2. 0.7 to 0.79: `0.8` points
   3. 0.6 to 0.69: `0.6` points
   4. 0.5 to 0.59: `0.4` points
   5. 0.4 to 0.49: `0.2` points
   6. Less than 0.4: `0` points


You are provided with the following template. **Populate only the sections marked as `# YOUR CODE HERE`. Do not modify other parts of the template.**

## `q3a`: `AutoEncoder` class

In [None]:
class AutoEncoder(torch.nn.Module):
    def __init__(self, latent_dim: int):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(output_channels=latent_dim, type_of_autoencoder="ae")
        self.decoder = Decoder(input_channels=latent_dim, type_of_autoencoder="ae")
        self.latent_dim = latent_dim

    def forward(self, input_tensor):
        # YOUR CODE HERE
        raise NotImplementedError()

    def save_model_weights(self):
        torch.save(self.state_dict(), "auto_encoder.pth")

    def load_model_weights(self):
        self.load_state_dict(torch.load("auto_encoder.pth"))

    def save_training_report(
        self,
        list_of_train_losses: List[float],
        list_of_val_losses: List[float],
        list_of_train_ssim_scores: List[float],
        list_of_val_ssim_scores: List[float],
    ):
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.title("Loss per Epoch")
        plt.plot(list_of_train_losses, label="Training")
        plt.plot(list_of_val_losses, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.title("SSIM per Epoch")
        plt.plot(list_of_train_ssim_scores, label="Training")
        plt.plot(list_of_val_ssim_scores, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("SSIM")
        plt.legend()

        plt.suptitle("AutoEncoder Training Report")

        plt.savefig("auto_encoder.png")
        plt.show()

In [None]:
# tests for q3a

autoencoder = AutoEncoder(latent_dim=64)

random_input_tensor = torch.randn(1, 1, 28, 28)
output_tensor = autoencoder(random_input_tensor)


del autoencoder

## `q3b`: Training the model

In [None]:
def get_ssim(
    list_of_predicted_images: List[NumpyArray], list_of_true_images: List[NumpyArray]
) -> List[float]:
    ssim_values = []
    for predicted_image, true_image in zip(
        list_of_predicted_images, list_of_true_images
    ):
        assert predicted_image.shape == (
            28,
            28,
        ), f"Expected image of shape (28, 28) but got {predicted_image.shape}"
        assert true_image.shape == (
            28,
            28,
        ), f"Expected image of shape (28, 28) but got {true_image.shape}"
        ssim_values.append(
            ssim(
                predicted_image,
                true_image,
                data_range=true_image.max() - true_image.min(),
            )
        )
    return sum(ssim_values) / len(ssim_values)

In [None]:
# Use this cell to:
# 1. Train the AutoEncoder model while bookkeeping the training and validation losses and SSIM scores for each epoch
# 2. Save the model weights using AutoEncoder.save_model_weights method
# 3. Save the training report using AutoEncoder.save_training_report method


# YOUR CODE HERE
raise NotImplementedError()

In [None]:
# tests for q3b

# `q4`: Variational AutoEncoder Model
* `q4a`: `VariationalAutoEncoder` class: Implement a VariationalAutoEncoder that uses the Encoder and Decoder Modules implemented in `q2a` and `q2b`. Constraints:
  1. The number of parameters in the VariationalAutoEncoder must be between 4,000 and 2,000,000 (both inclusive).
  2. The input tensor will be of shape `[batch_size, 1, 28, 28]` that comes out of the DataLoader of the `FashionMNIST` dataset.
  3. The output tensor must be of shape `[batch_size, 1, 28, 28]`. This tensor will be the reconstructed image of the input tensor.

* `q4b`: Training the models: Implement the training loop for the VariationalAutoEncoder model. Constraints:
  1. Use the `FashionMNIST` dataset implemented in `q1` to load the data.
  2. Use the `VariationalAutoEncoder` model implemented in `q4a`.
  3. You are free to choose any loss function, optimizer, and hyperparameters.
  4. **You must**:
     1. Book-keep the training and validation losses and SSIM scores for each epoch and use it to plot the training curves with the `VariationalAutoEncoder.save_training_report` method.
     2. To calculate the SSIM score, you can use the `get_ssim` function provided below.
     3. Save the model weights using `VariationalAutoEncoder.save_model_weights` method.


`q4` Grading [Total: 1.5 points]:
1. `q4a`: `VariationalAutoEncoder` class: `0.5` points if the code runs without any errors on hidden test cases, otherwise 0 points. No partial points for this question.
2. `q4b`: Training the models: `1` points. You will be awarded points based on the SSIM score of the `VariationalAutoEncoder` model on a **hidden test set**. The grading will be as follows:
   1. 0.8 or more: `1` point
   2. 0.7 or more: `0.8` points
   3. 0.6 or more: `0.6` points
   4. 0.5 or more: `0.4` points
   5. 0.4 or more: `0.2` points
   6. Less than 0.4: `0` points

You are provided with the following template. **Populate only the sections marked as `# YOUR CODE HERE`. Do not modify other parts of the template.**

## `q4a`: `VariationalAutoEncoder` class

In [None]:
class VariationalAutoEncoder(torch.nn.Module):
    def __init__(self, latent_dim: int):
        super(VariationalAutoEncoder, self).__init__()
        self.encoder = Encoder(output_channels=latent_dim * 2, type_of_autoencoder="vae")
        self.decoder = Decoder(input_channels=latent_dim, type_of_autoencoder="vae")
        self.latent_dim = latent_dim

    def reparameterize(self, mu, log_var):
        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, input_tensor):
        # YOUR CODE HERE
        raise NotImplementedError()

    def loss_function(self, predicted_images, gt_images, mu, log_var):
        # YOUR CODE HERE
        raise NotImplementedError()

    def save_model_weights(self):
        torch.save(self.state_dict(), "variational_auto_encoder.pth")

    def load_model_weights(self):
        self.load_state_dict(torch.load("variational_auto_encoder.pth"))

    def save_training_report(
        self,
        list_of_train_losses: List[float],
        list_of_val_losses: List[float],
        list_of_train_ssim_scores: List[float],
        list_of_val_ssim_scores: List[float],
    ):
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.title("Loss per Epoch")
        plt.plot(list_of_train_losses, label="Training")
        plt.plot(list_of_val_losses, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.title("SSIM per Epoch")
        plt.plot(list_of_train_ssim_scores, label="Training")
        plt.plot(list_of_val_ssim_scores, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("SSIM")
        plt.legend()

        plt.suptitle("VariationalAutoEncoder Training Report")

        plt.savefig("variational_auto_encoder.png")
        plt.show()

In [None]:
# tests for q4a

variational_autoencoder = VariationalAutoEncoder(latent_dim=64)

random_input_tensor = torch.randn(1, 1, 28, 28)
output = variational_autoencoder(random_input_tensor)


del variational_autoencoder

## `q4b`: Training the model

In [None]:
# Use this cell to:
# 1. Train the VariationalAutoEncoder model while bookkeeping the training and validation losses and SSIM scores for each epoch
# 2. Save the model weights using VariationalAutoEncoder.save_model_weights method
# 3. Save the training report using VariationalAutoEncoder.save_training_report method


# YOUR CODE HERE
raise NotImplementedError()

In [None]:
# tests for q4b

# `q5`: [BONUS] Conditional Variational AutoEncoder Model
* `q5a`: `ConditionalVariationalAutoEncoder` class: Implement a ConditionalVariationalAutoEncoder that uses the Encoder and Decoder Modules implemented in `q2a` and `q2b`. Constraints:
  1. The number of parameters in the ConditionalVariationalAutoEncoder must be between 4,000 and 2,000,000 (both inclusive).
  2. The input tensor will be of shape `[batch_size, 1, 28, 28]` that comes out of the DataLoader of the `FashionMNIST` dataset.
  3. The output tensor must be of shape `[batch_size, 1, 28, 28]`. This tensor will be the reconstructed image of the input tensor.

* `q5b`: Training the models: Implement the training loop for the ConditionalVariationalAutoEncoder model. Constraints:
  1. Use the `FashionMNIST` dataset implemented in `q1` to load the data.
  2. Use the `ConditionalVariationalAutoEncoder` model implemented in `q5a`.
  3. You are free to choose any loss function, optimizer, and hyperparameters.
  4. **You must**:
     1. Book-keep the training and validation losses and SSIM scores for each epoch and use it to plot the training curves with the `ConditionalVariationalAutoEncoder.save_training_report` method.
     2. To calculate the SSIM score, you can use the `get_ssim` function provided below.
     3. Save the model weights using `ConditionalVariationalAutoEncoder.save_model_weights` method.


`q5` Grading [Total: 1 point]:
1. `q5b`: Training the models: `1` point. You will be awarded points based on the SSIM score of the `ConditionalVariationalAutoEncoder` model on a **hidden test set**. The grading will be as follows:
   1. 0.8 or more: `1` point
   2. 0.7 to 0.79: `0.5` points
   3. less than 0.7: `0` points

You are provided with the following template. **Populate only the sections marked as `# YOUR CODE HERE`. Do not modify other parts of the template.**

## `q5a`: [BONUS]`ConditionalVariationalAutoEncoder`

In [None]:
class ConditionalVariationalAutoEncoder(torch.nn.Module):
    def __init__(self, latent_dim: int, condition_dim: int):
        super(ConditionalVariationalAutoEncoder, self).__init__()
        # YOUR CODE HERE
        raise NotImplementedError()

    def reparameterize(self, mu, log_var):
        # YOUR CODE HERE
        raise NotImplementedError()

    def forward(self, input_tensor, condition_tensor):
        # YOUR CODE HERE
        raise NotImplementedError()

    def loss_function(self, predicted_images, gt_images, mu, log_var):
        # YOUR CODE HERE
        raise NotImplementedError()

    def save_model_weights(self):
        torch.save(self.state_dict(), "conditional_variational_auto_encoder.pth")

    def load_model_weights(self):
        self.load_state_dict(torch.load("conditional_variational_auto_encoder.pth"))

    def save_training_report(
        self,
        list_of_train_losses: List[float],
        list_of_val_losses: List[float],
        list_of_train_ssim_scores: List[float],
        list_of_val_ssim_scores: List[float],
    ):
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.title("Loss per Epoch")
        plt.plot(list_of_train_losses, label="Training")
        plt.plot(list_of_val_losses, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.title("SSIM per Epoch")
        plt.plot(list_of_train_ssim_scores, label="Training")
        plt.plot(list_of_val_ssim_scores, label="Validation")
        plt.xlabel("Epoch")
        plt.ylabel("SSIM")
        plt.legend()

        plt.suptitle("ConditionalVariationalAutoEncoder Training Report")

        plt.savefig("conditional_variational_auto_encoder.png")
        plt.show()

In [None]:
# tests for q5a

conditional_variational_autoencoder = ConditionalVariationalAutoEncoder(
    latent_dim=64, condition_dim=10
)

random_input_tensor = torch.randn(1, 1, 28, 28)
random_condition_tensor = torch.randn(1, 10)
output = conditional_variational_autoencoder(
    random_input_tensor, random_condition_tensor
)


del conditional_variational_autoencoder

## `q5b`: [BONUS] Training the model

In [None]:
# Use this cell to:
# 1. Train the ConditionalVariationalAutoEncoder model while bookkeeping the training and validation losses and SSIM scores for each epoch
# 2. Save the model weights using ConditionalVariationalAutoEncoder.save_model_weights method
# 3. Save the training report using ConditionalVariationalAutoEncoder.save_training_report method


# YOUR CODE HERE
raise NotImplementedError()

In [None]:
# tests for q5b

In [None]:
# tests for q3b, q4b, q5b

