In [15]:
from lib.ekyn import *
from lib.env import *
from sage.utils import *
from sage.models import *
from sklearn.metrics import ConfusionMatrixDisplay,f1_score
import pandas as pd
import os

states = {}

for experiment in os.listdir(f'{EXPERIMENTS_PATH}'):
    if experiment == '.Trash-1000':
        continue
    state = torch.load(f'{EXPERIMENTS_PATH}/{experiment}/state.pt',map_location='cpu',weights_only=False)
    states[experiment] = state
# for experiment in os.listdir(f'{BEAST_EXPERIMENTS_PATH}'):
#     if experiment == '.Trash-1000':
#         continue
#     state = torch.load(f'{BEAST_EXPERIMENTS_PATH}/{experiment}/state.pt',map_location='cpu',weights_only=False)
#     states[experiment] = state

df = pd.DataFrame([states[experiment] for experiment in states])
pd.set_option('display.max_rows', 500)
df = df.sort_values(by='start_time',ascending=False)
df = df.reset_index(drop=True)

# df = df.iloc[[0,2]]
# df = df.iloc[3:]
info_df = df[['start_time','best_dev_loss','dropout','lr','wd','batch_size','norm','widthi']]

display(info_df.sort_values(by='best_dev_loss'))
plot_loss_curves(df,moving_window_length=9,lstm=False)

Unnamed: 0,start_time,best_dev_loss,dropout,lr,wd,batch_size,norm,widthi
2,2024_27_08_18_50_05,0.231558,0.1,0.0003,0.01,512,layer,[64]
9,2024_27_08_17_01_30,0.231932,0.1,0.0003,0.01,512,layer,[64]
3,2024_27_08_18_42_50,0.247675,0.5,0.0003,0.01,512,layer,"[64, 128, 256, 512]"
6,2024_27_08_18_05_50,0.251379,0.1,0.003,0.01,512,layer,"[64, 128]"
8,2024_27_08_17_47_12,0.253714,0.1,0.0003,0.01,512,layer,[64]
5,2024_27_08_18_17_58,0.255064,0.1,0.0003,0.01,512,layer,"[64, 128, 256, 512]"
12,2024_27_08_15_31_27,0.267809,0.1,0.0003,0.01,512,batch,[4]
7,2024_27_08_17_52_00,0.271507,0.1,0.0003,0.01,512,layer,[4]
11,2024_27_08_16_35_33,0.287421,0.1,0.0003,0.01,512,layer,[4]
10,2024_27_08_17_00_13,0.288772,0.1,0.0003,0.01,512,layer,[4]


In [None]:
from sklearn.model_selection import train_test_split
from lib.ekyn import *

MODEL_ID = f'2024_16_08_12_48_06'

ekyn_ids = get_ekyn_ids()
train_ids,test_ids = train_test_split(ekyn_ids,test_size=.2,shuffle=True,random_state=0)
print(train_ids,test_ids)
state = torch.load(f'{EXPERIMENTS_PATH}/{MODEL_ID}/state.pt',map_location='cpu',weights_only=False)
model = copy.deepcopy(state['model'])
model.load_state_dict(state['best_model_wts'])
trainloader,testloader = get_sequenced_dataloaders(batch_size=state['batch_size'],sequence_length=state['sequence_length'])
# trainloader,testloader = get_epoched_dataloaders(batch_size=state['batch_size'],robust=False)
loss,y_true,y_pred = evaluate(dataloader=testloader,model=model,criterion=state['criterion'],device='cuda')
ConfusionMatrixDisplay.from_predictions(y_true,y_pred,normalize='true',colorbar=False)
plt.title(f'f1 : {f1_score(y_true,y_pred,average="macro"):.3f}')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import f1_score

# Assuming df is populated with necessary data
df = {}

# Initialize the figure and axes
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(14, 8))

# Loop through test IDs
for id in test_ids:
    # Load data and perform necessary calculations
    X, y = load_ekyn_pt(id=id, condition='PF')

    # ECDF plot
    std, mean = torch.std_mean(X[torch.where(y.argmax(axis=1) == 1)[0]], dim=1)
    sns.ecdfplot(std, linewidth=2, label=id, ax=axes[0])

    # Calculate F1 score
    dataloader = get_epoched_dataloader_for_ids(ids=[id],robust=state['robust'])
    loss, y_true, y_pred = evaluate(dataloader=dataloader, model=model, criterion=state['criterion'], device='cuda')
    f1 = float(f1_score(y_true, y_pred, average="macro"))

    # Bar plot
    sns.barplot(x=[id], y=[f1], ax=axes[1])  # Plotting F1 score for each ID


# for id in train_ids:
#     # Load data and perform necessary calculations
#     X, y = load_ekyn_pt(id=id, condition='PF')

#     # ECDF plot
#     std, mean = torch.std_mean(X[torch.where(y.argmax(axis=1) == 1)[0]], dim=1)
#     sns.ecdfplot(std, linewidth=.5, label=id, ax=axes[0],color='black')

# Final adjustments
axes[0].legend()
axes[0].set_xlim([0, .000150])
axes[1].set_ylim([0, 1])  # Adjust y-axis limits for the bar plot

# Display the plot
plt.tight_layout()
plt.show()
