In [5]:

## the command to export requirements:
## under the terminal of enviroment of needed. 
## pip list --format=freeze > requirements.txt
## for voila preview : http://localhost:8888/voila

## when publish to binder:
## https://ovh2.mybinder.org/
## GitHub repository name or URL: https://github.com/ueyupen/face_machine_demo
## URL to open (optional): /voila/render/demo.ipynb
## Set "file" to URL

import os
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import colors
import random
import pickle

RGB_img=True
#crop_and_resize = [128,128,32,32] # crop_width, crop_height, resize_width, resize_height
crop_and_resize = [128,128,64,64] # crop_width, crop_height, resize_width, resize_height
crop_width = crop_and_resize[0]
crop_height = crop_and_resize[1]
resize_width = crop_and_resize[2]
resize_height = crop_and_resize[3]
sample_size=10000
Ramdom_seed_for_sampling=1005


In [15]:
file_path1 = os.getcwd()+'/data_to_use/concised_U.pkl'
file_path2 = os.getcwd()+'/data_to_use/concised_S.pkl'

with open(file_path1, "rb") as f1:
    concised_U = pickle.load(f1)
with open(file_path2, "rb") as f2:
    concised_S = pickle.load(f2)

# file_path3 = os.getcwd()+'/data_to_use/concised_V.pkl'
# with open(file_path3, "rb") as f3:
#     concised_V = pickle.load(f3)


# deal with V
concised_V_list=[]
for i in range (5):
    V_file_path = os.getcwd()+f'/data_to_use/concised_V_{i}.pkl'
    with open(V_file_path, "rb") as f3:
        concised_V_list.append(pickle.load(f3))
concised_V = np.hstack(concised_V_list)


In [16]:
from ipywidgets import interact, interactive, fixed, interact_manual,Layout
import ipywidgets as widgets
import time

