In [None]:
# import os
import math
import pandas as pd
import matplotlib.pyplot as plt
import glob
import re
from matplotlib.backends.backend_pdf import PdfPages

# Specify the path to your CSV files
file_paths = glob.glob("./output_EGD_with_Adam_Optimizer(switch_targets)*")
# Set file's name to save plots
save_title = "Undefined"
if str(file_paths[0]).__contains__("EGD"):
    save_title = "Exponentiated Gradient Descent"
elif str(file_paths[0]).__contains__("PGD"):
    save_title = "Projected Gradient Descent"
# Regular expression to find the substring between `)_` and `.csv`
pattern = r'\)_([^\.]+)\.csv'

# Initialize dictionaries to store the data
epoch_data = {}
unsafe_loss_data = {}
safe_loss_data = {}

pdf_title = save_title+" (Switching Targets).pdf"
with PdfPages(pdf_title) as pdf:
    # Loop over each file in the folder
    for file_path in file_paths:
        if file_path.endswith('.csv'):
            # Read the CSV file
            df = pd.read_csv(file_path)
            epochs = df['epoch']
            unsafe_loss = df['unsafe_loss']
            safe_loss = df['safe_loss']
            # Extract label from the file_name
            # Search for the pattern in the input string
            match = re.search(pattern, file_path)
            label = file_path
            # Extract the matched substring if the pattern is found
            if match:
                label = match.group(1)
                
            for epoch, u_loss, s_loss in zip(epochs, unsafe_loss, safe_loss):
                if epoch not in epoch_data:
                    epoch_data[epoch] = []
                    unsafe_loss_data[epoch] = []
                    safe_loss_data[epoch] = []
                epoch_data[epoch].append(epoch)
                if math.isnan(u_loss):
                    break
                unsafe_loss_data[epoch].append(u_loss)
                safe_loss_data[epoch].append(s_loss)

            # Plot the results
            plt.figure(figsize=(6, 4))
            plt.plot(epochs, unsafe_loss, label='Unsafe Loss', color='blue')
            plt.plot(epochs, safe_loss, label='Safe Loss', color='red')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title(label[label.find('('):-4])
            plt.legend()
            plt.grid(True)

            # Save the current plot to the PDF file
            pdf.savefig()
            plt.show()
            plt.close()  # Close the current figure to release memory            


            # # Check if the required columns exist
            # if all(col in df.columns for col in ['epoch', 'unsafe_loss', 'safe_loss']):
            #     # Plotting
            #     # plt.figure()
            #     plt.figure(figsize=(6, 4))
            #     plt.plot(epochs, unsafe_loss, label='Unsafe Loss')
            #     plt.plot(epochs, safe_loss, label='Safe Loss')
                
            #     # Adding labels and title
            #     plt.xlabel('Epoch')
            #     plt.ylabel('Loss')
            #     plt.title(label[label.find('('):-4])
            #     # plt.title(file_path.replace('.csv', ''))
            #     plt.legend()
            #     plt.grid(True)
            #     # Save the plot with the same name as the CSV file (without extension)
            #     # plt.savefig(os.path.join(folder_path, filename.replace('.csv', '.png')))
            #     # Save the current plot to the PDF file
            #     pdf.savefig()
            #     # Show the plot (optional)
            #     plt.show()
                
            #     # Close the plot to free memory
            #     plt.close()
            # else:
            #     print(f"File {file_path} does not contain the required columns.")


In [None]:
import numpy as np
# Calculate the mean losses for each epoch
epochs = sorted(epoch_data.keys())
median_unsafe_loss = [np.median(unsafe_loss_data[epoch]) for epoch in epochs]
median_safe_loss = [np.median(safe_loss_data[epoch]) for epoch in epochs]

print("Unsafe Loss:", median_unsafe_loss)
print("max:", np.max(median_unsafe_loss), end=", ")
print("min:", np.min(median_unsafe_loss))
print("Safe Loss:", median_safe_loss)
print("max:", np.max(median_safe_loss), end=", ")
print("min:", np.min(median_safe_loss))


pdf_title = save_title+" (Switching Targets - Median).pdf"
# Plot the results
plt.figure(figsize=(6, 4))
plt.plot(epochs, median_unsafe_loss, label='Median Unsafe Loss', color='blue')
plt.plot(epochs, median_safe_loss, label='Median Safe Loss', color='red')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title(label=save_title)
plt.legend()
plt.grid(True)
# fig_name = save_title+".pdf"
plt.savefig(pdf_title, format="pdf", bbox_inches="tight")
plt.show()