In [None]:
import os
import typing
from collections import Counter
from pathlib import Path
from typing import Dict, List, Any

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
from smiles2actions.utils import load_list_from_file

In [None]:
%matplotlib inline

# Analyze reaction class frequencies between original Pistachio and S2A dataset

### Load data

In [None]:
s2a_dir = Path(os.environ['S2A_PAPER_DATA_DIR'])

In [None]:
original_classes = load_list_from_file(s2a_dir / 'rxn_classes_original_data.txt')
dataset_classes = load_list_from_file(s2a_dir / 'rxn_classes_unique.txt')

In [None]:
s2a_class_counts = Counter(dataset_classes)
pis_class_counts = Counter(original_classes)
s2a_total_count = sum(s2a_class_counts.values())
pis_total_count = sum(pis_class_counts.values())

### Functionality to get pandas DataFrames

In [None]:
def merge_counter_down(counter: typing.Counter[str], class_level: int) -> typing.Counter[str]:
    """
    Merge the counters by granularity of the classes, f.i. merge all 5.2.X into 5.2.
    """
    new_counter: typing.Counter[str] = Counter()

    def simplify(original_rxn_class: str) -> str:
        building_blocks = original_rxn_class.split('.')
        return '.'.join(building_blocks[:class_level])

    for key, value in counter.items():
        new_counter[simplify(key)] += value

    return new_counter

In [None]:
def get_into_pandas(class_level: int) -> pd.DataFrame:
    """
    Get the counts into a pandas DataFrame.

    Args:
        class_level: What degree of fineness to keep: 1->5, 2->5.2, 3->5.2.8.
    """
    s2a_counts = merge_counter_down(s2a_class_counts, class_level)
    pis_counts = merge_counter_down(pis_class_counts, class_level)

    all_classes = set(s2a_counts.keys()) | set(pis_counts.keys())

    all_classes_sorted = sorted(all_classes, key=lambda x: tuple(int(k) for k in x.split('.')))

    data = [
        (rxn_class, pis_counts[rxn_class], s2a_counts[rxn_class])
        for rxn_class in all_classes_sorted
    ]

    df = pd.DataFrame(data, columns=['reaction class', 'Count in Pistachio', 'Count in s2a'])
    df['Frequency in Pistachio'] = df['Count in Pistachio'] / pis_total_count
    df['Frequency in s2a'] = df['Count in s2a'] / s2a_total_count
    return df

### Get and print some of the datafrmes

In [None]:
df_name_rxn = get_into_pandas(3)
df_category = get_into_pandas(2)
df_superclass = get_into_pandas(1)

In [None]:
df_superclass

In [None]:
df_category

In [None]:
df_name_rxn

### Prepare the plot

In [None]:
def original_count_category(df_row) -> str:
    count = df_row['Count in Pistachio']
    assert count > 0
    if count < 10:
        return '1–9'
    elif count < 100:
        return '10–99'
    elif count < 1000:
        return '100–999'
    else:
        return '>1000'

In [None]:
def difference_category(df_row) -> str:
    pis_freq = df_row['Frequency in Pistachio']
    s2a_freq = df_row['Frequency in s2a']
    assert pis_freq > 0

    enrichment = s2a_freq / pis_freq
    if s2a_freq == 0:
        return '–100%'
    elif enrichment < 0.5:
        return '–100% to –50%'
    elif enrichment < 0.75:
        return '–50% to –25%'
    elif enrichment < 1.25:
        return '–25% to +25%'
    elif enrichment < 1.50:
        return '+25% to +50%'
    else:
        return '>+50%'

In [None]:
counts = ['1–9', '10–99', '100–999', '>1000']
enrichments = ['–100%', '–100% to –50%', '–50% to –25%', '–25% to +25%', '+25% to +50%', '>+50%']

In [None]:
res: Dict[str, Dict[str, List[Any]]] = {}
for c in counts:
    res[c] = {}
    for e in enrichments:
        res[c][e] = []

In [None]:
for (idx, row) in df_name_rxn.iterrows():
    res[original_count_category(row)][difference_category(row)].append(row['reaction class'])

In [None]:
m = np.zeros((len(counts), len(enrichments)))
for c_index, c in enumerate(counts):
    for e_index, e in enumerate(enrichments):
        m[c_index, e_index] = len(res[c][e])

In [None]:
print('Total classes:', m.sum())
print(m)

In [None]:
# among the classes of 100 or more, how many reduced frequency by more than 50%
print(1 - m[2:4, 0:2].sum() / m[2:4, :].sum())

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
sns.heatmap(
    m,
    ax=ax,
    annot=True,
    fmt='g',
    cmap='Blues',
    vmin=0,
    vmax=100,
    xticklabels=enrichments,
    yticklabels=counts
)
plt.yticks(rotation=0)
plt.xticks(rotation=0)
plt.tight_layout()
ax.set_xlabel('Change in reaction class prevalence')
ax.set_ylabel('Original count')
plt.tight_layout()
plt.savefig('/tmp/class_prevalence.pdf')