class mvp:

    def __init__(self, image_w, image_h, start_sample, U_matrix, S_matrix, V_matrix, module_of_modes, module_of_modes_weights):

        self.U_ModeWise_MaxLoad = U_matrix.max(axis=0, keepdims=0)
        self.U_ModeWise_MinLoad = U_matrix.min(axis=0, keepdims=0)
        self.U_ModeWise_MeanLoad = U_matrix.mean(axis=0, keepdims=0)
        self.U_ModeWise_MaxMinMidLoad = (self.U_ModeWise_MaxLoad - self.U_ModeWise_MinLoad) / 2 + self.U_ModeWise_MinLoad
        self.U_ModeWise_MaxMinMidLoad = (self.U_ModeWise_MaxLoad - self.U_ModeWise_MinLoad) / 2 + self.U_ModeWise_MinLoad
        
        Normalized_U = 2*(U_matrix-self.U_ModeWise_MaxMinMidLoad)/(self.U_ModeWise_MaxLoad-self.U_ModeWise_MinLoad)
        Normalized_U_binned=np.around(Normalized_U, decimals=2, out=None)
        
        self.U_matrix_to_use = Normalized_U_binned
        self.S_matrix_to_use = S_matrix
        self.V_matrix_to_use = V_matrix
        self.modes_to_tune_list = [j for sub in module_of_modes for j in sub]
        self.plot_reshape=(image_w, image_h, 3)
        self.initial_sample1 = start_sample

        self.customized_U1_sample_original = self.U_matrix_to_use[self.initial_sample1].copy()
        self.customized_U1_sample          = self.U_matrix_to_use[self.initial_sample1,:].copy()

        # define widgets
        self.sample_seed_input1=widgets.IntText(value=0,description='sample ID',disabled=False,layout=Layout(width='50%'))
        self.resample_button = widgets.Button(description='resample',disabled=False,layout=Layout(width='100%'))
        self.reset_button = widgets.Button(description='reset',layout=Layout(width='100%'))
        self.slider_tuning_without_updating_the_plot = widgets.ToggleButton(value=False)
        self.Data_type = 'CelebA'
        
        # mode tuning
        self.sliders1=[]
        self.PC_thumbnails_list = []
        for mode in self.modes_to_tune_list:
            self.slider1 = widgets.FloatSlider(min=-2.5, max=2.5, step=0.01, value=self.U_matrix_to_use[self.initial_sample1, mode], orientation='vertical', disabled=True, description = f'PC {mode}', continuous_update = False, layout = Layout(min_height = '250px',width = '50px', align_self="center"))
            self.slider1.style.handle_color = 'black'
            self.PC_thumbnail_output = widgets.Output(layout = Layout(width = '70px',align_self = "center",margin = ('0px -10px 0px -10px'),padding = ('0px -10px 0px -10px')))
            self.sliders1.append(self.slider1)
            self.PC_thumbnails_list.append(self.PC_thumbnail_output)

        self.slider_to_mode_dictionary=dict(zip(self.sliders1,self.modes_to_tune_list))
        self.mode_to_slider_dictionary=dict(zip(self.modes_to_tune_list,self.sliders1))

        self.PC_thumbnail_to_mode_dictionary = dict(zip(self.PC_thumbnails_list, self.modes_to_tune_list))

        # module tuning
        self.module_sliders = []
        self.module_thumbnails_list = []
        self.module_list=[]
        for i in range (3):
            self.module_slider1 = widgets.FloatSlider(min=-2, max=2, step=0.01, value=0, description = f'module {i}', continuous_update = False, layout = Layout(min_width = '150px', align_self="center"))
            self.module_thumbnail_output = widgets.Output(layout = Layout(width = '70px',align_self = "center"))
            self.module_sliders.append(self.module_slider1)
            self.module_thumbnails_list.append(self.module_thumbnail_output)
            self.module_list.append(i)

        self.module_slider_to_module_dictionary=dict(zip(self.module_sliders, self.module_list))
        self.module_to_modes_list = module_of_modes
        self.module_to_modes_weight_list = module_of_modes_weights

        ## put thins together
        self.sample1_static_output = widgets.Output()
        self.sample1_dynamic_output = widgets.Output()
        mode_control_list=[]
        for _, (slider1, thumbnail) in enumerate(zip(self.PC_thumbnails_list, self.sliders1)):
            mode_control_list.append(widgets.VBox([slider1,thumbnail],layout = Layout(align_items = 'flex-start', min_width = '70px')))
        self.mode_box = widgets.HBox(mode_control_list,layout = Layout(width = '780px', align_self="center") )

        module_control_list=[]
        for module_slider, module_thumbnail in zip(self.module_thumbnails_list, self.module_sliders):
            module_control_list.append(widgets.HBox([module_slider,module_thumbnail],layout = Layout(align_items = 'flex-start', min_height = '70px')))
        self.module_box = widgets.VBox(module_control_list,layout = Layout(height = '300px', align_self="center") )
            
        self.image_output_box = widgets.HBox([self.sample1_static_output, self.sample1_dynamic_output])
        self.resample_box = widgets.HBox([self.sample_seed_input1, self.resample_button])
        #self.resample_box = widgets.HBox([self.resample_button])## hide sample ID
        self.control_pannel = widgets.VBox([self.image_output_box, self.resample_box, self.reset_button, self.module_box], layout = Layout(width = '400px', align_self="center"))
        self.final_GUI = widgets.HBox([self.control_pannel,  self.mode_box])

        # define button behavoir
        for slider in self.sliders1:
            slider.observe(self.slider_to_plot, names='value')
        
        for slider in self.module_sliders:
            slider.observe(self.modulr_slider_handler, names='value')
        
        self.reset_button.on_click(self.reset1)
        self.resample_button.on_click(self.resample)
        self.sample_seed_input1.observe(self.reset1, names='value')
    
        self.GUI_initiation()

    def GUI_initiation(self):
        self.plot_static1(self.customized_U1_sample_original)
        self.plot_dynamic1(self.customized_U1_sample)
        self.plot_PC_thumbnail()
        self.plot_module_thumbnail()

    def display_GUI(self):
        display(self.final_GUI)

    def image_recon_from_nob(self, customized_U):
        #customized_U_pretreatment=customized_U*(U_ModeWise_MaxLoad-U_ModeWise_MinLoad)/2+U_ModeWise_MeanLoad
        customized_U_pretreatment=customized_U*(self.U_ModeWise_MaxLoad-self.U_ModeWise_MinLoad)/2+self.U_ModeWise_MaxMinMidLoad
        image_generated = customized_U_pretreatment @ np.diag(self.S_matrix_to_use) @ self.V_matrix_to_use.T
        image_generated = (image_generated-np.min(image_generated))/(np.max(image_generated)-np.min(image_generated))
        return image_generated
    
    def plot_dynamic1(self, customized_U1_sample):
        with self.sample1_dynamic_output:
            self.sample1_dynamic_output.clear_output(wait=True)
            fig, axes = plt.subplots(1,1,figsize=(2,2.5))
            divnorm=colors.TwoSlopeNorm(vcenter=0)
            Reshaped_toplot = np.ndarray.reshape(self.image_recon_from_nob(customized_U1_sample), self.plot_reshape, order='C')
            axes.imshow(Reshaped_toplot, cmap="bwr", norm=divnorm)
            axes.set_title('tuned')
            axes.set_yticks([])
            axes.set_xticks([])
            fig.tight_layout()
            plt.show(block=False)
            plt.close()

    def plot_static1(self, customized_U1_sample_original):
        with self.sample1_static_output:
            self.sample1_static_output.clear_output(wait=True)
            fig, axes = plt.subplots(1,1,figsize=(2,2.5))
            divnorm=colors.TwoSlopeNorm(vcenter=0)
            Reshaped_toplot = np.ndarray.reshape(self.image_recon_from_nob(customized_U1_sample_original), self.plot_reshape, order='C')
            axes.imshow(Reshaped_toplot, cmap="bwr", norm=divnorm)
            axes.set_title('oringinal')
            axes.set_yticks([])
            axes.set_xticks([])
            fig.tight_layout()
            plt.show(block=False)
            plt.close()

    def plot_PC_thumbnail(self):
        for thumbnail in self.PC_thumbnails_list:
            mode = self.PC_thumbnail_to_mode_dictionary[thumbnail]
            with thumbnail:
                thumbnail.clear_output(wait=True)
                fig, ax = plt.subplots(1,1,figsize=(0.95,0.95))
                Reshaped_toplot=np.ndarray.reshape(self.V_matrix_to_use[:, mode], self.plot_reshape, order='C')
                image_norm = (Reshaped_toplot-np.min(Reshaped_toplot))/(np.max(Reshaped_toplot)-np.min(Reshaped_toplot))
                ax.imshow(np.array(image_norm), cmap="bwr",vmin=0,vmax=1)
                ax.set_yticks([])
                ax.set_xticks([])
                fig.tight_layout()
                plt.show(block=False)
                plt.close()

    def plot_module_thumbnail(self):
        for modes_in_module, mode_weight_in_module, thumbnail in zip(self.module_to_modes_list, self.module_to_modes_weight_list, self.module_thumbnails_list):
            to_plot=0
            for mode, weight in zip(modes_in_module,mode_weight_in_module):
                to_plot = to_plot + weight * self.V_matrix_to_use[:, mode] * self.S_matrix_to_use[mode]
            with thumbnail:
                thumbnail.clear_output(wait=True)
                fig, ax = plt.subplots(1,1,figsize=(0.95,0.95))
                Reshaped_toplot=np.ndarray.reshape(to_plot, self.plot_reshape, order='C')
                image_norm = (Reshaped_toplot-np.min(Reshaped_toplot))/(np.max(Reshaped_toplot)-np.min(Reshaped_toplot))
                ax.imshow(np.array(image_norm), cmap="bwr",vmin=0,vmax=1)
                ax.set_yticks([])
                ax.set_xticks([])
                fig.tight_layout()
                plt.show(block=False)
                plt.close()

    def slider_to_plot(self, change):
        input_slider=change.owner
        input_value=change.new
        input_mode = self.slider_to_mode_dictionary[input_slider]
        self.customized_U1_sample[input_mode] = input_value
        if self.slider_tuning_without_updating_the_plot.value == False:
            self.plot_dynamic1(self.customized_U1_sample)
        else:
            self.slider_tuning_without_updating_the_plot.value = False

    def modulr_slider_handler(self, change):
        sample1 = self.sample_seed_input1.value
        input_slider=change.owner
        input_value=change.new
        module = self.module_slider_to_module_dictionary[input_slider]
        U_matrix_sample_original = self.U_matrix_to_use[sample1]

        for idx ,(mode, weight) in enumerate(zip(self.module_to_modes_list[module],self.module_to_modes_weight_list[module])):
            len_of_this_module = len(self.module_to_modes_list[module])
            if idx == len_of_this_module-1:
                self.slider_tuning_without_updating_the_plot.value = False
            else:
                self.slider_tuning_without_updating_the_plot.value = True
            self.mode_to_slider_dictionary[mode].value = input_value * weight + U_matrix_sample_original[mode]
        
    def initialize_sliders1(self):
        sample1 = self.sample_seed_input1.value
        self.customized_U1_sample = self.U_matrix_to_use[sample1,:].copy()

        for idx, slider1 in enumerate(self.sliders1):
            mode = self.slider_to_mode_dictionary[slider1]
            if idx == len(self.sliders1)-1:
                self.slider_tuning_without_updating_the_plot.value = False
            else:
                self.slider_tuning_without_updating_the_plot.value = True
            slider1.value = self.customized_U1_sample[mode]
        
        self.customized_U1_sample_original = self.U_matrix_to_use[sample1,:].copy()
        self.plot_static1(self.customized_U1_sample_original)
        self.plot_dynamic1(self.customized_U1_sample)

    def initialize_module_sliders(self):
        for module_slider in self.module_sliders:
            module_slider.value = 0

    def reset1(self, b):
        self.initialize_sliders1()
        self.initialize_module_sliders()

    def resample(self,change):
        sample_to_draw=[i for i in range(self.U_matrix_to_use.shape[0])]
        sample_to_draw.remove(self.sample_seed_input1.value)
        self.sample_seed_input1.value = random.sample(sample_to_draw, 1)[0]
        self.initialize_sliders1()
        self.initialize_module_sliders()


In [17]:

U_matrix_to_use = concised_U
S_matrix_to_use = concised_S
V_matrix_to_use = concised_V
start_sample = 0
modes_to_tune=[7,8,9,25,26,27,30,31,32,33,35,78] 
image_width = 64
image_height = 64

module_of_modes = [[7,8,9],[33,78],[25,26,27,56,71]]
module_of_modes_weights = [[-1.2,-1.2,1.2],[1.5,-1.1],[1,1,1,1.5,-1.1]]


mvp_1 = mvp(image_width, image_height, start_sample, U_matrix_to_use, S_matrix_to_use, V_matrix_to_use, module_of_modes, module_of_modes_weights)
mvp_1.display_GUI()

HBox(children=(VBox(children=(HBox(children=(Output(), Output())), HBox(children=(IntText(value=0, descriptionâ€¦