Skip to content

Commit

Permalink
return plt object directly, pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
GreatYYX committed Jul 18, 2018
1 parent 0d0c295 commit 2850133
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions rltk/evaluation/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rltk.evaluation.trial import Trial
import matplotlib.pyplot as plt


class Evaluation(object):
def __init__(self, trial_list: list = None):
if not trial_list:
Expand All @@ -15,6 +16,7 @@ def add_trial(self, trial: Trial):
x (list): list of x coordinates
y (list): list of y coordinates
"""

def auc(self, x, y):
coords = sorted([(x[i], y[i]) for i in range(len(x))])
coords_reduced = dict()
Expand All @@ -41,7 +43,7 @@ def auc(self, x, y):
y1 = value
first = False
return [auc, list(coords_reduced.keys()), list(coords_reduced.values())]

def plot(self, parameter_list, label_max=False, label_min=False, auc_params=None, aoc_params=None):
"""
Args:
Expand Down Expand Up @@ -96,17 +98,19 @@ def plot(self, parameter_list, label_max=False, label_min=False, auc_params=None

if label_max:
global_max = max([(x[i], y[i]) for i in range(len(x))], key=lambda i: (i[1], -i[0]))
plt.annotate("(" + ("%.3f" % global_max[0]) + ", " + ("%.3f" % global_max[1]) + ")", xy = (global_max[0] - 0.1, global_max[1] + 0.05))

plt.annotate("(" + ("%.3f" % global_max[0]) + ", " + ("%.3f" % global_max[1]) + ")",
xy=(global_max[0] - 0.1, global_max[1] + 0.05))

if label_min:
global_min = min([(x[i], y[i]) for i in range(len(x))], key=lambda i: (i[1], -i[0]))
plt.annotate("(" + ("%.3f" % global_min[0]) + ", " + ("%.3f" % global_min[1]) + ")", xy = (global_min[0] - 0.1, global_min[1] - 0.05))
plt.annotate("(" + ("%.3f" % global_min[0]) + ", " + ("%.3f" % global_min[1]) + ")",
xy=(global_min[0] - 0.1, global_min[1] - 0.05))

if auc_params != None:
vals = self.auc(x, y)
auc = vals[0]
area_label = 'AUC: ' + ('%.5f' % auc)
plt.annotate(area_label, xy = (auc_params[0], auc_params[1]))
plt.annotate(area_label, xy=(auc_params[0], auc_params[1]))

if auc_params[2]:
x_vals = vals[1]
Expand All @@ -117,7 +121,7 @@ def plot(self, parameter_list, label_max=False, label_min=False, auc_params=None
vals = self.auc(x, y)
aoc = 1 - vals[0]
area_label = 'AOC: ' + ('%.5f' % aoc)
plt.annotate(area_label, xy = (aoc_params[0], aoc_params[1]))
plt.annotate(area_label, xy=(aoc_params[0], aoc_params[1]))

if aoc_params[2]:
x_vals = vals[1]
Expand All @@ -126,8 +130,8 @@ def plot(self, parameter_list, label_max=False, label_min=False, auc_params=None

plt.xlim(0, 1.05)
plt.ylim(0, 1.05)
plt.show()

return plt

def plot_precision_recall(self):
return self.plot([{
Expand Down

0 comments on commit 2850133

Please sign in to comment.