In [145]:
import numpy as np
import pandas as pd
import os
import json

In [146]:
def parse_config_name(config):
    """Parse config name to extract variant, cost, and reg information"""
    # Handle basic EOT case
    if config.endswith('_EOT'):
        return 'EOT', config.split('_EOT')[0].split('_')[-1], 'nil'
    
    # Handle REOT cases
    if '_REOT' in config:
        parts = config.split('_REOT')
        cost_part = parts[0].split('_')[-1]  # e.g., 'g'
        reot_part = parts[1]  # e.g., '1_mc+lr'
        
        if '_' in reot_part:
            lam_part, reg_part = reot_part.split('_', 1)
            variant = f'REOT{lam_part} (lam: {get_lambda_value(lam_part)})'
            reg = reg_part if reg_part != 'nil' else 'nil'
        else:
            variant = f'REOT{reot_part} (lam: {get_lambda_value(reot_part)})'
            reg = 'nil'
        
        return variant, cost_part, reg
    
    # Handle EOT with lambda cases
    if '_EOT' in config:
        parts = config.split('_EOT')
        cost_part = parts[0].split('_')[-1]  # e.g., 'g+mc'
        eot_num = parts[1]  # e.g., '1', '2', etc.
        
        if eot_num:
            variant = f'EOT{eot_num} (lam: {get_lambda_value(eot_num)})'
        else:
            variant = 'EOT'
        
        return variant, cost_part, 'nil'
    
    return config, 'unknown', 'nil'

def get_lambda_value(num_str):
    """Convert EOT/REOT number to lambda value"""
    lambda_map = {
        '1': '0',
        '2': '0.2', 
        '3': '0.4',
        '4': '0.6',
        '5': '0.8',
        '6': '1'
    }
    return lambda_map.get(num_str, num_str)

def format_table_data(error_dict, configs_to_evaluate):
    """Format error dictionary into table data"""
    table_data = []
    
    for config in configs_to_evaluate:
        if config in error_dict:
            variant, cost, reg = parse_config_name(config)
            
            row = {
                'Variant': variant,
                'Cost': cost,
                'Reg': reg,
                'W. Wasserstein': error_dict[config]['weighted_wasserstein'],
                'Wasserstein': error_dict[config]['wasserstein'],
                'MMD': error_dict[config]['mmd'],
                'Energy': error_dict[config]['energy']
            }
            table_data.append(row)
    
    return table_data

def print_markdown_table(table_data, title):
    """Print data as markdown table"""
    if not table_data:
        print(f"No data for {title}")
        return
        
    print(f"\n### {title} Table\n")
    
    # Header
    headers = ['Variant', 'Cost', 'Reg', 'W. Wasserstein', 'Wasserstein', 'MMD', 'Energy']
    print('| ' + ' | '.join(headers) + ' |')
    print('|' + '|'.join(['-' * (len(h) + 2) for h in headers]) + '|')
    
    # Data rows
    for row in table_data:
        values = [str(row.get(h, '')) for h in headers]
        print('| ' + ' | '.join(values) + ' |')

