In [1]:
import sys
sys.path.append("..")

from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import Audio
import numpy as np 
import pickle

from imitative_agent import ImitativeAgent
from lib.dataset_wrapper import Dataset
from lib.notebooks import show_ema
from external import lpcynet

current path: /mnt/c/Users/vpaul/OneDrive - CentraleSupelec/Inner_Speech/agent/imitative_agent


In [8]:
# Get paths to all imitative agent in specific directory
agents_path = glob("../out/imitative_agent1/*/")
agents_path.sort()

print(f"Found {len(agents_path)} agents")

# Dictionary to store agent aliases mapped to their paths
agents_alias = {}

for agent_path in agents_path:
    # Load agent configuration without neural networks for efficiency
    agent = ImitativeAgent.reload(agent_path, load_nn=False)
    config = agent.config
        
    # Get agent identifier from path
    agent_i = agent_path[-2]
    
    # Create descriptive alias string containing key agent parameters
    agent_alias = " ".join((
        f"{','.join(config['dataset']['names'])}",  # Dataset names
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}", # Articulatory features type
        f"jerk_c={config['training']['jerk_loss_ceil']}", # Jerk loss ceiling
        f"jerk_w={config['training']['jerk_loss_weight']}", # Jerk loss weight
        f"bi={config['model']['inverse_model']['bidirectional']}", # Bidirectional model flag
        f"({agent_i})", # Agent identifier
    ))

    # Print agent information
    print(f"\nPath: {agent_path}:")
    print(f"- Datasets: {config['dataset']['names']}")
    print(f"- Synthesizer art type: {agent.synthesizer.config['dataset']['art_type']}")
    print(f"- Jerk loss ceiling: {config['training']['jerk_loss_ceil']}")
    print(f"- Jerk loss weight: {config['training']['jerk_loss_weight']}")
    print(f"- Bidirectional: {config['model']['inverse_model']['bidirectional']}")
    
    # Store mapping between alias and path
    agents_alias[agent_alias] = agent_path

