### Train

In [2]:
from predict import MolPredict
from train import MolTrain
import pandas as pd
import numpy as np
from  visual_utils.visualization_module import YourVisualizationClass
from visual_utils.plot import _image_scatter, facecolors_customize
from rdkit import Chem
import rdkit.Chem.Draw as Draw
from matplotlib import pyplot as plt
import umap

In [None]:
clf = MolTrain(task='regression',
                data_type='molecule',
                epochs=200,
                learning_rate=0.00005,
                batch_size=4,
                early_stopping=5,
                save_path='./random_LDS',
                remove_hs=True,
              )
clf.fit('./dataset/Random/train.csv')

### Test

In [None]:
clf = MolPredict(load_model='./random_LDS',visual=False)
test_pred = clf.predict('./dataset/Random/test.csv')
test_results = pd.DataFrame({'pred':test_pred.flatten(),
                           'smiles':clf.datahub.data['smiles']
                            })
print(test_results.head())
test_results.to_csv("./random_LDS/random_LDS.csv")

### Visual

In [None]:
clf = MolPredict(load_model='./dataset/Scaffold/FDS',visual=True)
y_prediction, encodings,y_truths,smiles_data_ = clf.predict('./visual_utils/example.csv')
visualization_instance = YourVisualizationClass()
figure = visualization_instance.visualize(encodings, y_truths.flatten(), y_prediction.flatten())
figure.savefig(f"./visual_utils/{00000}_umap.png")
reducer = umap.UMAP(
    n_neighbors=60,
    min_dist=1.0,
    spread=1.0,
    metric="cosine",
    random_state=12,
)
encodings = reducer.fit_transform(encodings)

feature_name = "Predicted efficiency"
feature_ =  y_prediction.flatten()
fig, ax = plt.subplots(figsize=(12, 10), dpi=300)
im = ax.scatter(
    encodings[:, 0],
    encodings[:, 1],
    c=feature_,
    s=90 * 1000 / len(feature_),
    cmap="Spectral",
    alpha=0.9,
)
ax.set(xticks=[], yticks=[])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
index = np.argmax(y_prediction.flatten())
smiles_ = smiles_data_[index]
mol = Chem.MolFromSmiles(smiles_)
pil_img = Draw.MolToImage(mol, size=(150, 150))
x = [encodings[index, 0]]
y = [encodings[index, 1]]
aaa=y_prediction.flatten()
titles = [f"Predicted efficiency {aaa[index]:.2f}"]
colors = ["#7f7f7f"]
_image_scatter(x, y, [pil_img], titles, colors, ax, offset=(1.2, 0.95))

# find the siles with the least pred
index = np.argmin(aaa)
smiles_ = smiles_data_[index]
mol = Chem.MolFromSmiles(smiles_)
pil_img = Draw.MolToImage(mol, size=(150, 150))
x = [encodings[index, 0]]
y = [encodings[index, 1]]
titles = [f"Predicted efficiency {aaa[index]:.2f}"]
colors = ["#7f7f7f"]
_image_scatter(x, y, [pil_img], titles, colors, ax, offset=(0.1, 0.1))
# color bar
cbar = fig.colorbar(
    im,
    ax=ax,
    # format="%.1f",
    orientation="vertical",
    shrink=0.5,
)
cbar.ax.tick_params(labelsize=20)
cbar.ax.set_ylabel(feature_name, rotation=270, fontsize=20, labelpad=20)

fig.savefig(
        f"./visual_utils/{feature_name}_umap.png",
)


In [None]:

num_atoms = [Chem.MolFromSmiles(smiles).GetNumAtoms() for smiles in smiles_data_]
num_carbons = [smiles.count("C") for smiles in smiles_data_]

fig, ax = plt.subplots(figsize=(12, 10))
im = ax.scatter(
encodings[:, 0],
encodings[:, 1],
c=num_atoms,
s=np.power(y_prediction.flatten(), 3),
cmap="gnuplot2",
alpha=0.4,
edgecolors="white",
)
# remove ticks and spines
ax.set(xticks=[], yticks=[])
#ax.set_title("UMAP projection of the dataset", fontsize=24)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)

# size legend
handles, labels = im.legend_elements(
prop="sizes", alpha=0.6, func=lambda x: x ** (1 / 3)
)
legend = ax.legend(
handles,
labels,
#loc="best",
#loc="lower left",
bbox_to_anchor=(0.17 ,0.52),
title="Predicted efficiency",
fontsize=20,
#handlelength=0.5
)

ax.add_artist(legend)

# color bar
cbar = fig.colorbar(
im,
ax=ax,
# format="%.1f",
orientation="vertical",
shrink=0.5,
)
cbar.ax.tick_params(labelsize=20)
cbar.ax.set_ylabel("num atoms", rotation=270, fontsize=20, labelpad=20)

fig.savefig(f"./visual_utils/numatoms_umap.png",)