Import libraries

In [1]:
import random
import collections
import numpy as np  
from pprint import pprint 
import matplotlib.pyplot as plt 

from config import * 
from utils import get_all_spectra

Get some nice color palettes

In [2]:
colors = ["#120702", "#280706", "#3B0C19", "#421025", "#45112B", "#4D1434", "#612262", "#60296B", "#593678", "#534582", "#55598B", "#687792", "#7E8E95", "#93999A", "#A3A3A3", "#B0B0B0", "#C2C2C2"]

Define some helper functions

Get the number of repeated MS for the same experimental settings. Experimental hyper-parameters considered include: 
1. Adduct 
2. Collision energy 
3. Instrument

In [3]:
def get_n_repeat(data):

    counts = {} 

    for r in data: 
        inchikey = r.metadata["inchikey"][:14]
        adduct = r.metadata["adduct"]
        energy = r.metadata["collision_energy"]
        instrument = r.metadata["instrument_type"]

        key = f"{inchikey}_{adduct}_{energy}_{instrument}"

        if key not in counts: counts[key] = 0 
        counts[key] += 1 
    
    return counts

Get the necessary breakdowns

In [4]:
def get_statistics(data, subset = None):
    
    if subset is not None: data = [r for r in data if r.metadata["dataset"] == subset]
    data_stats = {}

    # Get the first 14 characters of the molecules
    inchikey_no_stero = [r.metadata["inchikey"][:14] for r in data]
    unique_inchikey_breakdown = collections.Counter(inchikey_no_stero)

    # Get the instruments breakdown 
    instruments = [r.metadata["instrument_type"] for r in data]
    unique_instruments_breakdown = collections.Counter(instruments)

    # Get the adduct breakdown 
    adducts = [r.metadata["adduct"] for r in data]
    unique_adducts_breakdown = collections.Counter(adducts)

    # Get the energy breakdown 
    energies = [r.metadata["collision_energy"] for r in data]
    energy_breakdown = collections.Counter(energies)

    # Get the superclass breakdown 
    superclasses = [r.metadata["superclass"] for r in data]
    superclass_breakdown = collections.Counter(superclasses)

    # Get the molecules breakdown 
    n_repeats = get_n_repeat(data)

    # Update the stats 
    data_stats["n_records"] = len(data)
    data_stats["n_unique_molecules"] = len(unique_inchikey_breakdown)
    data_stats["n_unique_instruments"] = len(unique_instruments_breakdown)
    data_stats["n_unique_adducts"] = len(unique_adducts_breakdown)
    data_stats["n_unique_energy"] = len(energy_breakdown)

    data_stats["inchikey_breakdown"] = unique_inchikey_breakdown
    data_stats["instruments_breakdown"] = unique_instruments_breakdown
    data_stats["adducts_breakdown"] = unique_adducts_breakdown
    data_stats["energy_breakdown"] = energy_breakdown 
    data_stats["superclass_breakdown"] = superclass_breakdown 

    data_stats["n_repeats"] = n_repeats

    return data_stats

Main

In [5]:
data_path = os.path.join(final_data_folder, "final_data.msp")
data = get_all_spectra(data_path)
if not os.path.exists(plots_folder): os.makedirs(plots_folder)

# Get the unique datasets
unique_datasets = list(set([r.metadata["dataset"] for r in data]))

# Get the statistics for each dataset
data_stats_merged = {} 

for d in unique_datasets:
    data_stats = get_statistics(data, d)
    data_stats_merged[d] = data_stats

# Get the total datasets
data_stats_merged["total"] = get_statistics(data)

224361it [08:07, 460.05it/s]


KeyboardInterrupt: 

1. Get dataset breakdown

In [None]:
index = list(range(len(colors)))
random.shuffle(index)
random_colors = [colors[i] for i in index]

for x, (k, i) in enumerate(data_stats_merged.items()):

    c = i["n_records"] 
    plt.bar(k, c, color = random_colors[x])
    plt.text(x - 0.15, c + 5, c)

plt.xlabel("Datasets")
plt.ylabel("Number of records")
plt.title("Number of MS records for each dataset")
plt.savefig(os.path.join(plots_folder, "dataset_breakdown.jpg"))

2. Get breakdown of adducts

In [None]:
unique_adducts = list(data_stats_merged["total"]["adducts_breakdown"].keys())[::-1]
unique_datasets = list(data_stats_merged.keys())

assert len(colors) >= len(unique_adducts)

index = list(range(len(colors)))
random.shuffle(index)
random_colors = [colors[i] for i in index]

x = np.arange(len(unique_datasets))  # the label locations
width = 0.25  # the width of the bars
multiplier = 0

