In [None]:
import os
from collections import Counter
from pathlib import Path
from typing import Iterable, List, Tuple

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from smiles2actions.utils import load_list_from_file, colorblind_color_palette

# Data distribution according to reaction classes

### Load the data

In [None]:
s2a_dir = Path(os.environ['S2A_PAPER_DATA_DIR'])

In [None]:
original_classes = load_list_from_file(s2a_dir / 'rxn_classes_original_data.txt')
dataset_classes = load_list_from_file(s2a_dir / 'rxn_classes_unique.txt')
dataset_classes_with_duplicates = load_list_from_file(s2a_dir / 'rxn_classes_all.txt')
train_classes = load_list_from_file(s2a_dir / 'rxn_classes_train.txt')
valid_classes = load_list_from_file(s2a_dir / 'rxn_classes_valid.txt')
test_classes = load_list_from_file(s2a_dir / 'rxn_classes_test.txt')

### Plot superclass distribution

In [None]:
def get_superclass(rxn_class: str) -> str:
    """Get the superclass ("3") from the full reaction class ("3.2.45")."""
    return rxn_class.split('.')[0]

In [None]:
def show_classes_distribution(
    reaction_classes: List[Tuple[str, List[str]]], filename: str
) -> None:
    set_names = [v[0] for v in reaction_classes]
    main_classes = [[get_superclass(rxn_class) for rxn_class in v[1]] for v in reaction_classes]

    # ### Plot the superclass distribution

    ind = np.arange(0, 13)
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 5))
    ax.hist(
        main_classes,
        ind,
        label=set_names,
        color=colorblind_color_palette(len(main_classes)),
        density=True
    )
    ax.legend(loc='upper right')
    ax.set_xlabel('Reaction superclass')
    ax.set_ylabel('Frequency')
    ax.set_xticks(ind + 0.5)
    ax.set_xticklabels(ind)
    ax.margins(x=0.02)
    fig.tight_layout()
    plt.savefig(filename)

First, all the data including the different splits.

In [None]:
classes_to_show = [
    ('Original reaction data', original_classes),
    ('All reactions (with duplicates)', dataset_classes_with_duplicates),
    ('All reactions (without duplicates)', dataset_classes),
    ('Train split', train_classes),
    ('Validation split', valid_classes),
    ('Test split', test_classes),
]

In [None]:
show_classes_distribution(classes_to_show, '/tmp/classes_distribution')

Second, the 500 reactions assessed by the chemist. They are the 500 first reactions in the test set.

In [None]:
classes_to_show = [
    ('All reactions (without duplicates)', dataset_classes),
    ('Test split', test_classes),
    ('Subset of 500 reactions', test_classes[:500]),
]

In [None]:
show_classes_distribution(classes_to_show, '/tmp/classes_distribution_500.pdf')

### Classes present or absent from splits

To check what classes are missing from the test split or other splits.

In [None]:
all_unique_classes = set(original_classes)
print('All classes', len(all_unique_classes))
classes_in_dataset = set(dataset_classes)
counter_in_dataset = Counter(dataset_classes)
print('Classes in dataset', len(classes_in_dataset))

In [None]:
def info_not_in_split(classes_for_subset: Iterable[str]) -> None:
    not_in_split = classes_in_dataset - set(classes_for_subset)
    counts = []
    for rxn_class in not_in_split:
        count = counter_in_dataset[rxn_class]
        print(f'{rxn_class} - Original count in dataset of {len(dataset_classes)}: {count}')
        counts.append(count)
    print('Number of missing classes', len(counts))
    print(f'Average count in the original dataset of {len(dataset_classes)}: {np.mean(counts)}')

In [None]:
print('Train')
info_not_in_split(train_classes)

In [None]:
print('Valid')
info_not_in_split(valid_classes)

In [None]:
print('Test')
info_not_in_split(test_classes)