In [117]:
import pandas as pd
from tqdm import tqdm
from rxnfp.transformer_fingerprints import (
    RXNBERTFingerprintGenerator,
    get_default_model_and_tokenizer,
    generate_fingerprints,
)
import numpy as np
import tmap as tm
import os
from rdkit.Chem import Descriptors
from scipy import stats
from matplotlib.colors import LinearSegmentedColormap
from faerun import Faerun
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
from pathlib import Path

In [118]:
BASE_NAME = 'light'
DATA_DIR = Path('./data/light/raw')
DATAS = [DATA_DIR / f'{i+1}.xlsx' for i in range(5)]
# USPTO_50k = 'data/schneider50k/schneider50k.tsv'  # 此处可以导入 uspto 作为背景
REACTION_IN_DATA = 'reaction'
REACTION_IN_USPTO_50k = 'rxn'
SUBCLASS = 'solvent'

In [119]:
NAME = BASE_NAME + '_' + SUBCLASS

In [120]:
df_datas = []
for DATA in DATAS:
    df_datas.append(pd.read_excel(DATA, engine='openpyxl'))  # use openpyxl for better excel
df_data = pd.concat(df_datas)
df_data.dropna(subset=['photocatalyst'], inplace=True)  # xlsx 的最后面往往都是空的
df_data = df_data.reset_index(drop=True)  # 保证顺序

In [121]:
len(df_data)

6068

In [122]:
# 标准化反应
from rdkit import Chem
def con_smi(smi: str):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi), canonical=True)
def con_rxn(rxn: str):
    new_rxn = rxn.split('>>')
    r = new_rxn[0]
    p = new_rxn[1]
    new_r = '.'.join([con_smi(x) for x in r.split('.')])
    new_p = con_smi(p)
    return new_r + '>>' + new_p
df_data['con_rxn'] = None
n_error = 0
for i, row in df_data.iterrows():
    try:
        df_data.at[i, 'con_rxn'] = con_rxn(row['reaction'])
    except:
        n_error += 1
print(f'n_error: {n_error}')
df_data.dropna(subset=['con_rxn'], inplace=True)  # xlsx 的最后面往往都是空的
print(len(df_data))

[17:27:44] Explicit valence for atom # 3 Br, 3, is greater than permitted
[17:27:44] Explicit valence for atom # 1 N, 4, is greater than permitted
[17:27:44] Explicit valence for atom # 1 N, 4, is greater than permitted
[17:27:45] SMILES Parse Error: syntax error while parsing: [H2O18]
[17:27:45] SMILES Parse Error: Failed parsing SMILES '[H2O18]' for input: '[H2O18]'
[17:27:46] Explicit valence for atom # 3 Br, 3, is greater than permitted
[17:27:46] Explicit valence for atom # 0 Cl, 2, is greater than permitted
[17:27:47] Explicit valence for atom # 6 S, 8, is greater than permitted
[17:27:47] SMILES Parse Error: syntax error while parsing: vCC1=CC=NC2=CC=CC=C21
[17:27:47] SMILES Parse Error: Failed parsing SMILES 'vCC1=CC=NC2=CC=CC=C21' for input: 'vCC1=CC=NC2=CC=CC=C21'


n_error: 10
6058


In [123]:
# 根据光催化剂类型生成索引
_class = SUBCLASS + '_class'
df_data[_class] = ''
cat_list = df_data[SUBCLASS].value_counts().index.to_list()
n_catalysis = len(cat_list)
print(n_catalysis)

# 将每类映射成 int
from collections import defaultdict
map_dict = defaultdict(lambda : n_catalysis)
for v, k in enumerate(cat_list):
    map_dict[k] = v

# 将种类信息写入 df
for i, row in tqdm(df_data.iterrows(), total=len(df_data)):
    df_data[_class][i] = map_dict[row[SUBCLASS]]

# 创建逆映射字典，以备后用
_reverse_map_dict = {v: k for k,v in map_dict.items()}
reverse_map_dict = defaultdict(lambda : 'None')
for k, v in _reverse_map_dict.items():
    reverse_map_dict[k] = v
print(map_dict)
print(reverse_map_dict)

43


A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  app.launch_new_instance()
100%|██████████| 6058/6058 [00:01<00:00, 4053.52it/s]

