In [8]:
import sys,os
sys.path.append("..")
print("current path:", os.getcwd())
from glob import glob
import matplotlib.pyplot as plt
import ipywidgets as ipw
from IPython.display import Audio
import numpy as np
import pickle

from communicative_agent import CommunicativeAgent
from lib.dataset_wrapper import Dataset
from lib.notebooks import show_ema
from IPython.display import clear_output


print("current path:", os.getcwd())
# sys.path.insert(0, "/Users/ladislas/Desktop/motor_control_agent")
sys.path.insert(0, "/mnt/c/Users/vpaul/Documents/Inner_Speech/agent/")
from external import lpcynet
#from external import *

current path: /mnt/c/Users/vpaul/OneDrive - CentraleSupelec/Inner_Speech/agent/communicative_agent
current path: /mnt/c/Users/vpaul/OneDrive - CentraleSupelec/Inner_Speech/agent/communicative_agent


In [10]:
agents_path = glob("../out/communicative_agent/*/")
agents_path.sort()

print(f"Found {len(agents_path)} agents")

# Dictionary to store agent aliases mapped to their paths
agents_alias = {}

for agent_path in agents_path:
    # Load agent configuration without neural networks for efficiency
    agent = CommunicativeAgent.reload(agent_path, load_nn=False)
    config = agent.config
        
    # Get agent identifier from path
    agent_i = agent_path[-2] 
    
    # Handle nb_derivatives, always equal to 0 in our implementation
    try:
        nb_derivatives = config['model']['direct_model']['nb_derivatives']
    except:
        nb_derivatives = 0
        
    # Create descriptive alias string containing key agent parameters
    agent_alias = " ".join((
        f"{','.join(config['dataset']['names'])}",  # Dataset names
        f"synth_art={agent.synthesizer.config['dataset']['art_type']}", # Articulatory features type
        f"nd={nb_derivatives}", # Number of derivatives
        f"jerk={config['training']['jerk_loss_weight']}", # Jerk loss weight
        f"({agent_i})", # Agent identifier
    ))
    
    # Print agent information
    print(f"\nPath: {agent_path}:")
    print(f"- Datasets: {config['dataset']['names']}")
    print(f"- Synthesizer art type: {agent.synthesizer.config['dataset']['art_type']}")
    print(f"- Jerk loss weight: {config['training']['jerk_loss_weight']}")
    
    # Store mapping between alias and path
    agents_alias[agent_alias] = agent_path

Found 6 agents

Path: ../out/communicative_agent/5e95a73fee5902e08b8838c6778b10dd-0/:
- Datasets: ['pb2007']
- Synthesizer art type: art_params
- Jerk loss weight: 0.15

Path: ../out/communicative_agent/732af0286099098cd5488228100d9cb1-0/:
- Datasets: ['pb2007']
- Synthesizer art type: art_params
- Jerk loss weight: 0.15

Path: ../out/communicative_agent/884b552070dd7a21e9901ec1cdb5a1e5-0/:
- Datasets: ['pb2007']
- Synthesizer art type: art_params
- Jerk loss weight: 0.1

Path: ../out/communicative_agent/9be83c9471e7f0e7ec19c4bd0ac540f7-0/:
- Datasets: ['pb2007']
- Synthesizer art type: art_params
- Jerk loss weight: 0.15

Path: ../out/communicative_agent/cca7402d9866782e2bc60b6c2cffc9c5-0/:
- Datasets: ['pb2007']
- Synthesizer art type: art_params
- Jerk loss weight: 0.1

Path: ../out/communicative_agent/d1e8dc5c1a2b0275d89ac357066fa2d0-0/:
- Datasets: ['pb2007']
- Synthesizer art type: art_params
- Jerk loss weight: 0.1


