In [None]:
import os
from itertools import cycle
from pathlib import Path
from typing import List

In [None]:
import attr
import matplotlib.pyplot as plt
import numpy as np
from paragraph2actions.analysis import partial_accuracy

In [None]:
from smiles2actions.utils import load_list_from_file

# Evaluate metrics according to the reaction classes

### File locations

In [None]:
s2a_dir = Path(os.environ['S2A_PAPER_DATA_DIR'])
tgt_file = str(s2a_dir / 'tgt-test.txt')
classes_file = str(s2a_dir / 'rxn_classes_test.txt')
transformer_file = str(s2a_dir / 'transformer_test.txt')

### Loading samples and subdivide into classes

In [None]:
rxn_classes = load_list_from_file(classes_file)
truths = load_list_from_file(tgt_file)
preds = load_list_from_file(transformer_file)

In [None]:
@attr.s(auto_attribs=True)
class Sample:
    rxn_class: str
    truth: str
    pred: str

In [None]:
assert len(truths) == len(rxn_classes) == len(preds)

In [None]:
all_samples = [
    Sample(rxn_class, truth, pred) for rxn_class, truth, pred in zip(rxn_classes, truths, preds)
]

In [None]:
samples_per_class: List[List[Sample]] = [[] for _ in range(12)]

In [None]:
for sample in all_samples:
    superclass = int(sample.rxn_class.split('.')[0])
    samples_per_class[superclass].append(sample)

### Compute the metrics

In [None]:
percentages = [50, 60, 70, 80, 90, 100]
percentage_labels = [f'{percentage}% accuracy' for percentage in percentages]
line_labels = [f'Superclass {i}' for i in range(12)]

In [None]:
print('Metrics on all the data')
print(' - 100% accuracy', partial_accuracy(truths, preds, 1.0))
print(' - 90% accuracy', partial_accuracy(truths, preds, 0.9))
print(' - 75% accuracy', partial_accuracy(truths, preds, 0.75))
print(' - 50% accuracy', partial_accuracy(truths, preds, 0.5))

In [None]:
print('Metrics for classes')
results = np.zeros((12, len(percentages)))
for superclass_index, class_samples in enumerate(samples_per_class):
    truth_for_class = [sample.truth for sample in class_samples]
    pred_for_class = [sample.pred for sample in class_samples]
    for percentage_index, percentage in enumerate(percentages):
        acc = partial_accuracy(truth_for_class, pred_for_class, percentage / 100)
        results[superclass_index, percentage_index] = acc

In [None]:
print(results)

In [None]:
lines = ["-", "--", ":"]
linecycler = cycle(lines)

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
for y_arr, label in zip(100 * results, list(range(12))):
    ax.plot(percentages, y_arr, next(linecycler), label=label)
ax.legend(loc='upper right')
ax.set_ylabel('Score (in %)')
ax.set_xticks(percentages)
ax.set_xticklabels(percentage_labels, rotation=20)
fig.tight_layout()
plt.savefig('/tmp/metrics_per_class.pdf')