In [1]:
!pip install -q seaborn matplotlib pandas

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [5]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def load_experiment_data(experiment_name, num_meta_loops):
    all_data = []
    for loop in range(num_meta_loops):
        loop_dir = os.path.join(f"{experiment_name}_models", f"loop_{loop}")
        csv_path = os.path.join(loop_dir, "results.csv")
        if os.path.exists(csv_path):
            df = pd.read_csv(csv_path)
            df['Loop'] = loop
            all_data.append(df)
        print(csv_path)
    if all_data:
        return pd.concat(all_data, ignore_index=True)
    else:
        print("No data found.")
        return None

def plot_learning_rate_vs_correct_count(df):
    plt.figure(figsize=(12, 6))
    sns.scatterplot(data=df, x='Learning Rate', y='Correct Count', hue='Loop', palette='viridis', s=100)
    plt.xscale('log')
    plt.title('Learning Rate vs Correct Count')
    plt.xlabel('Learning Rate')
    plt.ylabel('Correct Count')
    plt.legend(title='Loop', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_learning_rate_vs_train_loss(df):
    plt.figure(figsize=(12, 6))
    sns.lineplot(data=df, x='Learning Rate', y='Train Loss', hue='Loop', marker='o', palette='viridis', linewidth=2.5)
    plt.xscale('log')
    plt.title('Learning Rate vs Train Loss')
    plt.xlabel('Learning Rate')
    plt.ylabel('Train Loss')
    plt.legend(title='Loop', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_learning_rate_vs_perplexity(df):
    plt.figure(figsize=(12, 6))
    sns.lineplot(data=df, x='Learning Rate', y='Perplexity', hue='Loop', marker='o', palette='viridis', linewidth=2.5)
    plt.xscale('log')
    plt.title('Learning Rate vs Perplexity')
    plt.xlabel('Learning Rate')
    plt.ylabel('Perplexity')
    plt.legend(title='Loop', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_correct_count_vs_train_loss(df):
    plt.figure(figsize=(12, 6))
    sns.scatterplot(data=df, x='Correct Count', y='Train Loss', hue='Loop', palette='viridis', s=100)
    plt.title('Correct Count vs Train Loss')
    plt.xlabel('Correct Count')
    plt.ylabel('Train Loss')
    plt.legend(title='Loop', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_correct_count_vs_perplexity(df):
    plt.figure(figsize=(12, 6))
    sns.scatterplot(data=df, x='Correct Count', y='Perplexity', hue='Loop', palette='viridis', s=100)
    plt.title('Correct Count vs Perplexity')
    plt.xlabel('Correct Count')
    plt.ylabel('Perplexity')
    plt.legend(title='Loop', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_pairwise_relationships(df):
    sns.pairplot(df, hue='Loop', palette='viridis', markers='o')
    plt.suptitle('Pairwise Relationships', y=1.02)
    plt.tight_layout()
    plt.show()

def plot_histograms(df):
    df.hist(bins=20, figsize=(15, 10), grid=False)
    plt.suptitle('Histograms of Variables', y=1.02)
    plt.tight_layout()
    plt.show()

def plot_learning_rate_vs_metrics(df):
    metrics = ['Correct Count', 'Train Loss', 'Perplexity']
    plt.figure(figsize=(15, 10))
    for i, metric in enumerate(metrics):
        plt.subplot(3, 1, i+1)
        sns.lineplot(data=df, x='Learning Rate', y=metric, hue='Loop', marker='o', palette='viridis', linewidth=2.5)
        plt.xscale('log')
        plt.title(f'Learning Rate vs {metric}')
        plt.xlabel('Learning Rate')
        plt.ylabel(metric)
        plt.legend(title='Loop', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True)
    plt.tight_layout()
    plt.show()

def summarize_experiment_data(df):
    summary = df.describe(include='all')
    print("Summary of Experiment Data:\n", summary)

def main():
    experiment_name = "/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models"
    num_meta_loops = 25

    df = load_experiment_data(experiment_name, num_meta_loops)
    if df is not None:
        summarize_experiment_data(df)
        plot_learning_rate_vs_correct_count(df)
        plot_learning_rate_vs_train_loss(df)
        plot_learning_rate_vs_perplexity(df)
        plot_correct_count_vs_train_loss(df)
        plot_correct_count_vs_perplexity(df)
        plot_pairwise_relationships(df)
        plot_histograms(df)
        plot_learning_rate_vs_metrics(df)

if __name__ == "__main__":
    main()


/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_0/results.csv
/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_1/results.csv
/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_2/results.csv
/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_3/results.csv
/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_4/results.csv
/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_5/results.csv
/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_6/results.csv
/workspace/slice-monorepo/sub_validations/HT_LR_predictor_studier/50_100_high_learning_models_models/loop_7/results.csv
/workspace/slice-monorepo/sub_validation