# Visualize thresholds

In [1]:
import json
import matplotlib.pyplot as plt
from pathlib import Path

# Function to convert threshold names to numeric values
def convert_threshold_to_number(threshold_name):
    return float(threshold_name.replace('threshold_', '0.'))

# Function to process a single JSON file and generate the plot
def process_json_file(file_path, output_dir):
    with open(file_path, 'r') as file:
        data = json.load(file)
    
    # Extract thresholds and their corresponding averages, excluding 'top_k'
    thresholds = []
    harmony_averages = []
    notes_per_snapshot_averages = []
    single_harmony_averages = []
    single_disharmony_averages = []

    for entry in data['test_rows']:
        if entry['name'] != "top_k":  # Exclude 'top_k'
            thresholds.append(entry['name'])
            harmony_averages.append(entry['harmony_average'])
            notes_per_snapshot_averages.append(entry['notes_per_snapshot_average'])
            single_harmony_averages.append(entry.get('single_harmony_average', 0))
            single_disharmony_averages.append(entry.get('single_disharmony_average', 0))

    # Convert thresholds to numeric and sort based on those values
    sorted_indices = sorted(range(len(thresholds)), key=lambda i: convert_threshold_to_number(thresholds[i]))

    # Apply sorting to all lists
    thresholds = [thresholds[i] for i in sorted_indices]
    harmony_averages = [harmony_averages[i] for i in sorted_indices]
    notes_per_snapshot_averages = [notes_per_snapshot_averages[i] for i in sorted_indices]
    single_harmony_averages = [single_harmony_averages[i] for i in sorted_indices]
    single_disharmony_averages = [single_disharmony_averages[i] for i in sorted_indices]
    
    print(f"Creating plot for model {file_path.parent.name}")
    
    # Plotting the data
    plt.figure(figsize=(12, 6))

    plt.plot(thresholds, harmony_averages, marker='o', label='Overall Harmony Average', color='b')
    plt.plot(thresholds, notes_per_snapshot_averages, marker='o', label='Activity Average', color='g')
    plt.plot(thresholds, single_harmony_averages, marker='o', label='Harmony Average', color='r')
    plt.plot(thresholds, single_disharmony_averages, marker='o', label='Dissonance Average', color='m')

    plt.xlabel('Thresholds')
    plt.ylabel('Averages')
    plt.title(f'Impact of threshold adjustments on evaluation parameters - {file_path.parent.name}')
    plt.legend()
    plt.grid(True)

    # Save plot as a PNG file in the specified output directory
    output_dir.mkdir(parents=True, exist_ok=True)  # Ensure the output directory exists
    output_file = output_dir / f"music_evaluation_{file_path.parent.name}.png"
    plt.savefig(output_file)
    plt.close()

    print(f'Plot saved for {file_path} as {output_file}')

# Function to recursively search through a directory and process all JSON files
def process_all_json_files(input_dir, output_dir):
    # Iterate over all JSON files in the directory and its subdirectories
    for file_path in Path(input_dir).rglob('*music_evaluation.json'):
        process_json_file(file_path, output_dir)
        
# Example usage
input_directory = "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models"
output_directory = Path("/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/plots/music_scores_by_threshold")

# Call the function to process all JSON files in the directory and save plots to the specified output directory
process_all_json_files(input_directory, output_directory)


Creating plot for model transformer_01
Plot saved for /home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models/transformer_01/music_evaluation.json as /home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/plots/music_scores_by_threshold/music_evaluation_transformer_01.png
Creating plot for model transformer_0.2
Plot saved for /home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models/transformer_0.2/music_evaluation.json as /home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/plots/music_scores_by_threshold/music_evaluation_transformer_0.2.png
Creating plot for model transformer_0.3
Plot saved for /home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models/transformer_0.3/music_evaluation.json as /home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/plots/music_scores_by_threshold/music_evaluation_transformer_0.3.png
Creating plot for model transformer_0.4
Plot saved for /home/

## Visualize all eval scores + loss and improvements

