Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions cascade/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import ast
import numpy as np
import pandas as pd
from cascade.utility.metrics import Metrics


def load_submission_data(submission_path):
"""
Extract necessary data for model evaluation from the submitted csv file.

Args:
submission_path (str): Complete path to the submission file.

Returns:
tuple: Contains:
- trial indices (1D array)
- image IDs (1D array)
- neuron IDs (1D array)
- predictions (2d array: trials x neurons)
"""
submission_df = pd.read_csv(submission_path)
trial_idx = submission_df["trial_indices"].values
image_ids = submission_df["image_ids"].values
neuron_ids = np.array(ast.literal_eval(submission_df["neuron_ids"].values[0]))
predictions = np.array(
[ast.literal_eval(v) for v in submission_df["prediction"].values]
)

return trial_idx, image_ids, neuron_ids, predictions


def load_groundtruth_data(groundtruth_path):
"""
Extract necessary data for model evaluation from the ground truth data file.

Args:
groundtruth_path (str): Absolute path to the ground truth data file.

Returns:
tuple: Contains:
- trial indices (1D array)
- image IDs (1D array)
- neuron IDs (1D array)
- responses (2d array: trials x neurons)
"""
raise NotImplementedError()


def evaluate(submission_path, ground_truth_path):
"""
Compute evaluation metrics for a specific submission given the ground truth data.

Args:
submission_path (str): Absolute path to the submission csv file.
ground_truth_path (str): Absolute path to the ground truth data file.

Returns:
dict: Containing all the evaluation results for all the evaluation metrics.
"""
trial_idx_gt, image_ids_gt, neuron_ids_gt, responses = load_groundtruth_data(
ground_truth_path
)
(
trial_idx_submitted,
image_ids_submitted,
neuron_ids_submitted,
predictions,
) = load_submission_data(submission_path)

metric = Metrics(responses, trial_idx_gt, image_ids_gt, neuron_ids_gt)

output = {}
output["Correlation (single trial)"] = metric.correlation_to_single_trials(
predictions, trial_idx_submitted, neuron_ids_submitted, per_neuron=False
)
output["Correlation (mean)"] = metric.correlation_to_mean_across_repeats(
predictions, trial_idx_submitted, neuron_ids_submitted, per_neuron=False
)
output["FEVE"] = metric.feve(
predictions, trial_idx_submitted, neuron_ids_submitted, per_neuron=False
)
return output