fig, ax = plt.subplots()

for i, a in enumerate(unique_adducts):
    offset = width * multiplier
    counts = [data_stats_merged[d]["adducts_breakdown"][a] for d in unique_datasets]
    rects = ax.bar(x + offset, counts, width, label=a, color = random_colors[i])
    ax.bar_label(rects, padding=3)
    multiplier += 1

ax.set_ylabel('Number of records')
ax.set_xticks(x + width, unique_datasets)
ax.set_title('Breakdown of adduct for each dataset')

ax.legend(loc='upper left', ncols=3)

plt.legend()
plt.show()
plt.savefig(os.path.join(plots_folder, "adducts_ind_datasets.jpg"))

In [None]:
index = list(range(len(colors)))
random.shuffle(index)
random_colors = [colors[i] for i in index]
total_adducts = data_stats_merged["total"]["adducts_breakdown"]

for x, (k, c) in enumerate(total_adducts.items()):

    plt.bar(k, c, color = random_colors[x])
    plt.text(x - 0.15, c + 5, c)

plt.xlabel("Adduct")
plt.ylabel("Number of records")
plt.title("Number of MS records for each adduct")
plt.savefig(os.path.join(plots_folder, "adducts_total.jpg"))

3. Get breakdown of instruments

In [None]:
unique_instruments = list(data_stats_merged["total"]["instruments_breakdown"].keys())[::-1]
unique_datasets = list(data_stats_merged.keys())

assert len(colors) >= len(unique_instruments)

index = list(range(len(colors)))
random.shuffle(index)
random_colors = [colors[i] for i in index]

x = np.arange(len(unique_datasets))  # the label locations
width = 0.25  # the width of the bars
multiplier = 0

fig, ax = plt.subplots()

for i, u in enumerate(unique_instruments):
    offset = width * multiplier
    counts = [data_stats_merged[d]["instruments_breakdown"][u] for d in unique_datasets]
    rects = ax.bar(x + offset, counts, width, label=u, color = random_colors[i])
    ax.bar_label(rects, padding=3)
    multiplier += 1

ax.set_ylabel('Number of records')
ax.set_xticks(x + width, unique_datasets)
ax.set_title('Breakdown of instruments for each dataset')

ax.legend(loc='upper left', ncols=3)

plt.legend()
plt.show()
plt.savefig(os.path.join(plots_folder, "instruments_ind_datasets.jpg"))

In [None]:
index = list(range(len(colors)))
random.shuffle(index)
random_colors = [colors[i] for i in index]
total_instruments = data_stats_merged["total"]["instruments_breakdown"]

for x, (k, c) in enumerate(total_instruments.items()):

    plt.bar(k, c, color = random_colors[x])
    plt.text(x - 0.15, c + 5, c)

plt.xticks(rotation = 45)
plt.ylabel("Number of records")
plt.title("Number of MS records for each instrument type")
plt.savefig(os.path.join(plots_folder, "instrument_type_total.jpg"))

4. Get breakdown of classes

In [None]:
unique_superclasses = list(data_stats_merged["total"]["superclass_breakdown"].keys())[::-1]
unique_datasets = list(data_stats_merged.keys())

assert len(colors) >= len(unique_superclasses)

index = list(range(len(colors)))
random.shuffle(index)
random_colors = [colors[i] for i in index]

x = np.arange(len(unique_datasets))  # the label locations
width = 0.25  # the width of the bars
multiplier = 0

fig, ax = plt.subplots()

for i, a in enumerate(unique_superclasses):
    offset = width * multiplier
    counts = [data_stats_merged[d]["superclass_breakdown"][a] for d in unique_datasets]
    rects = ax.bar(x + offset, counts, width, label=a, color = random_colors[i])
    ax.bar_label(rects, padding=3)
    multiplier += 1

ax.set_ylabel('Number of records')
ax.set_xticks(x + width, unique_datasets)
ax.set_title('Breakdown of molecule classes for each dataset')

ax.legend(loc='upper left', ncols=3)

plt.legend()
plt.show()
plt.savefig(os.path.join(plots_folder, "molecule_class_ind_datasets.jpg"))

In [None]:
index = list(range(len(colors)))
random.shuffle(index)
random_colors = [colors[i] for i in index]
total_superclasses = data_stats_merged["total"]["superclass_breakdown"]

for x, (k, c) in enumerate(total_superclasses.items()):

    plt.bar(k, c, color = random_colors[x])
    plt.text(x - 0.15, c + 5, c)

plt.xticks(rotation = 90)
plt.ylabel("Number of records")
plt.title("Number of MS records for each molecule class")
plt.savefig(os.path.join(plots_folder, "molecule_class_total.jpg"))