In [5]:
def collect_and_plot_scores(root_dir, project_names, threshold_name, output_dir):
    # Initialize lists to hold data for plotting
    selected_projects = []
    single_harmony_averages = []
    single_disharmony_averages = []
    harmony_averages = []
    notes_per_snapshot_averages = []

    # Iterate over each specified project name
    for project_name in project_names:
        print(f"Checking project: {project_name}")
        project_path = Path(root_dir) / project_name
        all_files = list(project_path.rglob("music_evaluation.json"))

        if not all_files:
            print(f"No 'music_evaluation.json' found in {project_name}")
            continue

        for json_path in all_files:
            with open(json_path, 'r') as file:
                data = json.load(file)

                # Iterate through the test rows to find the specified threshold
                for test_row in data["test_rows"]:
                    if test_row["name"] == threshold_name:
                        selected_projects.append(project_name)
                        single_harmony_averages.append(test_row["single_harmony_average"])
                        single_disharmony_averages.append(test_row["single_disharmony_average"])
                        harmony_averages.append(test_row["harmony_average"])
                        notes_per_snapshot_averages.append(test_row["notes_per_snapshot_average"])

                        print(f"Project: {project_name} has data for {threshold_name}")
                        
                        print(f"Project: {project_name}")
                        print(f"Threshold: {test_row['name']}")
                        print("Averages:")
                        print(f"  Single Harmony Average: {test_row['single_harmony_average']}")
                        print(f"  Single Disharmony Average: {test_row['single_disharmony_average']}")
                        print(f"  Harmony Average: {test_row['harmony_average']}")
                        print(f"  Notes Per Snapshot Average: {test_row['notes_per_snapshot_average']}")
                        print("\n")

    # Ensure we have data to plot
    if len(selected_projects) == 0:
        print(f"No data found for threshold {threshold_name} in the specified projects.")
        return

    # Plotting the data
    plt.figure(figsize=(12, 6))

    plt.plot(selected_projects, harmony_averages, marker='o', label='Overall Harmony Average', color='b')
    plt.plot(selected_projects, notes_per_snapshot_averages, marker='o', label='Activity Average', color='g')
    plt.plot(selected_projects, single_harmony_averages, marker='o', label='Harmony Average', color='r')
    plt.plot(selected_projects, single_disharmony_averages, marker='o', label='Dissonance Average', color='m')

    plt.xlabel('Models')
    plt.ylabel('Averages')
    plt.title(f'Impact of model configurations on evaluation parameters')
    plt.legend()
    plt.grid(True)

    # Save plot as a PNG file in the specified output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)  # Ensure the output directory exists
    output_file = output_dir / f"music_evaluation_{threshold_name}.png"
    plt.savefig(output_file)
    plt.close()

    print(f'Plot saved as {output_file}')


#Hier auf genauen namen vom ordner achten
# Example usage:
project_names = [
    "transformer_1.0",
    "transformer_3.1",
    "transformer_3.2",
]

collect_and_plot_scores(
    "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models",
    project_names,
    "threshold_02",
    "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/plots/music_scores_by_model"
)


Checking project: transformer_1.0
Project: transformer_1.0 has data for threshold_02
Project: transformer_1.0
Threshold: threshold_02
Averages:
  Single Harmony Average: 1.25390625
  Single Disharmony Average: 0.494140625
  Harmony Average: 0.759765625
  Notes Per Snapshot Average: 0.41388916289160127


Checking project: transformer_3.1
Project: transformer_3.1 has data for threshold_02
Project: transformer_3.1
Threshold: threshold_02
Averages:
  Single Harmony Average: 1.9833984375
  Single Disharmony Average: 0.24988756465032605
  Harmony Average: 1.733510872849674
  Notes Per Snapshot Average: 0.2209772991898091


Checking project: transformer_3.2
Project: transformer_3.2 has data for threshold_02
Project: transformer_3.2
Threshold: threshold_02
Averages:
  Single Harmony Average: 1.2513020833333333
  Single Disharmony Average: 0.0
  Harmony Average: 1.2513020833333333
  Notes Per Snapshot Average: 0.25112874622493464


Plot saved as /home/falaxdb/Repos/minus1/transformer_decoder_tr

In [3]:
import transformer_decoder_training.analyze_results_functions as analyze

def plot_validation_loss_and_improvement(data, selected_project_names, output_dir):
    # Filter the data based on the selected project names
    filtered_data = [item for item in data if item['project_name'] in selected_project_names]

    # Sort the filtered data to match the order of the selected project names
    filtered_data.sort(key=lambda x: selected_project_names.index(x['project_name']))

    # Initialize lists to hold data for plotting
    project_names = [item['project_name'] for item in filtered_data]
    best_val_losses = [item['best_val_loss'] for item in filtered_data]
    improvements = [item['improvement'] for item in filtered_data]

    # Create a figure with two subplots (one for each metric)
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # Plot for Best Validation Loss
    ax1.plot(project_names, best_val_losses, marker='o', label='Best Validation Loss', color='b')
    ax1.set_xlabel('Models')
    ax1.set_ylabel('Best Validation Loss', color='b')
    ax1.tick_params(axis='y', labelcolor='b')
    ax1.grid(True)

    plt.title('Best Validation Loss by Model')

    # Save plot as a PNG file in the specified output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)  # Ensure the output directory exists
    output_file = output_dir / "validation_loss_and_improvement.png"
    plt.savefig(output_file)
    plt.close()

    print(f'Plot saved as {output_file}')
    

improvements = analyze.compute_val_loss_improvement("/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models")

plot_validation_loss_and_improvement(improvements, project_names, "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/plots/model_losses")

