In [None]:
'''
Libraries
'''
import os, csv
import matplotlib.pyplot as plt
import numpy as np


# -----Split here-----


'''
Data paths
'''
basepath = '../data/'
algorithms = os.listdir(basepath)


# -----Split here-----


'''
Define figure and axes
'''
figure, axes = plt.subplots(figsize=(12,8))
data = {}


# -----Split here-----


'''
Reads the CSV files
Assumes the files are rewards*.csv, it checks for 'rewards' in the first line
Stores the data across all runs for every algorithm into the data dictionary
'''
for alg in algorithms:
    alg_path = os.listdir(basepath+alg)
    
    rewards_files = np.array([i for i in alg_path if 'rewards' in i])
    numruns = len(rewards_files)
    
    returns_list_files = np.array([])
    
    for file in rewards_files:
        rewards_csv = csv.reader(open(basepath+alg+'/'+file))
        rewards_list = np.array([float(reward) for row in rewards_csv for reward in row if reward != 'rewards'])

        returns_list = np.array([])
        returns = 0
        for i in range(len(rewards_list)):
            returns += rewards_list[i]
            returns_list = np.append(returns_list, returns)
        
        if len(returns_list_files) == 0:
            returns_list_files = np.hstack((returns_list_files, returns_list))
        else:
            returns_list_files = np.vstack((returns_list_files, returns_list))
    
    data[str(alg)] = returns_list_files

    
# -----Split here-----


'''
Calculates the mean performance across multiple runs
Calculates the 95% confidence interval
Plots the mean performance with the 95% confidence interval in a shaded region
'''
for alg, returns_list_files in data.items():
    mean_returns_list = np.mean(returns_list_files, axis=0)
    stddev_returns_list = np.std(returns_list_files, axis=0)
    
    CI95_returns_list = 1.96 * stddev_returns_list / np.sqrt(numruns)
    min_CI95_returns_list = mean_returns_list - CI95_returns_list
    max_CI95_returns_list = mean_returns_list + CI95_returns_list
    
    lenrun = len(mean_returns_list)
    xaxis = np.array([i for i in range(1,lenrun+1)])
    
    axes.plot(xaxis, mean_returns_list, label=str(alg))
    axes.fill_between(xaxis, min_CI95_returns_list, max_CI95_returns_list, alpha=0.25)

    
# -----Split here-----


'''
Configuration settings of the plot
'''
axes.set_xlabel('Timesteps', size=15, labelpad=35)
axes.set_ylabel('Return', size=15, rotation=0, labelpad=45)
axes.legend(loc=0, fontsize=12)
plt.yticks(fontsize=14)
plt.xticks(fontsize=14)
plt.tight_layout()
plt.subplots_adjust(left=0.15, right = 0.9, bottom=0.2, top= 0.85, wspace=0.45, hspace = 0.5)


    
# -----Split here-----


'''
Save and display the plot
'''


plt.savefig('comparison.png',dpi=500)
plt.show()