# 03 forward plots
Making forward plots to verify reasonable spectra predictions for data augmentation

In [None]:
import pandas as pd
from pathlib import Path
from rdkit import Chem
from rdkit.Chem import Draw
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

from rdkit import Chem
from rdkit.Chem import Draw

In [None]:

import mist.utils  as utils
from mist.utils import plot_utils

plot_utils.set_style()

In [None]:
input_dir = "../data/paired_spectra/canopus_train/"
input_dir = Path(input_dir)
forward_folder = input_dir / "morgan4096_spec_preds_fold_100_0"
labels_file = input_dir / "labels.tsv"
k = 6
date = datetime.now().strftime("%Y_%m_%d")
res_dir = f"../results/{date}_output_forward_imgs"
res_dir_orig = Path(res_dir)
res_dir_orig.mkdir(exist_ok=True)

In [None]:
full_labels = pd.read_csv(labels_file, sep="\t")
forward_labels_file = forward_folder / "labels.tsv"
forward_labels = pd.read_csv(forward_labels_file, sep="\t", index_col=0)

In [None]:
name_to_full_smi = dict(full_labels[['spec', 'smiles']].values)
name_to_forward_smi = dict(forward_labels[['spec', 'smiles']].values)

name_to_full_ikey = dict(full_labels[['spec', 'inchikey']].values)
name_to_forward_ikey = dict(forward_labels[['spec', 'inchikey']].values)

ikey_to_forward = dict(forward_labels[['inchikey', 'spec']].values)
ikey_to_spec = dict(full_labels[['inchikey', 'spec']].values)
print(len(ikey_to_forward.keys()))
print(len(ikey_to_spec.keys()))

In [None]:
ikeys_forward = set(ikey_to_forward.keys())
ikeys_full = set(ikey_to_spec.keys())
ikey_overlap = ikeys_full.intersection(ikeys_forward)
print("Len of overlap", len(ikey_overlap))


In [None]:
split_file = "../data/paired_spectra/canopus_train/splits/canopus_hplus_100_0.csv"
split_df = pd.read_csv(split_file,)
split_df_test_names = split_df[split_df['Fold_100_0'] == "test"]
split_df_train_names = split_df[split_df['Fold_100_0'] == "train"]

test_names = split_df_test_names['name'].values
train_names = split_df_train_names['name'].values
test_ikeys = [name_to_full_ikey[i] for i in test_names]
train_ikeys = [name_to_full_ikey[i] for i in train_names]
print(len(ikey_overlap.intersection(train_ikeys)), len(train_ikeys))
print(len(ikey_overlap.intersection(test_ikeys)), len(test_ikeys))

In [None]:
# Sample forward ikey
test_name = np.random.choice(test_names)
test_ikey = name_to_full_ikey[test_name]
sample_spec = ikey_to_forward[test_ikey]
# sample_spec = np.random.choice(forward_labels['spec'].values)
sample_compound = name_to_forward_smi[sample_spec]
print(sample_compound)
Chem.MolFromSmiles(sample_compound)

In [None]:
pred_file = forward_folder / "spectra" / f"{sample_spec}.tsv"
spec = pd.read_csv(pred_file, sep="\t", index_col=0)

In [None]:
sirius_df = forward_folder.parent / "sirius_outputs/summary_statistics/summary_df.tsv"
sirius_df = pd.read_csv(sirius_df, sep="\t", index_col=0)
real_spec = forward_folder.parent / "spec_files" / f"{test_name}.ms"

spec_name_to_sirius = dict(sirius_df[['spec_name', "spec_file"]].values)
real_spec = spec_name_to_sirius[test_name]
real_df = pd.read_csv(Path("../") / real_spec, sep="\t")
mzs, intens = real_df[['mz', "rel.intensity"]]
#real_spec 

In [None]:
# spec_ar = utils.parse_spectra(real_spec)[1][0][1]
# spec_ar = np.array([j for j in sorted(spec_ar, key = lambda x : x[1])][-20:])
# spec_ar[:, 1] = spec_ar[:, 1] /  spec_ar[:, 1].max()
# mzs, intens = spec_ar[:, 0], spec_ar[:, 1]
fig = plt.figure(figsize=(3.5,1.7), dpi=300)
ax = fig.gca()
for ind, (m, i) in enumerate(zip(mzs, intens)): 
    ax.vlines(x=m, ymin=0,ymax=i, color="black", linewidth=0.5)
ax.set_xlabel("M/Z")
ax.set_ylabel("Intensity")
ax.set_ylim([0,1.2])
print(sample_compound)

In [None]:
mzs, intens = zip(*spec[["mz", "intensity"]].values)
fig = plt.figure(figsize=(3.5,1.7), dpi=300)
ax = fig.gca()
for ind, (m, i) in enumerate(zip(mzs, intens)): 
    ax.vlines(x=m, ymin=0,ymax=i, color="black", linewidth=0.5)

mzs, intens = real_df[['mz', "rel.intensity"]]
for ind, (m, i) in enumerate(zip(mzs, intens)): 
    ax.vlines(x=m, ymin=0,ymax=-float(i), color="red", linewidth=0.5)
ax.set_xlabel("M/Z")
ax.set_ylabel("Intensity")
ax.set_ylim([-1.2,1.2])
print(sample_compound)

In [None]:
# Create plot
res_dir.mkdir(exist_ok=True)
fig = plt.figure(figsize=(3.5,1.7), dpi=300)
ax = fig.gca()
full_out_smi = res_dir / f"{spec_name}_mol_full.pdf"

mol = Chem.MolFromSmiles(spec_smiles)
Chem.Kekulize(mol)
export_mol(mol, full_out_smi)
for ind, (smi, (m, i)) in enumerate(zip(smiles, zip(mzs, intens))): 
    ax.vlines(x=m, ymin=0,ymax=i, color="black", linewidth=0.5)
    if smi is not None and i in top_intens: 
        mol = Chem.MolFromSmiles(smi, sanitize=False)
        if mol is not None:
            try:
                Chem.Kekulize(mol)
            except:
                continue
            ax.text(x=m, y=i + 0.2, s = f"{ind}", fontsize=4, rotation=90) 
            temp_out_smi = res_dir / f"{spec_name}_mol_{ind}.pdf"
            export_mol(mol, temp_out_smi)

ax.set_xlabel("M/Z")
ax.set_ylabel("Intensity")
ax.set_ylim([0,1.5])
ax.set_title(f"Spectra: {spec_name}")
plt.savefig(res_dir / f"{spec_name}_out_spec.pdf", bbox_inches="tight")