In [None]:
from model.datamodule import MNISTDataModule
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from model.vae import VAE
import torch
from ipywidgets import interact
import os
import glob

# load and setup data module

In [None]:
dm = MNISTDataModule(data_path='./data', num_workers=4, batch_size=32)
dm.setup()

# visualize some random examples from the validation dataset

In [None]:
dataiter = iter(dm.val_dataloader())
image = next(dataiter)

num_samples = 25
sample_images = [image[0][i,0] for i in range(num_samples)] 

fig = plt.figure(figsize=(5, 5))
grid = ImageGrid(fig, 111, nrows_ncols=(5, 5), axes_pad=0.1)

for ax, im in zip(grid, sample_images):
    ax.imshow(im, cmap='gray')
    ax.axis('off')

plt.show()

# load the trained model from the checkpoint

In [None]:
# helper function to get the latest model
def find_latest_checkpoint(base_path: str = "./lightning_logs") -> str:
    """
    Finds the latest checkpoint file in the PyTorch Lightning logs.

    This function searches through the version directories in the specified base path,
    identifies the latest version, and then finds the latest checkpoint within that version.

    Args:
        base_path (str): The base path where the lightning logs are stored.

    Returns:
        str: The path to the latest checkpoint file.
    """
    # Find all directories with the pattern 'version_*'
    version_dirs = glob.glob(os.path.join(base_path, "version_*"))

    # Sort the found directories and select the latest one
    latest_version_dir = sorted(version_dirs, key=lambda x: int(x.split('_')[-1]))[-1]

    # Find the latest checkpoint in the latest directory
    checkpoint_path = max(glob.glob(os.path.join(latest_version_dir, "checkpoints", "*.ckpt")), key=os.path.getmtime)

    return checkpoint_path

In [None]:
latest_checkpoint = find_latest_checkpoint()
vae = VAE.load_from_checkpoint(latest_checkpoint).to('cpu')

# Generate data using the decoder neural network

In [None]:
def generate_image(z1=0, z2=0):
    plt.figure(figsize=(2, 2))  # Festlegen der Abbildungsgröße, z.B. 4x4 Zoll
    input_tensor = torch.Tensor([z1, z2]).to('cpu')
    img_array = vae.decode(input_tensor).reshape(28, 28).cpu().detach().numpy()
    plt.imshow(img_array, cmap='gray')
    plt.axis('off')  # Optional: Achsen ausblenden
    plt.show()

interact(generate_image, z1=(-5.0, 5.0, 0.1), z2=(-5.0, 5.0, 0.1))


# Visualize distribution of the means of the latent space

In [None]:
import torch
import matplotlib.pyplot as plt

# Set the model to evaluation mode
vae.eval()

# Move the model to CPU for compatibility
vae.to('cpu')

# Iterate through the data in the DataModule
all_latent_vars = []
all_labels = []
for batch in dm.val_dataloader():
    inputs, labels = batch
    with torch.no_grad():  # No gradient computation
        _, _, latent_vars, _ = vae(batch)
        all_latent_vars.append(latent_vars.cpu())
        all_labels.append(labels.cpu())

# Convert to a single tensor array
all_latent_vars = torch.cat(all_latent_vars, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# Create the scatter plot
plt.figure(figsize=(10, 8))
for i in range(all_labels.max() + 1):  # Assuming 0-based class labels
    indices = all_labels == i
    plt.scatter(all_latent_vars[indices, 0], all_latent_vars[indices, 1], label=f'Class {i}', alpha=0.1, s=10)

plt.xlabel(r'Means of latent variable z_1')
plt.ylabel('Means of latent variable z_2')
plt.legend()
plt.show()