In [None]:
import collections, re
from pathlib import Path
from matplotlib import pyplot as plt


PROJECT_DIR = Path('..')
INPUT_DIR = PROJECT_DIR / 'input' / 'bms-molecular-translation'
TRAIN_LABELS_PATH = INPUT_DIR / 'train_labels.csv'

In [None]:
def count_elements(input_file_path):
    line_regex = re.compile('^[0-9a-f]+,"?InChI=1S/([^/]+)')
    element_regex = re.compile('([A-Z][a-z]?)[0-9]*')
    formulas = []
    formula_count = 0
    with open(input_file_path, 'r') as f:
        first_line = True
        for line in f.readlines():
            if first_line:
                # skip the header line
                first_line = False
                continue
            match = line_regex.match(line)
            if not match:
                print('Warning - line not matched:', line)
                continue
            formula = match.groups(0)[0]
            formula_count += 1
            formulas.append(formula)
    # Count the elements: increment an element's count each time an element occurs in a formula.
    element_counts = collections.Counter()
    for formula in formulas:
        elements = element_regex.findall(formula)
        if not 'C' in elements:
            print('NOTE: no carbon in:', formula)
        element_counts.update(elements)
    # Convert counts to a list of (element, frequency) pairs.
    element_counts = list(element_counts.items())
    # Sort counts by descending frequency.
    element_counts = sorted(element_counts, key=lambda pair: -pair[1])
    return formula_count, element_counts


TRAIN_FORMULA_COUNT, TRAIN_ELEMENT_COUNTS = count_elements(TRAIN_LABELS_PATH)
print(f'Training data element counts: {TRAIN_ELEMENT_COUNTS} in {TRAIN_FORMULA_COUNT} formulas.')

In [None]:
def show_element_counts(element_counts, title, log=False):
    elements = [pair[0] for pair in element_counts]
    counts = [pair[1] for pair in element_counts]
    if not log:
        counts = [count/10**6 for count in counts]
    y_pos = [-i for i, _ in enumerate(counts)]
    plt.figure(figsize=(10, 8), facecolor='#eef')
    plt.title(title)
    plt.xlabel('No. of formulas' + (not log and ' (millions)' or ''))
    plt.ylabel('Element')
    plt.barh(y_pos, counts, log=log, fc='gold')
    plt.yticks(y_pos, elements)
    plt.show()
    print('Done.')
    return


show_element_counts(TRAIN_ELEMENT_COUNTS, 'Training Data Element Counts', log=False)