# DeepDrugDomain Training Script
This notebook provides an example script for training a model using the DeepDrugDomain package.

## Import Libraries
Importing necessary Python libraries and modules from DeepDrugDomain.

In [None]:
import numpy as np
import torch
from deepdrugdomain.optimizers.factory import OptimizerFactory
from deepdrugdomain.schedulers.factory import SchedulerFactory
from deepdrugdomain.data.collate import CollateFactory
from torch.utils.data import DataLoader
from deepdrugdomain.models.factory import ModelFactory
from dgllife.utils import CanonicalAtomFeaturizer
import deepdrugdomain as ddd

## Configuration Settings
Set up the configuration for data paths, model parameters, and other settings.

In [None]:
config = {
    'device': 'cpu',  # 'gpu' if CUDA is available and desired
    'seed': 4,
    'resume': '',
    'start_epoch': 0,
    'eval': False,
    'num_workers': 4,
    'batch_size': 32,
    'pin_mem': True
}

## Setting Environment
seeding all the random actions for reproducibility

In [None]:
# Set seed 
seed = config['seed']
torch.manual_seed(seed)
np.random.seed(seed)

## Defining Preprocessing functions
this part is different for each model based on the author's preprocessing in the original paper

## Model
This way of creating the model, creates the model with default hyperparameter and layers. you can see the default in the config folder of the package

In [None]:
# Model setup
model = ModelFactory.create("attentionsitedti") # you can change the model to other models in the model factory (e.g. "attentionsitedti", "fragxsite", ...)

In [None]:
preprocess_drug, preprocess_protein, preprocess_label = model.get_preprocess("SMILES", "pdb_id", "Label")
preprocesses = preprocess_drug + preprocess_protein + preprocess_label
collate_fn = model.collate

print(preprocesses)

## Handling the data
defining datasets and managing splits and creating dataloaders.

In [None]:
# Load dataset
dataset = ddd.data.DatasetFactory.create(
    "human", # you should change the dataset name to your dataset name in the dataset factory (e.g. "human", "drugbank", "celegans")
    file_paths="data/human/", # you should change the file_paths to the path of your dataset
    preprocesses=preprocesses) 
datasets = dataset(split_method="random_split", # you can change the split_method to other split methods in the dataset factory (e.g. "random_split", "scaffold_split", "cold_split")
                   frac=[0.8, 0.1, 0.1],
                   seed=seed)

data_loader_train = DataLoader(datasets[0], batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], pin_memory=config['pin_mem'],
                               collate_fn=collate_fn, drop_last=True)

data_loader_val = DataLoader(datasets[1], batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], pin_memory=config['pin_mem'],
                               collate_fn=collate_fn, drop_last=True)
data_loader_test = DataLoader(datasets[2], batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], pin_memory=config['pin_mem'],
                               collate_fn=collate_fn, drop_last=True)

# Training Parameters

In [None]:
criterion = torch.nn.BCELoss() # you can change the criterion to other loss functions
optimizer = OptimizerFactory.create(
    "adamw", model.parameters(), lr=1e-3, weight_decay=0.0) # you can change the optimizer to other optimizers in the optimizer factory
scheduler = SchedulerFactory.create(
    "cosine", optimizer, warmup_epochs=0, warmup_lr=1e-3, num_epochs=200) # you can change the scheduler to other schedulers in the scheduler factory
device = torch.device(config['device'])
model.to(device)

# Evaluators
train_evaluator = ddd.metrics.Evaluator(["accuracy_score"], threshold=0.5) # you can change the metrics to other metrics in the  metric factory
test_evaluator = ddd.metrics.Evaluator(
    ["accuracy_score", "f1_score", "auc", "precision_score", "recall_score"], threshold=0.5) # you can change the metrics to other metrics in the metric factory

## Training Loop

In [None]:
epochs = 200
accum_iter = 1
for epoch in range(epochs):
    print(f"Epoch {epoch}:")
    model.train_one_epoch(data_loader_train, device, criterion,
                          optimizer, num_epochs=200, scheduler=scheduler, evaluator=train_evaluator, grad_accum_steps=accum_iter)
    print(model.evaluate(data_loader_val, device, criterion, evaluator=test_evaluator))

# Testing The Trained Model

In [None]:
model.evaluate(data_loader_test, device, criterion, evaluator=test_evaluator)