def generate_tables(configs_to_evaluate, run_version, GSE_id='GSE232025', use_all_data=False, evaluate_on=-3, return_df=True):
    """Main function to generate tables"""
    
    base_folder = f'/Users/rssantanu/Desktop/codebase/constrained_FM/experiment_figures/use_all_data_{use_all_data}_{GSE_id}/{run_version}/'
    
    if not os.path.exists(base_folder):
        print(f"Error: Base folder does not exist: {base_folder}")
        return None, None
    
    all_configs = os.listdir(base_folder)
    IVP = {}
    next_step_prediction = {}
    
    # Load data
    for config in configs_to_evaluate:
        saved_configs = [c for c in all_configs if c.startswith(config)]
        for saved_config in saved_configs:
            folder_address = base_folder + saved_config + '/'
            IVP[saved_config] = {}
            next_step_prediction[saved_config] = {}
            
            try:
                IVP_dict = json.load(open(folder_address + 'IVP_error.json'))
                next_step_dict = json.load(open(folder_address + 'next_step_error.json'))

                # if IVP_dict['wasserstein'][-1] > 500 or IVP_dict['weighted_wasserstein'][-1] > 500:
                #     print(f"Error: Wasserstein is greater than 1000 for {saved_config}")
                #     continue

                # if next_step_dict['wasserstein'][-1] > 500 or next_step_dict['weighted_wasserstein'][-1] > 500:
                #     print(f"Error: Wasserstein is greater than 1000 for {saved_config}")
                #     continue

                IVP[saved_config]['wasserstein'] = IVP_dict['wasserstein']
                IVP[saved_config]['weighted_wasserstein'] = IVP_dict['weighted_wasserstein']
                IVP[saved_config]['mmd'] = IVP_dict['mmd']
                IVP[saved_config]['energy'] = IVP_dict['energy']
                
                next_step_prediction[saved_config]['wasserstein'] = next_step_dict['wasserstein']
                next_step_prediction[saved_config]['weighted_wasserstein'] = next_step_dict['weighted_wasserstein']
                next_step_prediction[saved_config]['mmd'] = next_step_dict['mmd']
                next_step_prediction[saved_config]['energy'] = next_step_dict['energy']
            except Exception as e:
                print(f"Error loading {folder_address}: {e}")
                continue
    
    # Process data
    error_dict_IVP = {config: {'wasserstein': None, 'weighted_wasserstein': None, 'mmd': None, 'energy': None} for config in configs_to_evaluate}
    error_dict_next_step = {config: {'wasserstein': None, 'weighted_wasserstein': None, 'mmd': None, 'energy': None} for config in configs_to_evaluate}
    
    for config in configs_to_evaluate:
        all_keys_ivp = IVP.keys()
        all_keys_next_step = next_step_prediction.keys()
        
        # Collect data for this config
        wasserstein_error_data_ivp = []
        weighted_wasserstein_error_data_ivp = []
        mmd_error_data_ivp = []
        energy_error_data_ivp = []
        
        wasserstein_error_data_next_step = []
        weighted_wasserstein_error_data_next_step = []
        mmd_error_data_next_step = []
        energy_error_data_next_step = []
        
        for k in [c for c in all_keys_ivp if c.startswith(config)]:
            wasserstein_error_data_ivp.append(IVP[k]['wasserstein'][evaluate_on])
            weighted_wasserstein_error_data_ivp.append(IVP[k]['weighted_wasserstein'][evaluate_on])
            mmd_error_data_ivp.append(IVP[k]['mmd'][evaluate_on])
            energy_error_data_ivp.append(IVP[k]['energy'][evaluate_on])
        
        for k in [c for c in all_keys_next_step if c.startswith(config)]:
            wasserstein_error_data_next_step.append(next_step_prediction[k]['wasserstein'][evaluate_on])
            weighted_wasserstein_error_data_next_step.append(next_step_prediction[k]['weighted_wasserstein'][evaluate_on])
            mmd_error_data_next_step.append(next_step_prediction[k]['mmd'][evaluate_on])
            energy_error_data_next_step.append(next_step_prediction[k]['energy'][evaluate_on])
        
        # Calculate means and stds
        if wasserstein_error_data_ivp:
            error_dict_IVP[config]['wasserstein'] = f'{np.mean(wasserstein_error_data_ivp):.5f} ± {np.std(wasserstein_error_data_ivp):.5f}'
            error_dict_IVP[config]['weighted_wasserstein'] = f'{np.mean(weighted_wasserstein_error_data_ivp):.5f} ± {np.std(weighted_wasserstein_error_data_ivp):.5f}'
            error_dict_IVP[config]['mmd'] = f'{np.mean(mmd_error_data_ivp):.5f} ± {np.std(mmd_error_data_ivp):.5f}'
            error_dict_IVP[config]['energy'] = f'{np.mean(energy_error_data_ivp):.5f} ± {np.std(energy_error_data_ivp):.5f}'
        
        if wasserstein_error_data_next_step:
            error_dict_next_step[config]['wasserstein'] = f'{np.mean(wasserstein_error_data_next_step):.5f} ± {np.std(wasserstein_error_data_next_step):.5f}'
            error_dict_next_step[config]['weighted_wasserstein'] = f'{np.mean(weighted_wasserstein_error_data_next_step):.5f} ± {np.std(weighted_wasserstein_error_data_next_step):.5f}'
            error_dict_next_step[config]['mmd'] = f'{np.mean(mmd_error_data_next_step):.5f} ± {np.std(mmd_error_data_next_step):.5f}'
            error_dict_next_step[config]['energy'] = f'{np.mean(energy_error_data_next_step):.5f} ± {np.std(energy_error_data_next_step):.5f}'
    
    if return_df:
        # Create DataFrames
        ivp_table_data = format_table_data(error_dict_IVP, configs_to_evaluate)
        next_step_table_data = format_table_data(error_dict_next_step, configs_to_evaluate)
        
        ivp_df = pd.DataFrame(ivp_table_data)
        next_step_df = pd.DataFrame(next_step_table_data)
        
        # Also print markdown tables for convenience
        print_markdown_table(ivp_table_data, "IVP")
        print_markdown_table(next_step_table_data, "Next-Step")
        
        return ivp_df, next_step_df
    else:
        # Format and print tables (original behavior)
        ivp_table_data = format_table_data(error_dict_IVP, configs_to_evaluate)
        next_step_table_data = format_table_data(error_dict_next_step, configs_to_evaluate)
        
        print_markdown_table(ivp_table_data, "IVP")
        print_markdown_table(next_step_table_data, "Next-Step")
        
        return error_dict_IVP, error_dict_next_step



