Skip to content

Commit

Permalink
Merge pull request #15 from preinaj/save-predictions-in-evaluator
Browse files Browse the repository at this point in the history
New evaluator feature: save predictions
  • Loading branch information
pedrolarben committed Jan 18, 2022
2 parents e788e15 + ce274b3 commit 3671fff
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions ADLStream/evaluation/base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ def evaluate(self):

def __init__(
self,
results_file="ADLStream.csv",
results_file=None,
predictions_file=None,
dataset_name=None,
show_plot=True,
plot_file=None,
ylabel="",
):
self.results_file = results_file
self.dataset_name = dataset_name
self.predictions_file = predictions_file
self.show_plot = show_plot
self.plot_file = plot_file
self.ylabel = ylabel
Expand All @@ -104,6 +106,14 @@ def _create_results_file(self):

def start(self):
self.visualizer.start()
if self.predictions_file:
self.predictions_file = open(self.predictions_file, "a")
if self.results_file:
self.results_file = open(self.results_file, "a")

def end(self):
self.predictions_file.close()
self.results_file.close()

@abstractmethod
def evaluate(self):
Expand All @@ -127,15 +137,19 @@ def evaluate(self):

def write_results(self, new_results, instances):
if self.results_file is not None:
with open(self.results_file, "a") as f:
for i, value in enumerate(new_results):
f.write(
"{},{},{}\n".format(
str(datetime.now()),
instances[i],
value,
)
for i, value in enumerate(new_results):
self.results_file.write(
"{},{},{}\n".format(
str(datetime.now()),
instances[i],
value,
)
)

def write_predictions(self, preds):
if self.predictions_file is not None:
for _, prediction in enumerate(preds):
self.predictions_file.write(f"{','.join(map(str, prediction))}\n")

def update_plot(self, new_results, instances):
if self.show_plot or self.plot_file is not None:
Expand All @@ -151,6 +165,7 @@ def update_predictions(self, context):
self.x_eval += x
self.y_eval += y
self.o_eval += o
self.write_predictions(o)

def run(self, context):
"""Run evaluator
Expand All @@ -173,3 +188,4 @@ def run(self, context):
self.visualizer.savefig(self.plot_file)
if self.show_plot:
self.visualizer.show()
self.end()

0 comments on commit 3671fff

Please sign in to comment.