## Imports

In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import os
import json
import re
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset

import dist_heatmap

# Analyse Results

## Files list

In [None]:
logs_files = {
    'arc': 'log_files/moe_all_layers_analysis_arc_en_he_500_20250921_090254.log',
    'mmlu': 'log_files/moe_all_layers_analysis_mmlu_en_he_500_20250921_123201.log',
    'gsm': 'log_files/moe_all_layers_analysis_gsm_en_he_500_20250921_115824.log',
    'copa': 'log_files/moe_all_layers_analysis_copa_en_he_500_20250921_125427.log',
    'nli': 'log_files/moe_all_layers_analysis_nli_en_he_460_20250925_154220.log',
    'ted': 'log_files/moe_all_layers_analysis_ted_he_en_chunks_20250925_161115.log',
}

# 'dist' files have the average distributions of the output of the router
# 'actv' files have the average *activation* of experts - meaning the average
# of the 0-1 vector, with 1 at the top 8 and 0 at the other for each token

experts_dist_files = {
    # 'arc_dist': 'moe_csv_results/moe_all_layers_analysis_arc_en_he_500_20250921_090254_avg_dist.csv',
    'arc_actv': 'moe_csv_results/moe_all_layers_analysis_arc_en_he_500_20250921_090254_avg_activation.csv',
    # 'mmlu_dist': 'moe_csv_results/moe_all_layers_analysis_mmlu_en_he_500_20250921_123201_avg_dist.csv',
    'mmlu_actv': 'moe_csv_results/moe_all_layers_analysis_mmlu_en_he_500_20250921_123201_avg_activation.csv',
    # 'gsm_dist': 'moe_csv_results/moe_all_layers_analysis_gsm_en_he_500_20250921_115824_avg_dist.csv',
    'gsm_actv': 'moe_csv_results/moe_all_layers_analysis_gsm_en_he_500_20250921_115824_avg_activation.csv',
    # 'copa_dist': 'moe_csv_results/moe_all_layers_analysis_copa_en_he_500_20250921_125427_avg_dist.csv',
    'copa_actv': 'moe_csv_results/moe_all_layers_analysis_copa_en_he_500_20250921_125427_avg_activation.csv',
    # 'nli_dist': 'moe_csv_results/moe_all_layers_analysis_nli_en_he_460_20250925_154220_avg_dist.csv',
    'nli_actv': 'moe_csv_results/moe_all_layers_analysis_nli_en_he_460_20250925_154220_avg_activation.csv',
    # 'ted_dist': 'moe_csv_results/moe_all_layers_analysis_ted_he_en_chunks_20250925_161115_avg_dist.csv',
    'ted_actv': 'moe_csv_results/moe_all_layers_analysis_ted_he_en_chunks_20250925_161115_avg_activation.csv',
}

for k in experts_dist_files:
    experts_dist_files[k] = pd.read_csv(experts_dist_files[k])

experts_columns = [f'Expert_{i}' for i in range(128)]

## General metrics

In [None]:
# Run to re-create the graphs in the document
fig, axs = plt.subplots(3, 1, figsize=(10, 24))
num_layers = 48

for k in logs_files:
    dist_heatmap.plot_moe_analysis_from_log(logs_files[k], k, axs)

for ax in axs:
    ax.set_xlabel('Layer Number', fontsize=16)
    ax.set_ylabel('Average Value', fontsize=16)
    # Set x-axis ticks to be every 2 layers
    ax.set_xticks(np.arange(0, num_layers, 2))
    ax.legend(fontsize=16, loc='upper left')
    ax.grid()

axs[0].set_ylim(0, 1)
axs[1].set_ylim(0, 3)
axs[2].set_ylim(0, 8)

axs[0].set_title(f'Average MoE Cosine Distance per Layer', fontsize=20)
axs[1].set_title(f'Average MoE Overlap 3 per Layer', fontsize=20)
axs[2].set_title(f'Average MoE Overlap 8 per Layer', fontsize=20)

fig.tight_layout()

plt.savefig('plots/avg_moe_graphs.jpeg')
plt.show()

## Distribution by layer

In [None]:
# Which layers to show
from_layer = 44
to_layer = 48

In [None]:
# Run to plot the distributions
m = 4.3 / 44
figsize = (15, 0.3 + (to_layer - from_layer) * m)

v_lst = []
global_vmax = 0

for bnch in experts_dist_files:
    print(bnch)
    dist_v = experts_dist_files[bnch][(
        experts_dist_files[bnch]['Layer'] >= from_layer
    ) & (
        experts_dist_files[bnch]['Layer'] <= to_layer
    )]
    vmax = dist_v[experts_columns].max().max()
    # vmax = 1
    global_vmax = max([global_vmax, vmax])
    v_lst.append(dist_v)
    # dist_heatmap.plot_heatmap_experts(
    #     dist_v,
    #     vmax,
    #     figsize=figsize,
    #     showdiff=False,
    #     show_top=False,
    #     from_layer=from_layer,
    #     to_layer=to_layer
    # )

# From top to bottom order
v_lst = v_lst[::-1]

dist_heatmap.comp_one_lan(
    v_lst,
    vmax=global_vmax,
    lan='English',
    figsize=figsize,
    show_top=False,
    from_layer=from_layer,
    to_layer=to_layer,
)

dist_heatmap.comp_one_lan(
    v_lst,
    vmax=global_vmax,
    lan='Hebrew',
    figsize=figsize,
    show_top=False,
    from_layer=from_layer,
    to_layer=to_layer,
)