In [1]:
import json, os, numpy as np, requests
import matplotlib.pyplot as plt
import heapq
import pickle
import random
from collections import defaultdict
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem

In [2]:
hsqc_pickle = pickle.load(open("SMART_dataset_v2_new.pkl", "rb"))

In [3]:
# generate mapping of cannonical smiles -> hsqc
smiles_to_hsqc = defaultdict(list)
for item in tqdm(hsqc_pickle.values()):
    mol = Chem.MolFromSmiles(item["SMILES"])
    cannonical_smiles = Chem.MolToSmiles(mol)
    smiles_to_hsqc[cannonical_smiles].append(np.array(item["HSQC"]))   

100%|█████████████████████████████████| 137267/137267 [00:20<00:00, 6703.25it/s]


In [4]:
all_gnps = json.load(open("ALL_GNPS.json"))
print(len(all_gnps))

210569


In [5]:
# Remove samples without mass spec
all_gnps = [i for i in all_gnps if i["peaks_json"] != "[]"]
print(len(all_gnps))

209654


In [6]:
# Remove unecessary metadata
ok_keys = set(["spectrum_id", "peaks_json", "Ion_Mode", "Smiles", "Adduct", "Precursor_MZ"])
all_gnps = [{key:item[key] for key in ok_keys} for item in all_gnps]

In [7]:
%%capture --no-stdout

# Remove invalid SMILES 
all_gnps = [i for i in all_gnps if Chem.MolFromSmiles(i["Smiles"]) is not None]
print(len(all_gnps))

98620


In [8]:
# Only use positive ion mode
all_gnps = [i for i in all_gnps if i["Ion_Mode"].lower() == "positive"]
print(len(all_gnps))

69294


In [9]:
# Delete MS with < threshold peaks and normalize by highest peak
threshold = 5
max_peaks = 150
remove = set()
for item in tqdm(all_gnps):
    peaks = json.loads(item["peaks_json"])
    if len(peaks) < threshold:
        remove.add(item["spectrum_id"])
        continue
    
    m = max(peaks, key=lambda p: p[1])
    peaks = np.array([[x, (y/m[1])**0.5] for x, y in peaks])
    if len(peaks) > max_peaks:
        # only preserve max_peaks highest peaks
        inds = np.argpartition(peaks[:, 1], -max_peaks)
        peaks = peaks[inds][-max_peaks:]
    item["peaks_json"] = peaks
    

all_gnps = [i for i in all_gnps if i["spectrum_id"] not in remove]
print(len(all_gnps))

100%|███████████████████████████████████| 69294/69294 [00:42<00:00, 1623.60it/s]

57938





In [10]:
# generate mapping of cannonical smiles -> spectra
smiles_to_spectra = defaultdict(list)
for item in tqdm(all_gnps):
    mol = Chem.MolFromSmiles(item["Smiles"])
    cannonical_smiles = Chem.MolToSmiles(mol)   
    smiles_to_spectra[cannonical_smiles].append(item["peaks_json"])

100%|███████████████████████████████████| 57938/57938 [00:06<00:00, 9405.69it/s]


In [11]:
# get smiles strings with both ms and hsqc data
inter = set(smiles_to_spectra.keys()).intersection(set(smiles_to_hsqc.keys()))

In [12]:
# create mapping of sample id -> sample
pairs = defaultdict(dict)
i = 0
for smiles in tqdm(inter):
    mol = Chem.MolFromSmiles(smiles)
    fp = np.array(AllChem.GetMorganFingerprintAsBitVect(mol,3,nBits=6144))
    spectra = smiles_to_spectra[smiles]
    random.shuffle(spectra)
    hsqc = smiles_to_hsqc[smiles]
    random.shuffle(hsqc)
    while len(hsqc) > 0 and len(spectra) > 0:
        pairs[i]["HSQC"] = hsqc.pop()
        pairs[i]["MS"] = spectra.pop()
        pairs[i]["FP"] = fp
        pairs[i]["SMILES"] = smiles
        i += 1

100%|████████████████████████████████████| 11834/11834 [00:18<00:00, 629.88it/s]


In [13]:
# create pair dataset split
pair_ids = list(pairs.keys())
pair_dataset = {"train":{}, "val":{}, "test":{}}
random.shuffle(pair_ids)
split_index = int(len(pair_ids)*.15) 
for pair_id in pair_ids[:split_index]:
    pair_dataset["test"][pair_id] = pairs[pair_id]
for pair_id in pair_ids[split_index:split_index*2]:
    pair_dataset["val"][pair_id] = pairs[pair_id] 
for pair_id in pair_ids[split_index*2:]:
    pair_dataset["train"][pair_id] = pairs[pair_id]