defaultdict(<function <lambda> at 0x7fe703492170>, {'MeCN': 0, 'DCM': 1, 'DMSO': 2, 'DMF': 3, 'DCE': 4, 'Methanol': 5, '1,4-dioxane': 6, 'actone': 7, 'DCM/toluene': 8, 'ethanol': 9, 'actetone': 10, 'acetone': 11, 'DMA': 12, 'MeCN/H2O': 13, 'MeCN,H2O': 14, 'DMF,MeOH': 15, 'DCM,H2O': 16, 'ethanol,H2O': 17, 'Toluene/MeCN': 18, 'PhH': 19, 'i-PrOH': 20, 'MeCN, H2O': 21, 'CF3CH2OH': 22, 'PhCF3': 23, 'actetone, ethyl acetate ': 24, 't-amyl-OH': 25, 'DME,DMF': 26, 'ethyl acetate': 27, 'Methanol,DMSO': 28, 'toluene, H2O': 29, 'TFE': 30, 'CHCl3': 31, 'EtOH': 32, 'NMP,PhCl': 33, 'THF': 34, '1,4-dioxane,MeCN': 35, 'actone/H2O': 36, 'DCM/MeCN': 37, 'DMF/H2O': 38, 'MeCN、DMSO': 39, 'MeOH': 40, 'DMA/H2O': 41, 'propanol': 42})
defaultdict(<function <lambda> at 0x7fe70348add0>, {0: 'MeCN', 1: 'DCM', 2: 'DMSO', 3: 'DMF', 4: 'DCE', 5: 'Methanol', 6: '1,4-dioxane', 7: 'actone', 8: 'DCM/toluene', 9: 'ethanol', 10: 'actetone', 11: 'acetone', 12: 'DMA', 13: 'MeCN/H2O', 14: 'MeCN,H2O', 15: 'DMF,MeOH', 16: 'DCM




In [124]:
model, tokenizer = get_default_model_and_tokenizer('bert_ft_10k_25s')
ft_rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)
fps_ft_data = generate_fingerprints(df_data.reaction.values.tolist(), ft_rxnfp_generator, batch_size=8)

Some weights of the model checkpoint at /home/seeyou/anaconda3/envs/rdkit/lib/python3.7/site-packages/rxnfp/models/transformers/bert_ft_10k_25s were not used when initializing BertModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 757/757 [00:07<00:00, 108.02it/s]


In [125]:
lf = tm.LSHForest(256, 128)
mh_encoder = tm.Minhash()
# slow
mhfps = [mh_encoder.from_weight_array(fp.tolist(), method="I2CWS") for fp in tqdm(fps_ft_data)]

100%|██████████| 6056/6056 [00:01<00:00, 4756.84it/s]


In [126]:
df_data.head()

Unnamed: 0,reaction,photocatalyst,base,additive,solvent,time(h),yield/%,source,Unnamed: 8,Unnamed: 9,Unnamed: 10,base/acid,con_rxn,solvent_class
0,C=CC1=CC=CC=C1.FC([S+]2C3=C(C=CC=C3)C4=C2C=CC=...,fac-Ir(ppy)3,,,DCE,2.0,78,"Angew. Chem., Int. Ed., 2012, 51, 9567–9571",37.0,1.0,,,C=Cc1ccccc1.FC(F)(F)[s+]1c2ccccc2c2ccccc21.F[B...,4
1,C=CC1=CC=CC=C1.FC([S+]2C3=C(C=CC=C3)C4=C2C=CC=...,fac-Ir(ppy)3,,,DCE,2.0,75,"Angew. Chem., Int. Ed., 2012, 51, 9567–9571",37.0,2.0,,,C=Cc1ccccc1.FC(F)(F)[s+]1c2ccccc2c2ccccc21.F[B...,4
2,C=CC1=CC=CC=C1.FC([S+]2C3=C(C=CC=C3)C4=C2C=CC=...,fac-Ir(ppy)3,,,DCE,2.0,76,"Angew. Chem., Int. Ed., 2012, 51, 9567–9571",37.0,3.0,,,C=Cc1ccccc1.FC(F)(F)[s+]1c2ccccc2c2ccccc21.F[B...,4
3,C=CC1=CC=CC=C1.FC([S+]2C3=C(C=CC=C3)C4=C2C=CC=...,fac-Ir(ppy)3,,,DCE,2.0,84,"Angew. Chem., Int. Ed., 2012, 51, 9567–9571",37.0,4.0,,,C=Cc1ccccc1.FC(F)(F)[s+]1c2ccccc2c2ccccc21.F[B...,4
4,C=CC1=CC=CC=C1.FC([S+]2C3=C(C=CC=C3)C4=C2C=CC=...,fac-Ir(ppy)3,,,DCE,2.0,51,"Angew. Chem., Int. Ed., 2012, 51, 9567–9571",37.0,5.0,,,C=Cc1ccccc1.FC(F)(F)[s+]1c2ccccc2c2ccccc21.F[B...,4


In [127]:
# slow

labels = []
# superclasses
superclasses = []

# product properties
tpsa = []
logp = []
mw = []
h_acceptors = []
h_donors = []
ring_count = []

# metals in precursors
has_Pd = []
has_Li = []
has_Mg = []
has_Al = []

n_errors = 0  # 记录错误个数

for i, row in tqdm(df_data.iterrows(), total=len(df_data)):
    try:
        rxn = row["reaction"]
        labels.append(
            str(rxn)
            + "__"
            + str(rxn)
            + f"__source: {row['source']}"
            + f"__solvent: {row['solvent']} yield/%: {row['yield/%']} base: {row['base']} additive: {row['additive']}"
        )
        superclasses.append(int(row[_class]))

        precursors, products = rxn.split('>>')

        mol = Chem.MolFromSmiles(products)

        tpsa.append(Descriptors.TPSA(mol))
        logp.append(Descriptors.MolLogP(mol))
        mw.append(Descriptors.MolWt(mol))
        h_acceptors.append(Descriptors.NumHAcceptors(mol))
        h_donors.append(Descriptors.NumHDonors(mol))
        ring_count.append(Descriptors.RingCount(mol))

        has_Pd.append('Pd' in precursors)
        has_Li.append('Li' in precursors)
        has_Mg.append('Mg' in precursors)
        has_Al.append('Al' in precursors)
    except:
        print(i, row["reaction"])  # 记录错误位置

100%|██████████| 6058/6058 [00:03<00:00, 1805.09it/s]


In [128]:
tpsa_ranked = stats.rankdata(np.array(tpsa) / max(tpsa)) / len(tpsa)
logp_ranked = stats.rankdata(np.array(logp) / max(logp)) / len(logp)
mw_ranked = stats.rankdata(np.array(mw) / max(mw)) / len(mw)
h_acceptors_ranked = stats.rankdata(np.array(h_acceptors) / max(h_acceptors)) / len(
    h_acceptors
)
h_donors_ranked = stats.rankdata(np.array(h_donors) / max(h_donors)) / len(h_donors)
ring_count_ranked = stats.rankdata(np.array(ring_count) / max(ring_count)) / len(
    ring_count
)
labels_groups, groups = Faerun.create_categories(superclasses)
labels_groups = [(label[0], f"{label[1]} - {reverse_map_dict[label[1]]}") for label in labels_groups]
labels_groups

[(0, '0 - MeCN'),
 (1, '1 - DCM'),
 (2, '2 - DMSO'),
 (3, '3 - DMF'),
 (4, '4 - DCE'),
 (5, '5 - Methanol'),
 (6, '6 - 1,4-dioxane'),
 (7, '7 - actone'),
 (8, '8 - DCM/toluene'),
 (9, '9 - ethanol'),
 (10, '10 - actetone'),
 (11, '11 - acetone'),
 (12, '12 - DMA'),
 (13, '13 - MeCN/H2O'),
 (14, '14 - MeCN,H2O'),
 (15, '15 - DMF,MeOH'),
 (16, '16 - DCM,H2O'),
 (17, '17 - ethanol,H2O'),
 (18, '18 - Toluene/MeCN'),
 (19, '19 - PhH'),
 (20, '20 - i-PrOH'),
 (21, '21 - MeCN, H2O'),
 (22, '22 - CF3CH2OH'),
 (23, '23 - PhCF3'),
 (24, '24 - actetone, ethyl acetate '),
 (25, '25 - t-amyl-OH'),
 (26, '26 - DME,DMF'),
 (27, '27 - ethyl acetate'),
 (28, '28 - Methanol,DMSO'),
 (29, '29 - toluene, H2O'),
 (30, '30 - TFE'),
 (31, '31 - CHCl3'),
 (32, '32 - EtOH'),
 (33, '33 - NMP,PhCl'),
 (34, '34 - THF'),
 (35, '35 - 1,4-dioxane,MeCN'),
 (36, '36 - actone/H2O'),
 (37, '37 - DCM/MeCN'),
 (38, '38 - DMF/H2O'),
 (39, '39 - MeCN、DMSO'),
 (40, '40 - MeOH'),
 (41, '41 - DMA/H2O'),
 (42, '42 - propanol')]

In [129]:
# slow
lf.batch_add(mhfps)
lf.index()

# Layout
cfg = tm.LayoutConfiguration()
cfg.k = 50
cfg.kc = 50
cfg.sl_scaling_min = 1.0
cfg.sl_scaling_max = 1.0
cfg.sl_repeats = 1
cfg.sl_extra_scaling_steps = 2
cfg.placer = tm.Placer.Barycenter
cfg.merger = tm.Merger.LocalBiconnected
cfg.merger_factor = 2.0
cfg.merger_adjustment = 0
cfg.fme_iterations = 1000
cfg.sl_scaling_type = tm.ScalingType.RelativeToDesiredLength
cfg.node_size = 1 / 37
cfg.mmm_repeats = 1

# Define colormaps
set1 = plt.get_cmap("Set1").colors
rainbow = plt.get_cmap("rainbow")
colors = rainbow(np.linspace(0, 1, len(set(groups))))[:, :3].tolist()
custom_cm = LinearSegmentedColormap.from_list("my_map", colors, N=len(colors))
bin_cmap = ListedColormap([set1[8], "#5400F6"], name="bin_cmap")

# Get tree coordinates
x, y, s, t, _ = tm.layout_from_lsh_forest(lf, config=cfg)

In [130]:
# slow
f = Faerun(clear_color="#000000", coords=False, view="front",)

f.add_scatter(
"ReactionAtlas",
{
    "x": x, "y": y,
    "c": [
        groups, # superclasses
        has_Pd,
        has_Li,
        has_Mg,
        has_Al,
        tpsa_ranked,
        logp_ranked,
        mw_ranked,
        h_acceptors_ranked,
        h_donors_ranked,
        ring_count_ranked,
    ],
    "labels": labels
},
shader="smoothCircle",
colormap=[
    custom_cm,
    bin_cmap,
    bin_cmap,
    bin_cmap,
    bin_cmap,
    "rainbow",
    "rainbow",
    "rainbow",
    "rainbow",
    "rainbow",
    "rainbow",

],
point_scale=6.0,
categorical=[
    True,
    True,
    True,
    True,
    True,
    False,
    False,
    False,
    False,
    False,
    False,
],
has_legend=True,
legend_labels=[
    labels_groups,
    [(0, "No"), (1, "Yes")],
    [(0, "No"), (1, "Yes")],
    [(0, "No"), (1, "Yes")],
    [(0, "No"), (1, "Yes")],
    None,
    None,
    None,
    None,
    None,
    None,
],
selected_labels=["SMILES", "SMILES", "Patent ID",  "Named Reaction", "Category", "Superclass"],
series_title=[
    "Superclass",
    "Pd",
    "Li",
    "Mg",
    "Al",
    "TPSA",
    "logP",
    "Mol Weight",
    "H Acceptors",
    "H Donors",
    "Ring Count",
],
max_legend_label=[
    None,
    None,
    None,
    None,
    None,
    str(round(max(tpsa))),
    str(round(max(logp))),
    str(round(max(mw))),
    str(round(max(h_acceptors))),
    str(round(max(h_donors))),
    str(round(max(ring_count))),
],
min_legend_label=[
    None,
    None,
    None,
    None,
    None,
    str(round(min(tpsa))),
    str(round(min(logp))),
    str(round(min(mw))),
    str(round(min(h_acceptors))),
    str(round(min(h_donors))),
    str(round(min(ring_count))),
],
title_index=2,
legend_title=NAME,
)

f.add_tree("reactiontree", {"from": s, "to": t}, point_helper="ReactionAtlas")

In [131]:
# slow
plot = f.plot(NAME, template="reaction_smiles")