In [None]:
import pickle
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from ipywidgets import interactive, FloatSlider, Button, VBox
from IPython.display import display as ipython_display
from tqdm import tqdm
#local
import dnnlib
from dnnlib import tflib

In [None]:
model_path = 'best_net.pkl'

In [None]:
#define load model functions
_cached_networks = dict()
def load_networks(path):
    if path in _cached_networks:
        return _cached_networks[path]
    stream = open(path, 'rb')
    tflib.init_tf()
    with stream:
        G, D, Gs = pickle.load(stream, encoding='latin1')
    _cached_networks[path] = G, D, Gs
    return G, D, Gs
# Code to load the StyleGAN2 Model
def load_model():
    _G, _D, Gs = load_networks(model_path)
    noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    return Gs, noise_vars, Gs_kwargs

In [None]:
#define helper functions
def get_control_latent_vectors(path):
    files = [x for x in Path(path).iterdir() if str(x).endswith('.npy')]
    latent_vectors = {f.name[:-4]:np.load(f) for f in files}
    return latent_vectors
#load latent directions
latent_controls = get_control_latent_vectors('trajectories/')

def generate_image_from_projected_latents(latent_vector):
    images = Gs.components.synthesis.run(latent_vector, **Gs_kwargs)
    return images

def create_video(image):
    fig, ax = plt.subplots()
    plt.close()
    def animator(N): # N is the animation frame number
        ax.imshow(image[N],cmap='gray')
        ax.axis('off')
        return ax
    PlotFrames = range(0,image.shape[0],1)
    anim = animation.FuncAnimation(fig,animator,frames=PlotFrames,interval=50)
    rc('animation', html='jshtml') # embed in the HTML for Google Colab
    return anim

def create_videos(image1, image2):
    assert image1.shape[0] == image2.shape[0], "Both videos should have the same number of frames"
    
    fig, (ax1, ax2) = plt.subplots(1, 2)  # 1 row, 2 columns
    plt.close()
    
    def animator(N): 
        # Clear previous frames
        ax1.clear()
        ax2.clear()
        
        ax1.set_title("ED-to-ES")
        ax2.set_title("Frame-to-frame")
        
        ax1.imshow(image1[N], cmap='gray')
        ax1.axis('off')
        
        ax2.imshow(image2[N], cmap='gray')
        ax2.axis('off')
        
        return ax1, ax2
    
    # Assuming both videos have the same length
    PlotFrames = range(0, image1.shape[0], 1)
    anim = animation.FuncAnimation(fig, animator, frames=PlotFrames, interval=50)  # adjusted interval to 50ms
    rc('animation', html='jshtml') # embed in the HTML for Google Colab
    return anim

In [None]:
## define video generation methods
def ED_to_ES(latent_code):
    all_imgs = []
    amounts_up = [i/25 for i in range(0,25)]
    amounts_down = [1-i/25 for i in range(1,26)]

    for amount_to_move in tqdm(amounts_up):
        modified_latent_code = latent_code + latent_controls["time"]*amount_to_move
        images = generate_image_from_projected_latents(modified_latent_code)
        all_imgs.append(np.array(images[0]))

    for amount_to_move in tqdm(amounts_down):
        modified_latent_code = latent_code + latent_controls["time"]*amount_to_move
        images = generate_image_from_projected_latents(modified_latent_code)
        all_imgs.append(np.array(images[0]))
    
    return np.array(all_imgs)

def frame_to_frame(latent_code):
    modified_latent_code = np.copy(latent_code)
    full_video = [generate_image_from_projected_latents(modified_latent_code)]
    for i in tqdm(range(49)):
        modified_latent_code = modified_latent_code +  latent_controls[f'{i}{i+1}']
        ims = generate_image_from_projected_latents(modified_latent_code)
        full_video.append(ims)
    return np.array(full_video).squeeze()

In [None]:
#load the model
Gs, noise_vars, Gs_kwargs = load_model()

In [None]:
#select a random latent code
rnd = np.random.RandomState(3)
z = rnd.randn(1, *Gs.input_shape[1:])
noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]
tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars})
random_img_latent_code = Gs.components.mapping.run(z,None)

#make it be ED frame
random_img_latent_code -= 0.7*latent_controls['time']

In [None]:
def create_and_display_videos(sphericity_index=0.0, lv_area=0.0):
    #apply physiological adjustment
    adjusted_latent_code = np.copy(random_img_latent_code)
    adjusted_latent_code += sphericity_index * latent_controls['sphericity_index']
    adjusted_latent_code += lv_area * latent_controls['lv_area']
    
    fig, (ax1, ax2) = plt.subplots(1, 2)  # 1 row, 2 columns
    
    ax1.clear()
    ax2.clear()
    
    img = generate_image_from_projected_latents(adjusted_latent_code).squeeze()
    ax1.set_title("ED-to-ES")
    ax2.set_title("Frame-to-frame")

    ax1.imshow(img, cmap='gray')
    ax1.axis('off')

    ax2.imshow(img, cmap='gray')
    ax2.axis('off')

    plt.show()

def on_generate_videos_button_click(button):
    # Access the current values of the sliders
    sphericity_index = sphericity_index_slider.value
    lv_area = lv_area_slider.value

    # Apply physiological adjustment
    adjusted_latent_code = np.copy(random_img_latent_code)
    adjusted_latent_code += sphericity_index * latent_controls['sphericity_index']
    adjusted_latent_code += lv_area * latent_controls['lv_area']

    # Generate videos
    ED_to_ES_video = ED_to_ES(adjusted_latent_code)
    frame_to_frame_video = frame_to_frame(adjusted_latent_code)
    anim = create_videos(ED_to_ES_video, frame_to_frame_video)
    ipython_display(anim)

In [None]:
generate_videos_button = Button(description="Generate Videos")
generate_videos_button.on_click(on_generate_videos_button_click)

sphericity_index_slider = FloatSlider(min=-2, max=3, step=0.1, value=0.0, description='Sphericity:')
lv_area_slider = FloatSlider(min=-2, max=3, step=0.1, value=0.0, description='LV Area:')

# Display the interactive interface
interactive_plot = interactive(create_and_display_videos, sphericity_index=sphericity_index_slider, lv_area=lv_area_slider)
output = interactive_plot.children[-1]
output.layout.height = '200px'  
display(VBox([interactive_plot, generate_videos_button]))