In [None]:
import pandas as pd
import deepchem as dc
from deepchem.models.optimizers import ExponentialDecay

from qsar.gan.qsar_gan import QsarGan
from qsar.gan.extract_descriptors import DescriptorsExtractor
from qsar.utils.visualizer import Visualizer

# Data importation

In [None]:
# our dataset (183 smiles)
df = pd.read_csv('../data/unfiltered_data_smiles.csv')
data_local = df[['SMILES']]
data_local.columns = [col.lower() for col in data_local.columns]

# Tox21 Dataset: Contains bioactivity data for various chemicals in toxicity assays, used for toxicology research and safety assessments.
# https://paperswithcode.com/dataset/tox21-1
_, datasets, _ = dc.molnet.load_tox21()
data_tox21 = pd.DataFrame(data={'smiles': datasets[0].ids})

# Lipo Dataset: Provides lipophilicity data (logP values) of chemical compounds, aiding in studies of compound absorption and metabolism in pharmacology.
# https://www.ebi.ac.uk/chembl/document_report_card/CHEMBL3301361/
_, datasets, _ = dc.molnet.load_lipo()
data_lipo = pd.DataFrame(data={'smiles': datasets[0].ids})

# thought of combining all datasets might maximise the GAN perf and boost the prediction
data = pd.concat([data_local, data_tox21, data_lipo], ignore_index=True)
display(data)

# Determine the max atom count threshold

In [None]:
gan = QsarGan(learning_rate=ExponentialDecay(0.001, 0.9, 5000))
max_atom_count, atom_counts = gan.featurizer.determine_atom_count(smiles=data, quantile=0.95)
visualizer = Visualizer()
visualizer.display_atom_count_distribution(atom_counts)

# Train the Gan and generate data

In [None]:
generated_smiles = gan.fit_predict(smiles=data, epochs=128)
display(generated_smiles)

# Extract Descriptors

In [None]:
data_with_descriptors = DescriptorsExtractor.extract_descriptors(generated_smiles)
display(data_with_descriptors)

# Display the generated molecules

In [None]:
visualizer.draw_generated_molecules(generated_smiles)