In [12]:
def show_agent(agent_path):
    """
    Creates an interactive visualization for analyzing a communicative agent's speech synthesis and articulation.
    
    Args:
        agent_path (str): Path to the saved agent model
        
    Returns:
        Interactive widget displaying audio and visualizations of speech repetition
    """
    # Load the agent and get key configuration parameters
    agent = CommunicativeAgent.reload(agent_path)
    sound_type = agent.synthesizer.config["dataset"]["sound_type"]
    art_type = agent.synthesizer.config["dataset"]["art_type"] 
    synth_dataset = agent.synthesizer.dataset
    
    def show_dataset(dataset_name):
        """
        Creates interactive visualization for a specific dataset.
        
        Args:
            dataset_name (str): Name of the dataset to visualize
        """
        dataset = Dataset(dataset_name)
        
        # Load dataset and extract features
        items_cepstrum = dataset.get_items_data(sound_type, cut_silences=False)
        items_source = dataset.get_items_data("source", cut_silences=False)
        sampling_rate = dataset.features_config["wav_sampling_rate"]
        
        # items_ema = dataset.get_items_data("ema", cut_silences=True)
        
        items_name = dataset.get_items_list()
        
        def resynth_item(item_name):
            """
            Resynthesize and visualize a specific utterance, showing original, repeated and estimated versions.
            
            Args:
                item_name (str): Name of the utterance to process
            """
            # Clear any existing plots
            plt.close('all')
            
            # Load item data
            item_cepstrum = items_cepstrum[item_name]
            item_source = items_source[item_name]
            item_wave = dataset.get_item_wave(item_name)
            nb_frames = len(item_cepstrum)
            
            # Generate repetition using the agent
            repetition = agent.repeat(item_cepstrum)
            repeated_cepstrum = repetition["sound_repeated"]    # Via synthesizer
            estimated_cepstrum = repetition["sound_estimated"]  # Via direct model
            estimated_art = repetition["art_estimated"]
            
            # Combine cepstral coefficients with source features for synthesis
            repeated_sound = np.concatenate((repeated_cepstrum, item_source), axis=1)
            estimated_sound = np.concatenate((estimated_cepstrum, item_source), axis=1)

            # Convert to waveforms using LPCNet
            repeated_wave = lpcynet.synthesize_frames(repeated_sound)
            estimated_wave = lpcynet.synthesize_frames(estimated_sound)
            
            # Display audio players for comparison
            print("Original sound:")
            display(Audio(item_wave, rate=sampling_rate))
            print("Repetition (Inverse model → Synthesizer → LPCNet):")
            display(Audio(repeated_wave, rate=sampling_rate))
            print("Estimation (Inverse model → Direct model → LPCNet):")
            display(Audio(estimated_wave, rate=sampling_rate))
            
            # Create spectrogram visualizations
            fig = plt.figure(figsize=(nb_frames/20, 6), dpi=120)
            
            ax = plt.subplot(311)
            ax.set_title("original %s" % (sound_type))
            ax.imshow(item_cepstrum.T, origin="lower")
            
            ax = plt.subplot(312)
            ax.set_title("Repetition")
            ax.imshow(repeated_cepstrum.T, origin="lower")
            
            ax = plt.subplot(313)
            ax.set_title("Estimation")
            ax.imshow(estimated_cepstrum.T, origin="lower")
            
            plt.tight_layout()
            display(fig)
            clear_output(wait=True)

            # Convert articulatory parameters to EMA coordinates if needed
            if art_type == "art_params":
                estimated_art = synth_dataset.art_to_ema(estimated_art)
            # item_ema = items_ema[item_name]
            show_ema(estimated_art, reference=None, dataset=synth_dataset)
        
        # Create interactive widget for utterance selection
        display(ipw.interactive(resynth_item, item_name=items_name))
    
    # Create interactive widget for dataset selection
    display(ipw.interactive(show_dataset, dataset_name=agent.sound_quantizer.config["dataset"]["names"]))

# Create top-level interactive widget for agent selection
display(ipw.interactive(show_agent, agent_path=agents_alias))


interactive(children=(Dropdown(description='agent_path', options={'pb2007 synth_art=art_params nd=0 jerk=0.15 …