In [14]:
# preventing data leakage: keep splits consistent
ms_dataset = defaultdict(lambda: defaultdict(dict))
hsqc_dataset = defaultdict(lambda: defaultdict(dict))
for split, samples in pair_dataset.items():
    for sample_num, sample in samples.items():
        hsqc_dataset[split][sample_num] = sample
        ms_dataset[split][sample_num] = sample

In [15]:
# everything remaining in smiles_to_hsqc does not have a corresponding pair available
# create splits for remaining HSQC
total_hsqc_size = sum([len(hsqc_list) for hsqc_list in smiles_to_hsqc.values()])
total_hsqc_size += len(hsqc_dataset["train"])
total_hsqc_size += len(hsqc_dataset["val"])
total_hsqc_size += len(hsqc_dataset["test"])
split_index = int(total_hsqc_size*.15) 
dict_items = list(smiles_to_hsqc.items())
random.shuffle(dict_items)
for smiles, hsqc_list in tqdm(dict_items):
    mol = Chem.MolFromSmiles(smiles)
    fp = np.array(AllChem.GetMorganFingerprintAsBitVect(mol,3,nBits=6144))
    for hsqc in hsqc_list:
        if len(hsqc_dataset["val"]) < split_index:
            split = "val"
        elif len(hsqc_dataset["test"]) < split_index:
            split = "test"
        else:
            split = "train"
        
        hsqc_dataset[split][i]["HSQC"] = hsqc
        hsqc_dataset[split][i]["SMILES"] = smiles
        hsqc_dataset[split][i]["FP"] = fp
        i += 1
    
print("test: ", len(hsqc_dataset["test"]))
print("val: ", len(hsqc_dataset["val"]))
print("train: ", len(hsqc_dataset["train"]))

100%|██████████████████████████████████| 129874/129874 [03:28<00:00, 622.38it/s]

test:  20590
val:  20590
train:  96087





In [16]:
# everything remaining in smiles_to_spectra does not have a corresponding pair available
# create splits for reamining spectra
total_ms_size = sum([len(spectra_list) for spectra_list in smiles_to_spectra.values()])
total_ms_size += len(ms_dataset["train"])
total_ms_size += len(ms_dataset["val"])
total_ms_size += len(ms_dataset["test"])
split_index = int(total_ms_size*.15) 
dict_items = list(smiles_to_spectra.items())
random.shuffle(dict_items)
for smiles, spectra_list in tqdm(dict_items):
    mol = Chem.MolFromSmiles(smiles)
    fp = np.array(AllChem.GetMorganFingerprintAsBitVect(mol,3,nBits=6144))
    for spectra in spectra_list:
        if len(ms_dataset["val"]) < split_index:
            split = "val"
        elif len(ms_dataset["test"]) < split_index:
            split = "test"
        else:
            split = "train"
        
        ms_dataset[split][i]["MS"] = spectra
        ms_dataset[split][i]["SMILES"] = smiles
        ms_dataset[split][i]["FP"] = fp
        i += 1
    
print("test: ", len(ms_dataset["test"]))
print("val: ", len(ms_dataset["val"]))
print("train: ", len(ms_dataset["train"]))

100%|████████████████████████████████████| 14662/14662 [00:23<00:00, 628.86it/s]

test:  8690
val:  8690
train:  40558





In [22]:
with open("data/ms_pretrain/train.pkl", "wb") as f:
    pickle.dump(ms_dataset["train"], f, protocol=pickle.HIGHEST_PROTOCOL)
with open("data/ms_pretrain/val.pkl", "wb") as f:
    pickle.dump(ms_dataset["val"], f, protocol=pickle.HIGHEST_PROTOCOL)
with open("data/ms_pretrain/test.pkl", "wb") as f:
    pickle.dump(ms_dataset["test"], f, protocol=pickle.HIGHEST_PROTOCOL)

In [23]:
with open("data/hsqc_pretrain/train.pkl", "wb") as f:
    pickle.dump(hsqc_dataset["train"], f, protocol=pickle.HIGHEST_PROTOCOL)
with open("data/hsqc_pretrain/val.pkl", "wb") as f:
    pickle.dump(hsqc_dataset["val"], f, protocol=pickle.HIGHEST_PROTOCOL)
with open("data/hsqc_pretrain/test.pkl", "wb") as f:
    pickle.dump(hsqc_dataset["test"], f, protocol=pickle.HIGHEST_PROTOCOL)

In [24]:
with open("data/hsqc_ms_pairs/train.pkl", "wb") as f:
    pickle.dump(pair_dataset["train"], f, protocol=pickle.HIGHEST_PROTOCOL)
with open("data/hsqc_ms_pairs/val.pkl", "wb") as f:
    pickle.dump(pair_dataset["val"], f, protocol=pickle.HIGHEST_PROTOCOL)
with open("data/hsqc_ms_pairs/test.pkl", "wb") as f:
    pickle.dump(pair_dataset["test"], f, protocol=pickle.HIGHEST_PROTOCOL)