In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from model import deanGMVAE  # Adjust this import statement as needed


In [2]:
path = "./weights/C8_Disentanglement_17_6_0.792_0.3_0.00048114888561434383_adam_9_626.7850927065497_606.3880588061189_14.004697956450999_6.392335702295172_8_0.3387780262488029.pth"  # Update this path
model_trained = deanGMVAE(z_dim=9, beta=0.00048114888561434383, dropout=0.3, K=8)
model_trained.load_state_dict(torch.load(path))
model_trained.eval()

deanGMVAE(
  (encoder): DeanEncoderGMM(
    (conv1): Conv2d(4, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv4): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (dropout): Dropout(p=0.3, inplace=False)
    (fc_mu): Linear(in_features=4096, out_features=72, bias=True)
    (fc_logvar): Linear(in_features=4096, out_features=72, bias=True)
    (fc_pi): Linear(in_features=4096, out_features=8, bias=True)
  )
  (decoder): DeanDecGMM(
    (fc): Linear(in_features=9, out_features=4096, bias=True)
    (deconv1): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (deconv2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (deconv3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (deconv4): ConvTranspose2d

In [3]:
def generate_image_from_latent(model, dim_to_explore, val, latent_size=9):
    with torch.no_grad():
        # Generate a baseline latent vector with the specified dimension varied
        z = torch.zeros((1, latent_size))
        z[0, dim_to_explore] = val
        
        # Use the decoder to generate an image from the latent vector
        img = model.decoder(z).squeeze(0)  # Assuming the output is (C, H, W)
        
        # Process the image to visualize the 4 channels as specified:
        # First 3 channels are directly mapped; the 4th channel adds to all RGB.
        rgb_image = np.zeros((img.shape[1], img.shape[2], 3), dtype=np.float32)  # Prepare an empty RGB image
        
        # Map the first 3 channels to RGB and add the 4th channel to each RGB channel
        rgb_image[:, :, 0] = img[1, :, :] + img[3, :, :]  # R + 4th
        rgb_image[:, :, 1] = img[2, :, :] + img[3, :, :]  # G + 4th
        rgb_image[:, :, 2] = img[0, :, :] + img[3, :, :]  # B + 4th
        
        # Normalize the image to be in the [0, 1] range
        rgb_image = np.clip(rgb_image / rgb_image.max(), 0, 1)
        
        return rgb_image  # Return the numpy array of the processed image


In [5]:
latent_dim = 9  # or however many dimensions your model has
sliders = [widgets.FloatSlider(value=0.0, min=-1, max=1, step=0.1, description=f'Dim {i}') for i in range(latent_dim)]
output = widgets.Output()

def on_value_change(change, dim):
    with output:
        clear_output(wait=True)
        # Call your modified function to generate and display the image
        val = sliders[dim].value
        img = generate_image_from_latent(model_trained, dim, val)
        plt.imshow(img)
        plt.axis('off')
        plt.show()

for i, slider in enumerate(sliders):
    slider.observe(lambda change, dim=i: on_value_change(change, dim), names='value')
    
display(*sliders, output)

FloatSlider(value=0.0, description='Dim 0', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 1', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 2', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 3', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 4', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 5', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 6', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 7', max=1.0, min=-1.0)

FloatSlider(value=0.0, description='Dim 8', max=1.0, min=-1.0)

Output()