In [None]:
'''
Libraries
'''
import os, csv, time
import matplotlib.pyplot as plt
import numpy as np


# -----Split here-----


'''
Data paths
'''
basepath = '../data/'
algorithms = os.listdir(basepath)


# -----Split here-----


'''
Define figure and axes
'''
figure, axes = plt.subplots(figsize=(12,8))
data = {}


# -----Split here-----


'''
Reads the CSV files
Skips the first line of the CSV file as it is a description 
Stores the data across all runs for every algorithm into the data dictionary
'''
for alg in algorithms:
    if alg!='esarsa':
        continue
    alg_path = os.listdir(basepath+alg)
    
    rewards_files = np.array([i for i in alg_path if 'rewards' in i])
    numruns = len(rewards_files)
    
    returns_list_files = np.array([])

    for file in rewards_files:
        rewards_csv = csv.reader(open(basepath+alg+'/'+file))
        next(rewards_csv) #skips the first line
        
        rewards_list = np.array([float(reward) for row in rewards_csv for reward in row])
        returns_list = np.array([0 for i in range(len(rewards_list))])
        
        returns_list[0] = rewards_list[0]

        for i in range(1,len(rewards_list)):
            if rewards_list[i] == 0:
                returns_list[i] = returns_list[i-1]
            else:
                returns_list[i] = returns_list[i-1] + rewards_list[i]
        
        if returns_list_files.size != 0:
            returns_list_files = np.vstack((returns_list_files, returns_list))
        else:
            returns_list_files = np.hstack((returns_list_files, returns_list))
    
    data[str(alg)] = returns_list_files

    
# -----Split here-----


'''
Calculates the mean and median performance across multiple runs
Calculates the 95% confidence interval
Plots the mean and median performance with the 95% confidence interval in a shaded region
'''
for alg, returns_list_files in data.items():
    mean_returns_list = np.mean(returns_list_files, axis=0)
    stddev_returns_list = np.std(returns_list_files, axis=0)
    
    CI95_returns_list = 1.96 * stddev_returns_list / np.sqrt(numruns)
    min_CI95_returns_list = mean_returns_list - CI95_returns_list
    max_CI95_returns_list = mean_returns_list + CI95_returns_list
    
    lenrun = len(mean_returns_list)
    xaxis = np.array([i for i in range(1,lenrun+1)])
    
    axes.plot(xaxis, mean_returns_list, label=str(alg)+'-mean')
    axes.fill_between(xaxis, min_CI95_returns_list, max_CI95_returns_list, alpha=0.25)

    '''
    # Calculates median across every timestep, does not report the median run

    sorted_returns_list = np.sort(returns_list_files, axis=0)
    if numruns % 2 == 1:
        median_returns_list = sorted_returns_list[int((numruns-1)/2)]
    else:
        median_returns_list = (sorted_returns_list[int(numruns/2)] + sorted_returns_list[int(numruns/2 - 1)]) / 2.0
    axes.plot(xaxis, median_returns_list, label=str(alg)+'-median', color='black') 
    '''

    '''
    # Calculates the median run based on final performance
    indices = np.argsort(returns_list_files[:,-1])
    if numruns % 2 == 1:
        median_returns_list = returns_list_files[indices[int((numruns-1)/2)]]
    else:
        median_returns_list = (returns_list_files[indices[int(numruns/2)]] + returns_list_files[indices[int(numruns/2 - 1)]])/2.0
    axes.plot(xaxis, median_returns_list, label=str(alg)+'-median', color='black') 
    '''

    
# -----Split here-----


'''
Configuration settings of the plot
'''
axes.set_xlabel('Timesteps', size=15, labelpad=35)
axes.set_ylabel('Return', size=15, rotation=0, labelpad=45)
axes.legend(loc=0, fontsize=12)
plt.yticks(fontsize=14)
plt.xticks(fontsize=14)
plt.tight_layout()
plt.subplots_adjust(left=0.2, right = 0.9, bottom=0.2, top= 0.85, wspace=0.45, hspace = 0.5)


# -----Split here-----


'''
Save and display the plot
'''
plt.savefig('esarsa-median.png',dpi=500)
plt.show()