In [147]:
### stereoseq
# post prior correction extrapolation; tstep: 4
# evaluate_on= -1


# configs_to_evaluate= ['vi_pca_C_g+mc_mEOT1', 'vi_pca_C_g+mc_mEOT2', 'vi_pca_C_g+mc_mEOT3']
# run_version = 'v_post_prior_correction_mEOT_w_g+mc'

# configs_to_evaluate= ['vi_pca_C_g+lr_mEOT1', 'vi_pca_C_g+lr_mEOT2']
# run_version = 'v_post_prior_correction_mEOT_w_g+lr'

# configs_to_evaluate= ['vi_pca_C_g+mc+lr_mEOT1', 'vi_pca_C_g+mc+lr_mEOT2', 'vi_pca_C_g+mc+lr_mEOT3']
# run_version = 'v_post_prior_correction_mEOT_w_g+mc+lr'

# configs_to_evaluate= ['vi_pca_C_g_mREOT1_lr']
# run_version = 'v_post_prior_correction_mREOT_w_C_g_reg_lr'

# configs_to_evaluate= ['vi_pca_C_g_mREOT1_mc']
# run_version = 'v_post_prior_correction_mREOT_w_C_g_reg_mc'

# configs_to_evaluate= ['vi_pca_C_g_mREOT1_mc+lr']
# run_version = 'v_post_prior_correction_mREOT_w_C_g_reg_mc+lr'



# post prior correction interpolation: tstep: 2
evaluate_on= -3


# configs_to_evaluate= ['vi_pca_C_g+mc_mEOT1', 'vi_pca_C_g+mc_mEOT2', 'vi_pca_C_g+mc_mEOT3']
# run_version = 'v_interp_post_prior_correction_mEOT_w_g+mc_ti2'

# configs_to_evaluate= ['vi_pca_C_g+lr_mEOT1', 'vi_pca_C_g+lr_mEOT2']
# run_version = 'v_interp_post_prior_correction_mEOT_w_g+lr_ti2'

# configs_to_evaluate= ['vi_pca_C_g+mc+lr_mEOT1', 'vi_pca_C_g+mc+lr_mEOT2', 'vi_pca_C_g+mc+lr_mEOT3']
# run_version = 'v_interp_post_prior_correction_mEOT_w_g+mc+lr_ti2'

# configs_to_evaluate= ['vi_pca_C_g_mREOT1_lr']
# run_version = 'v_interp_post_prior_correction_mREOT_w_C_g_reg_lr_ti2'

# configs_to_evaluate= ['vi_pca_C_g_mREOT1_mc']
# run_version = 'v_interp_post_prior_correction_mREOT_w_C_g_reg_mc_ti2'

configs_to_evaluate= ['vi_pca_C_g_mREOT1_mc+lr']
run_version = 'v_interp_post_prior_correction_mREOT_w_C_g_reg_mc+lr_ti2'





In [148]:
## generic parameters
GSE_id = 'GSE232025'
# GSE_id = 'GSE062025'
# GSE_id = 'GSE092025'

use_all_data = False

# Get DataFrames
ivp_df, next_step_df = generate_tables(
    configs_to_evaluate=configs_to_evaluate,
    run_version=run_version,
    GSE_id=GSE_id,
    use_all_data=use_all_data,
    evaluate_on=evaluate_on,
    return_df=True  # Set to True to get DataFrames
)


### IVP Table

| Variant | Cost | Reg | W. Wasserstein | Wasserstein | MMD | Energy |
|---------|------|-----|----------------|-------------|-----|--------|
| vi_pca_C_g_mREOT1_mc+lr | unknown | nil | 4.15434 ± 0.33990 | 4.38535 ± 0.41422 | 0.16697 ± 0.01580 | 32.36193 ± 4.21622 |

### Next-Step Table

| Variant | Cost | Reg | W. Wasserstein | Wasserstein | MMD | Energy |
|---------|------|-----|----------------|-------------|-----|--------|
| vi_pca_C_g_mREOT1_mc+lr | unknown | nil | 2.53461 ± 0.07981 | 2.40120 ± 0.12588 | 0.03792 ± 0.00276 | 10.56963 ± 1.00340 |


In [149]:
next_step_df

Unnamed: 0,Variant,Cost,Reg,W. Wasserstein,Wasserstein,MMD,Energy
0,vi_pca_C_g_mREOT1_mc+lr,unknown,nil,2.53461 ± 0.07981,2.40120 ± 0.12588,0.03792 ± 0.00276,10.56963 ± 1.00340


In [150]:
ivp_df

Unnamed: 0,Variant,Cost,Reg,W. Wasserstein,Wasserstein,MMD,Energy
0,vi_pca_C_g_mREOT1_mc+lr,unknown,nil,4.15434 ± 0.33990,4.38535 ± 0.41422,0.16697 ± 0.01580,32.36193 ± 4.21622