{'dataset': {'batch_size': 8, 'datasplits_size': [64, 16, 20], 'names': ['pb2007'], 'num_workers': 6, 'shuffle_between_epochs': True, 'sound_type': 'cepstrum'}, 'model': {'direct_model': {'activation': 'relu', 'batch_norm': True, 'dropout_p': 0.25, 'hidden_layers': [256, 256, 256, 256]}, 'inverse_model': {'bidirectional': True, 'dropout_p': 0.25, 'hidden_size': 32, 'num_layers': 2}}, 'synthesizer': {'name': 'ea587b76c95fecef01cfd16c7f5f289d-0/'}, 'training': {'jerk_loss_ceil': 0.014, 'jerk_loss_weight': 1, 'vel_loss_ceil': 0.014, 'vel_loss_weight': 1, 'learning_rate': 0.001, 'max_epochs': 500, 'patience': 25}}
{'direct_model': {'activation': 'relu', 'batch_norm': True, 'dropout_p': 0.25, 'hidden_layers': [256, 256, 256, 256]}, 'inverse_model': {'bidirectional': True, 'dropout_p': 0.25, 'hidden_size': 32, 'num_layers': 2}}
{'direct_model': {'activation': 'relu', 'batch_norm': True, 'dropout_p': 0.25, 'hidden_layers': [256, 256, 256, 256]}, 'inverse_model': {'bidirectional': True, 'dropo

In [10]:
# Dictionary to store current item for each dataset
datasets_current_item = {}

def show_agent(agent_alias):
    """
    Creates an interactive visualization for analyzing an imitative agent's speech synthesis and articulation.
    
    Args:
        agent_alias (str): Alias identifying the agent to visualize
        
    Returns:
        Interactive widget displaying audio and visualizations of speech repetition
    """
    # Load the agent and get key configuration parameters
    agent_path = agents_alias[agent_alias]
    agent = ImitativeAgent.reload(agent_path)
    
    sound_type = agent.config["dataset"]["sound_type"]
    art_type = agent.synthesizer.config["dataset"]["art_type"]
    synth_dataset = agent.synthesizer.dataset
    
    def show_dataset(dataset_name):
        """
        Creates interactive visualization for a specific dataset.
        
        Args:
            dataset_name (str): Name of the dataset to visualize
        """
        # Load dataset and extract features
        dataset = Dataset(dataset_name)
        items_cepstrum = dataset.get_items_data(sound_type, cut_silences=True)
        items_source = dataset.get_items_data("source", cut_silences=True)
        sampling_rate = dataset.features_config["wav_sampling_rate"]
        
        items_ema = dataset.get_items_data("ema", cut_silences=True)
        
        # Get list of items and set current item
        items_name = dataset.get_items_list()
        if dataset_name in datasets_current_item:
            current_item = datasets_current_item[dataset_name]
        else:
            current_item = items_name[0][1]
        
        def resynth_item(item_name=current_item, freeze_source=False):
            """
            Resynthesize and visualize a specific utterance, showing original, repeated and estimated versions.
            
            Args:
                item_name (str): Name of the utterance to process
                freeze_source (bool): Whether to freeze the source features
            """
            # Store current item for this dataset
            datasets_current_item[dataset_name] = item_name
            
            # Load item data
            item_cepstrum = items_cepstrum[item_name]
            item_source = items_sources[item_name]
            item_wave = dataset.get_item_wave(item_name)
            nb_frames = len(item_cepstrum)
            
            # Generate repetition using the agent
            repetition = agent.repeat(item_cepstrum)
            repeated_cepstrum = repetition["sound_repeated"]    # Via synthesizer
            estimated_cepstrum = repetition["sound_estimated"]  # Via direct model
            estimated_art = repetition["art_estimated"]
            
            # Optionally freeze source features
            if freeze_source:
                item_source[:] = (1, 0)
            
            # Combine cepstral coefficients with source features for synthesis
            repeated_sound = np.concatenate((repeated_cepstrum, item_source), axis=1)
            estimated_sound = np.concatenate((estimated_cepstrum, item_source), axis=1)

            # Convert to waveforms using LPCNet
            repeated_wave = lpcynet.synthesize_frames(repeated_sound)
            estimated_wave = lpcynet.synthesize_frames(estimated_sound)
            
            # Display audio players for comparison
            print("Original sound:")
            display(Audio(item_wave, rate=sampling_rate))
            print("Repetition (Inverse model → Synthesizer → LPCNet):")
            display(Audio(repeated_wave, rate=sampling_rate))
            print("Estimation (Inverse model → Direct model → LPCNet):")
            display(Audio(estimated_wave, rate=sampling_rate))
            
            # Create spectrogram visualizations
            plt.figure(figsize=(nb_frames/20, 6), dpi=120)
            
            ax = plt.subplot(311)
            ax.set_title("original %s" % (sound_type))
            ax.imshow(item_cepstrum.T, origin="lower")
            
            ax = plt.subplot(312)
            ax.set_title("Repetition")
            ax.imshow(repeated_cepstrum.T, origin="lower")
            
            ax = plt.subplot(313)
            ax.set_title("Estimation")
            ax.imshow(estimated_cepstrum.T, origin="lower")
            
            plt.tight_layout()
            plt.show()
            
            # Convert articulatory parameters to EMA coordinates if needed
            if art_type == "art_params":
                estimated_art = dataset.art_to_ema(estimated_art)
            item_ema = items_ema[item_name]
            show_ema(estimated_art, reference=item_ema, dataset=synth_dataset)
        
        # Create interactive widget for utterance selection
        display(ipw.interactive(resynth_item, item_name=items_name, freeze_source=False))
    
    # Create interactive widget for dataset selection  
    display(ipw.interactive(show_dataset, dataset_name=agent.config["dataset"]["names"]))

# Create top-level interactive widget for agent selection
display(ipw.interactive(show_agent, agent_alias=sorted(agents_alias.keys())))

interactive(children=(Dropdown(description='agent_alias', options=('pb2007 synth_art=art_params jerk_c=0.014 j…