Found 23 projects
Project transformer_1.9: First Val Loss = 0.3000, Best Val Loss = 0.1295 (Epoch 179), Improvement = 56.82%
Project transformer_1.10: First Val Loss = 0.3236, Best Val Loss = 0.1675 (Epoch 200), Improvement = 48.24%
Project transformer_1.12: First Val Loss = 0.1270, Best Val Loss = 0.1003 (Epoch 17), Improvement = 20.95%
Project transformer_1.11: First Val Loss = 0.1237, Best Val Loss = 0.0982 (Epoch 9), Improvement = 20.64%
Project transformer_1.7: First Val Loss = 0.1740, Best Val Loss = 0.1400 (Epoch 63), Improvement = 19.54%
Project transformer_1.11_overfitted: First Val Loss = 0.1171, Best Val Loss = 0.0947 (Epoch 9), Improvement = 19.12%
Project transformer_2.0: First Val Loss = 0.0350, Best Val Loss = 0.0290 (Epoch 14), Improvement = 17.14%
Project transformer_1.1: First Val Loss = 0.1630, Best Val Loss = 0.1380 (Epoch 12), Improvement = 15.34%
Project transformer_1.0: First Val Loss = 0.1650, Best Val Loss = 0.1450 (Epoch 12), Improvement = 12.12%
Project trans

In [4]:
import json
from pathlib import Path
import matplotlib.pyplot as plt

def collect_and_plot_scores(root_dir, project_names, threshold_name, output_dir, loss_values):
    # Initialize lists to hold data for plotting
    selected_projects = []
    single_harmony_averages = []
    single_disharmony_averages = []
    harmony_averages = []
    notes_per_snapshot_averages = []

    # Iterate over each specified project name
    for project_name in project_names:
        print(f"Checking project: {project_name}")
        project_path = Path(root_dir) / project_name
        all_files = list(project_path.rglob("music_evaluation.json"))

        if not all_files:
            print(f"No 'music_evaluation.json' found in {project_name}")
            continue

        for json_path in all_files:
            with open(json_path, 'r') as file:
                data = json.load(file)

                # Iterate through the test rows to find the specified threshold
                for test_row in data["test_rows"]:
                    if test_row["name"] == threshold_name:
                        selected_projects.append(project_name)
                        single_harmony_averages.append(test_row["single_harmony_average"])
                        single_disharmony_averages.append(test_row["single_disharmony_average"])
                        harmony_averages.append(test_row["harmony_average"])
                        notes_per_snapshot_averages.append(test_row["notes_per_snapshot_average"])

                        print(f"Project: {project_name} has data for {threshold_name}")
                        
                        print(f"Project: {project_name}")
                        print(f"Threshold: {test_row['name']}")
                        print("Averages:")
                        print(f"  Single Harmony Average: {test_row['single_harmony_average']}")
                        print(f"  Single Disharmony Average: {test_row['single_disharmony_average']}")
                        print(f"  Harmony Average: {test_row['harmony_average']}")
                        print(f"  Notes Per Snapshot Average: {test_row['notes_per_snapshot_average']}")
                        print("\n")

    # Ensure we have data to plot
    if len(selected_projects) == 0:
        print(f"No data found for threshold {threshold_name} in the specified projects.")
        return

    # Plotting the data with dual y-axes
    fig, ax1 = plt.subplots(figsize=(12, 6))

    ax1.plot(selected_projects, harmony_averages, marker='o', label='Overall Harmony Average', color='b')
    ax1.plot(selected_projects, notes_per_snapshot_averages, marker='o', label='Activity Average', color='g')
    ax1.plot(selected_projects, single_harmony_averages, marker='o', label='Harmony Average', color='r')
    ax1.plot(selected_projects, single_disharmony_averages, marker='o', label='Dissonance Average', color='m')

    ax1.set_xlabel('Models')
    ax1.set_ylabel('Averages')
    ax1.legend(loc='upper left')
    ax1.grid(True)

    # Create a second y-axis for the loss values
    ax2 = ax1.twinx()
    ax2.plot(selected_projects, loss_values, marker='o', label='Best Validation Loss', color='orange')
    ax2.set_ylabel('Best Validation Loss', color='orange')
    ax2.tick_params(axis='y', labelcolor='orange')

    # Combine legends from both axes
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax2.get_legend_handles_labels()
    ax1.legend(lines_1 + lines_2, labels_1 + labels_2, loc='upper center')

    plt.title(f'Impact of model configurations on evaluation parameters')

    # Save plot as a PNG file in the specified output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)  # Ensure the output directory exists
    output_file = output_dir / f"music_evaluation_{threshold_name}.png"
    plt.savefig(output_file)
    plt.close()

    print(f'Plot saved as {output_file}')


# Example usage:
project_names = [
    "transformer_1.0",
    "transformer_1.8",
    "transformer_1.12",
]

# improvements = analyze.compute_val_loss_improvement("/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models")
# loss_values = [item['best_val_loss'] for item in improvements if item['project_name'] in project_names]

# collect_and_plot_scores(
#    "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/saved_models",
#    project_names,
#    "threshold_02",
#    "/home/falaxdb/Repos/minus1/transformer_decoder_training/saved_files/plots/music_scores_by_model",
#    loss_values
#)
