<!-- # Before you use this template

This template is just a recommended template for project Report. It only considers the general type of research in our paper pool. Feel free to edit it to better fit your project. You will iteratively update the same notebook submission for your draft and the final submission. Please check the project rubriks to get a sense of what is expected in the template.

---

# FAQ and Attentions
* Copy and move this template to your Google Drive. Name your notebook by your team ID (upper-left corner). Don't eidt this original file.
* This template covers most questions we want to ask about your reproduction experiment. You don't need to exactly follow the template, however, you should address the questions. Please feel free to customize your report accordingly.
* any report must have run-able codes and necessary annotations (in text and code comments).
* The notebook is like a demo and only uses small-size data (a subset of original data or processed data), the entire runtime of the notebook including data reading, data process, model training, printing, figure plotting, etc,
must be within 8 min, otherwise, you may get penalty on the grade.
  * If the raw dataset is too large to be loaded  you can select a subset of data and pre-process the data, then, upload the subset or processed data to Google Drive and load them in this notebook.
  * If the whole training is too long to run, you can only set the number of training epoch to a small number, e.g., 3, just show that the training is runable.
  * For results model validation, you can train the model outside this notebook in advance, then, load pretrained model and use it for validation (display the figures, print the metrics).
* The post-process is important! For post-process of the results,please use plots/figures. The code to summarize results and plot figures may be tedious, however, it won't be waste of time since these figures can be used for presentation. While plotting in code, the figures should have titles or captions if necessary (e.g., title your figure with "Figure 1. xxxx")
* There is not page limit to your notebook report, you can also use separate notebooks for the report, just make sure your grader can access and run/test them.
* If you use outside resources, please refer them (in any formats). Include the links to the resources if necessary. -->

# Introduction
* **Github: https://github.com/xiaorandu/dl4h_project**
*   **Background of the problem**
  * **type of problem:** Artificial intelligence (AI) is being used to help aid drug discovery; however, many of these processes focus on the studies of chemical structures and largely ignoring the plethora of information found in text-based instructions. This limitation hinders the advancement of textual descriptions for drug design, molecule editing, and predicting complex biological activities.

  * **importance/meaning of solving the problem:** Solving this problem will help enable faster iterations of drug discovery, such as re-purposing and multi-objective lead optimization.

  * **the difficulty of the problem:** This problem contains many difficult aspects including the zero shot learning tasks (which are especially difficult in the context of bio chemistry) and understanding natural language. Data insufficiency (PubChemSTM consists of 250,000 molecules and 281,000 structure-text pairs vs. 400 million in the vision-language domain used by peers from other domains) is another limitation, and the expressiveness of chemical structure models is also a bottleneck of this work.

  * **the state of the art methods and effectiveness:** A multi modal model was designed that incorporates both molecular structural information and textual knowledge. A multi modal model, MoleculeSTM, which consists of two brances, the chemical structure branch (to handle molecules' internal structure)  and textual description branch (to handle external domain knowledge) was designed. Such design enables the model to be integrated with existing models trained on each seperately , i.e., molecular structural models and scientific language models. A large multi-modal structure-text dataset was created to align the two branhes of MoleculeSTM. Two challenging downstream tasks were desinged, the structure retrieval task and text based molecule editing task and petrained MoleculeSTM was applied on them in a zero-shot manner. By studing these tasks two main attributes of MoleculeSTM were summaried, open vocabulary and composibilty. Open vocabulary means the model can support exploring a wide range of biochemical concepts with unbound vocabulary. Compositionality means complex concepts can be expressed by decomposing it into several simpler concepts. Results had shown the effectiveness of MoleculeSTM can reach the best performance on six zero-shot retrival tasks, which is up to 50% higher accuracy and twenty zero-shot text-based editing tasks, which is up to 40% higher hit ratio when comparing with the stsate-of-the-art methods. Additionally, MoleculeSTM was able to detect critical structure inferred by text descriptions for molecular editing tasks.

*   **Paper explanation**
  * **what did the paper propose:** The paper introduced MoleculeSTM, a model that integrates chemical structures of molecules with their textual descriptions using a contrastive learning approach. The model aims to perform tasks such as structure-text retrieval and molecule editing based on text instructions in a zero-shot setting. It utilizes a vast, multi-modal dataset, PubChemSTM, containing over 280,000 chemical structure-text pairs.

  * **innovations of the method:** MoleculeSTM uniquely combines chemical structural data with textual information, enhancing the model's understanding and generalization capabilities across various biochemical contexts. It demonstrates significant effectiveness in zero-shot scenarios, where the model performs tasks without having been explicitly trained on them. The model supports open-ended vocabularies and can decompose complex instructions into simpler concepts, making it versatile in handling diverse and novel scientific queries.

  * **how well the proposed method work (in its own metrics):** MoleculeSTM significantly outperformed existing methods in zero-shot retrieval and text-based molecule editing tasks. It achieved up to 50% higher accuracy in retrieval tasks and up to 40% higher hit ratios in editing tasks compared to state-of-the-art methods. This indicates a robust ability to generalize and effectively apply learned knowledge to new and unseen data.

  * **what is the contribution to the reasearch regime (referring the Background above, how important the paper is to the problem):** The research incorporated textual descriptions with chemical structures for molecule representation learning.The multi-modal model, MoleculeSTM, consistently showed improved performance when compared to the existing models. MoleculeSTM might accelerate various downstream drug discovery practices, since it was observed that the model can successfully modify molecule substructures to gain desired properties and also retrieve novel drug-target relations. This paper is important as it was able to illustrate the effictiveness of incorporating textual descriptions in addition to chemical structures for molecule representation learning. It did have two limitations,data insufficiency and expressiveness of the chemical structure models.


# Scope of Reproducibility:

List hypotheses from the paper we will test and the corresponding experiments you will run.

Hypothesis:
1. MoleculeSTM achieves state-of-the-art performance in zero-shot structure-text retrieval and molecule editing tasks compared to the existing method.
2. Incorporating textual descriptions through contrastive learning will significantly improve the model's performance on zero-shot retrieval and molecule editing tasks.

Experiments:
Retraining the model on the constructed PubChemSTM dataset, consisting of over 280,000 chemical structure-text pairs, and evaluating it on specified zero-shot tasks: structure-text retrieval and molecule editing.






# Methodology

This methodology is the core of your project. It consists of run-able codes with necessary annotations to show the expeiment you executed for testing the hypotheses.

The methodology at least contains two subsections **data** and **model** in your experiment.

##  Environment

### Python version


In [None]:
!python --version

Python 3.10.12

### Dependencies/packages needed

In [None]:
# installed packages
!pip install rdkit
!pip install torch torchvision
!pip install requests tqdm matplotlib spacy Levenshtein boto3 deepspeed
!pip install ogb==1.2.0
!pip install transformers==4.30.2

!pip install torch_geometric
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install git+https://github.com/MolecularAI/pysmilesutils.git

In [None]:
# install apex
!git clone https://github.com/chao1224/apex.git
%cd apex
!pip install -v --disable-pip-version-check --no-cache-dir ./
%cd ..

In [None]:
# install metagron
!git clone https://github.com/MolecularAI/MolBART.git --branch megatron-molbart-with-zinc
%cd MolBART/megatron_molbart/Megatron-LM-v1.1.5-3D_parallelism
!pip install .
%cd ../../..

##  Data
<!-- Data includes raw data (MIMIC III tables), descriptive statistics (our homework questions), and data processing (feature engineering).
  * Source of the data: where the data is collected from; if data is synthetic or self-generated, explain how. If possible, please provide a link to the raw datasets.
  * Statistics: include basic descriptive statistics of the dataset like size, cross validation split, label distribution, etc.
  * Data process: how do you munipulate the data, e.g., change the class labels, split the dataset to train/valid/test, refining the dataset.
  * Illustration: printing results, plotting figures for illustration.
  * You can upload your raw dataset to Google Drive and mount this Colab to the same directory. If your raw dataset is too large, you can upload the processed dataset and have a code to load the processed dataset. -->


### Data download instruction
We can use the following python script to download the pretraining dataset and downstream datasets.

In [None]:
from huggingface_hub import HfApi, snapshot_download
api = HfApi()
snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="dataset", local_dir='data')

The data folder can be found at [google drive link](https://drive.google.com/drive/u/0/folders/1pCr0WrY-3lbxxy44D68u2cDzmDONVLBD).
<br/>The data folder will include:
```
data
├── PubChemSTM_data/
│   └── raw
│        └── CID2SMILES.csv
│        └── CID2name.json
│        └── CID2name_raw.json
│        └── molecules.sdf
│   └── processed/
├── pretrained_SciBERT/
├── pretrained_MegaMolBART/
├── pretrained_KV-PLM/
├── pretrained_GraphMVP/
├── pretrained_MoleculeSTM_Raw/
├── pretrained_MoleculeSTM/
├── DrugBank_data/
├── ZINC250K_data/
├── Editing_data/
│   └── single_multi_property_SMILES.txt
│   └── neighbor2drug/
│   └── ChEMBL_data/
└── MoleculeNet_data/
```

### Data descriptions with helpful charts and visualizations

In [None]:
%cd MoleculeSTM/data/PubChemSTM_data/raw

In [None]:
'''
The SMILES string views the molecule as a sequence
'''
import pandas as pd
CID2SMILES = 'CID2SMILES.csv'
df_CID2SMILES = pd.read_csv(CID2SMILES, usecols=['CID', 'SMILES'])
df_CID2SMILES.head()

```
index	CID	SMILES
0	     1	CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C
1	     3	O=C(O)C1=CC=CC(O)C1O
2	     4	CC(O)CN
3	     5	NCC(=O)COP(=O)(O)O
4	     6	O=[N+]([O-])c1ccc(Cl)c([N+](=O)[O-])c1
```

In [None]:
import json
CID2name_raw = "CID2name_raw.json"
with open(CID2name_raw, 'r') as file:
  data = json.load(file)

df_CID2name_raw = pd.DataFrame(list(data.items()), columns=['CID', 'Names'])
df_CID2name_raw.head()

```
index	CID	Names
0	    180	Acetone,ACETONE,Acetone
1	    222	Ammonia,AMMONIA SOLUTIONS (CONTAINING MORE THAN 35% BUT NOT MORE THAN 50% AMMONIA),AMMONIA, ANHYDROUS,AMMONIA, SOLUTION, WITH MORE THAN 10% BUT NOT MORE THAN 35% AMMONIA,Ammonia
2	   5359596	Arsenic,ARSENIC,Arsenic,Arsenic atom,Arsenic Compounds
3	    241	Benzene,BENZENE,Benzene,Benzene
4	   23973	Cadmium,CADMIUM,Cadmium atom,Cadmium,Cadmium Compounds
```

In [None]:
import json
CID2name = "CID2name.json"
with open(CID2name, 'r') as file:
  data = json.load(file)

df_CID2name = pd.DataFrame(list(data.items()), columns=['CID', 'Names'])
df_CID2name.head()

```
index	CID	Names
0	    180	Acetone,Acetone,Acetone,Acetone,Acetone
1	    222	Ammonia,Ammonia solutions (containing more than 35% but not more than 50% ammonia),Ammonia, anhydrous,Ammonia, solution, with more than 10% but not more than 35% ammonia,Ammonia,Ammonia,Ammonia,Ammonia
2	   5359596	Arsenic,Arsenic,Arsenic,Arsenic atom,Arsenic, a naturally occurring element,,Arsenic
3	    241	Benzene,Benzene,Benzene,Benzene,Benzene,Benzene
4	   23973	Cadmium,Cadmium,Cadmium atom,The main sources of cadmium in the air,Cadmium,Cadmium,Cadmium
```

In [None]:
from rdkit import Chem
from rdkit.Chem import PandasTools

sdf_file = "molecules.sdf"
df_molecules = PandasTools.LoadSDF(sdf_file)
print(df_molecules.head())

```
# the output would be like this.
  PUBCHEM_COMPOUND_CID PUBCHEM_COMPOUND_CANONICALIZED  \
0             29500027                              1   
1             29500038                              1   
2             29500039                              1   
3             29500070                              1   
4             29500073                              1   

  PUBCHEM_CACTVS_COMPLEXITY PUBCHEM_CACTVS_HBOND_ACCEPTOR  \
0                       460                             7   
1                       480                             6   
2                       480                             6   
3                       543                             6   
4                       451                             6   

  PUBCHEM_CACTVS_HBOND_DONOR PUBCHEM_CACTVS_ROTATABLE_BOND  \
0                          1                             4   
1                          2                             6   
2                          2                             6   
3                          1                             4   
4                          2                             3   

                             PUBCHEM_CACTVS_SUBSKEYS  \
0  AAADccB7oABAAAAAAAAAAAAAAAAAAWLAAAA8QAAAAAAAAA...   
1  AAADceB7sABgAAAAAAAAAAAAAAAAAWJAAAAsAAAAAAAAAA...   
2  AAADceB7sABgAAAAAAAAAAAAAAAAAWJAAAAsAAAAAAAAAA...   
3  AAADceB7oABAAAAAAAAAAAAAAAAAAWLAAAAwYAAAAAAAAA...   
4  AAADceB7oABAAAAAAAAAAAAAAAAAAWAAAAA8QAAAAAAAAA...   

                          PUBCHEM_IUPAC_OPENEYE_NAME  \
0  N-[4-(2-pyridyl)thiazol-2-yl]-4-(tetrazol-1-yl...   
1  (3S)-3-acetamido-N-[4-(2-pyridyl)thiazol-2-yl]...   
2  (3R)-3-acetamido-N-[4-(2-pyridyl)thiazol-2-yl]...   
3  4-(tetrazol-1-yl)-N-[4-(2,4,5-trimethylphenyl)...   
4  3-amino-N-[4-(2,4,5-trimethylphenyl)thiazol-2-...   

                              PUBCHEM_IUPAC_CAS_NAME  \
0  N-[4-(2-pyridinyl)-2-thiazolyl]-4-(1-tetrazoly...   
1  (3S)-3-acetamido-N-[4-(2-pyridinyl)-2-thiazoly...   
2  (3R)-3-acetamido-N-[4-(2-pyridinyl)-2-thiazoly...   
3  4-(1-tetrazolyl)-N-[4-(2,4,5-trimethylphenyl)-...   
4  3-amino-N-[4-(2,4,5-trimethylphenyl)-2-thiazol...   

                           PUBCHEM_IUPAC_NAME_MARKUP  ...  \
0  <I>N</I>-(4-pyridin-2-yl-1,3-thiazol-2-yl)-4-(...  ...   
1  (3<I>S</I>)-3-acetamido-<I>N</I>-(4-pyridin-2-...  ...   
2  (3<I>R</I>)-3-acetamido-<I>N</I>-(4-pyridin-2-...  ...   
3  4-(tetrazol-1-yl)-<I>N</I>-[4-(2,4,5-trimethyl...  ...   
4  3-amino-<I>N</I>-[4-(2,4,5-trimethylphenyl)-1,...  ...   

  PUBCHEM_BOND_UDEF_STEREO_COUNT PUBCHEM_ISOTOPIC_ATOM_COUNT  \
0                              0                           0   
1                              0                           0   
2                              0                           0   
3                              0                           0   
4                              0                           0   

  PUBCHEM_COMPONENT_COUNT PUBCHEM_CACTVS_TAUTO_COUNT PUBCHEM_COORDINATE_TYPE  \
0                       1                         -1               1\n5\n255   
1                       1                         -1               1\n5\n255   
2                       1                         -1               1\n5\n255   
3                       1                         -1               1\n5\n255   
4                       1                         -1               1\n5\n255   

                             PUBCHEM_BONDANNOTATIONS        ID  \
0  1  18  8\n1  20  8\n10  12  8\n10  13  8\n11  ...  29500027   
1  1  11  8\n1  16  8\n11  13  8\n13  15  8\n15  ...  29500038   
2  1  11  8\n1  16  8\n11  13  8\n13  15  8\n15  ...  29500039   
3  1  19  8\n1  20  8\n10  14  8\n11  12  8\n11  ...  29500070   
4  1  18  8\n1  19  8\n10  11  8\n10  12  8\n11  ...  29500073   

                                              ROMol PUBCHEM_XLOGP3  \
0  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca4cf0>            NaN   
1  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca4c80>            NaN   
2  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca5070>            NaN   
3  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca51c0>            NaN   
4  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca5310>            NaN   

  PUBCHEM_REFERENCE_STANDARDIZATION  
0                               NaN  
1                               NaN  
2                               NaN  
3                               NaN  
4                               NaN  
```

### Preprocessing code + command

The preprocessing code can be found at preprocessing/PubchemSTM folder.

##   Model
<!-- The model includes the model definitation which usually is a class, model training, and other necessary parts.
  * Model architecture: layer number/size/type, activation function, etc
  * Training objectives: loss function, optimizer, weight of each loss term, etc
  * Others: whether the model is pretrained, Monte Carlo simulation for uncertainty analysis, etc
  * The code of model should have classes of the model, functions of model training, model validation, etc.
  * If your model training is done outside of this notebook, please upload the trained model here and develop a function to load and test it. -->

### Citation to the original paper
```
@article{liu2023moleculestm,
    title={Multi-modal molecule structure-text model for text-based retrieval and editing},
    author={Liu, Shengchao and Nie, Weili and Wang, Chengpeng and Lu, Jiarui and Qiao, Zhuoran and Liu, Ling and Tang, Jian and Xiao, Chaowei and Anandkumar, Anima},
    title={Multi-modal molecule structure--text model for text-based retrieval and editing},
    journal={Nature Machine Intelligence},
    year={2023},
    month={Dec},
    day={01},
    volume={5},
    number={12},
    pages={1447-1457},
    issn={2522-5839},
    doi={10.1038/s42256-023-00759-6},
    url={https://doi.org/10.1038/s42256-023-00759-6}
}
```

### Link to the original paper’s repo
https://github.com/chao1224/MoleculeSTM

### Model descriptions
Multi-modal molecule structure-text model(MoleculeSTM) combines the chemical structure of molecules and their textual descriptions to enhance the capacities of artificial intelligence in drug discovery. MoleculeSTM utilizes a contrastive learning strategy to jointly learn from over 280,000 chemical structure-text pairs in PubChemSTM. This training approach allows the model to perform zero-shot tasks based on textual instructions, such as molecule structure-text retrieval and molecule editing. MoleculeSTM leverages an open vocabulary and compositionality through natural language, demonstrating state-of-the-art generalization across various benchmarks without the need for labeled examples or fine-tuning.

### Implementation code

#### **Pretraining**

##### **1. Chemical structure branch fc**
*   Pretraining SMILES
*   Pretraining Graph
<br/>
Notes: We currently focus on pretraining Graph.



In [None]:
import sys
sys.path.append('MoleculeSTM')

In [None]:
%cd scripts

In [None]:
# For the molecular graph, we take a pretrained graph isomorphism network15 using GraphMVP pretraining.
! python pretrain.py \
    --verbose --batch_size=8 \
    --molecule_type=Graph

#### **1 Downstream: Zero-shot Structure-text Retrieval**

##### for DrugBank-Description

In [None]:
# For graphs
! python downstream_01_retrieval_Description_Pharmacodynamics.py \
    --task=molecule_description_removed_PubChem \
    --molecule_type=Graph \
    --input_model_dir=../data/demo/demo_checkpoints_Graph

##### for DrugBank-Pharmacodynamics

In [None]:
python downstream_01_retrieval_Description_Pharmacodynamics.py \
    --task=molecule_pharmacodynamics_removed_PubChem \
    --molecule_type=Graph \
    --input_model_dir=../data/demo/demo_checkpoints_Graph

##### for DrugBank-ATC

In [None]:
python downstream_01_retrieval_ATC.py \
    --molecule_type=Graph \
    --input_model_dir=../data/demo/demo_checkpoints_Graph

#### **2 Downstream: Zero-shot Text-based Molecule Editing**

In [None]:
# For graphs
! python downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py \
    --MoleculeSTM_molecule_type=Graph \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph


! python downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py \
    --MoleculeSTM_molecule_type=Graph \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph \
    --language_edit_model_dir=../data/demo/demo_checkpoints_Graph \
    --input_description_id=101

#### **3 Downstream: Molecular Property Prediction**

In [None]:
# For graphs
! python downstream_03_property_prediction.py \
    --dataset=bace --molecule_type=Graph

##   Training

### Computational requirements


*   Hardware: Google Colab V100 GPU, RAM 16GB
*   Seeds: 42
*   Training epochs: 100
*   Computing time:



### Implementation code

### Support Functions and Setup:

In [None]:
# entire blockfor property prediction Graph
def train_classification(model, device, loader, optimizer):
    if args.training_mode == "fine_tuning":
        model.train()
    else:
        model.eval()
    linear_model.train()
    total_loss = 0

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        if args.molecule_type == "MegaMolBART":
            SMILES_list, y = batch
            SMILES_list = list(SMILES_list)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                SMILES_list, mol2latent=None,
                molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = y.to(device).float()
        else:
            batch = batch.to(device)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                batch, mol2latent=None,
                molecule_type="Graph", molecule_model=model)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = batch.y.view(pred.shape).to(device).float()

        is_valid = y ** 2 > 0
        loss_mat = criterion(pred, (y + 1) / 2)
        loss_mat = torch.where(
            is_valid, loss_mat,
            torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))

        optimizer.zero_grad()
        loss = torch.sum(loss_mat) / torch.sum(is_valid)
        loss.backward()
        optimizer.step()
        total_loss += loss.detach().item()

    return total_loss / len(loader)


@torch.no_grad()
def eval_classification(model, device, loader):
    model.eval()
    linear_model.eval()
    y_true, y_scores = [], []

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        if args.molecule_type == "MegaMolBART":
            SMILES_list, y = batch
            SMILES_list = list(SMILES_list)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                SMILES_list, mol2latent=None,
                molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = y.to(device).float()
        else:
            batch = batch.to(device)
            molecule_repr = get_molecule_repr_MoleculeSTM(
                batch, mol2latent=None,
                molecule_type="Graph", molecule_model=model)
            pred = linear_model(molecule_repr)
            pred = pred.float()
            y = batch.y.view(pred.shape).to(device).float()

        y_true.append(y)
        y_scores.append(pred)

    y_true = torch.cat(y_true, dim=0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim=0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
            is_valid = y_true[:, i] ** 2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
        else:
            print("{} is invalid".format(i))

    if len(roc_list) < y_true.shape[1]:
        print(len(roc_list))
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list)) / y_true.shape[1]))

    return sum(roc_list) / len(roc_list), 0, y_true, y_scores

# setup for optimizer
if args.training_mode == "fine_tuning":
    model_param_group = [
        {"params": model.parameters()},
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
else:
    model_param_group = [
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay)

### Perform Training

In [None]:
# for property prediction Graph

train_func = train_classification
eval_func = eval_classification

train_roc_list, val_roc_list, test_roc_list = [], [], []
train_acc_list, val_acc_list, test_acc_list = [], [], []
best_val_roc, best_val_idx = -1, 0
criterion = nn.BCEWithLogitsLoss(reduction="none")

for epoch in range(1, args.epochs + 1):
    loss_acc = train_func(model, device, train_loader, optimizer)
    print("Epoch: {}\nLoss: {}".format(epoch, loss_acc))

    if args.eval_train:
        train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader)
    else:
        train_roc = train_acc = 0
    val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader)
    test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader)

    train_roc_list.append(train_roc)
    train_acc_list.append(train_acc)
    val_roc_list.append(val_roc)
    val_acc_list.append(val_acc)
    test_roc_list.append(test_roc)
    test_acc_list.append(test_acc)
    print("train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc, val_roc, test_roc))
    print()

print("best train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx]))

##   Evaluation

### Metric descriptions

This evaluation uses several different metrics:


*   Contrastive loss - measuring the the performance of the model in correctly identifying true matches from a set of possible matches
*   Accuracy - determines the proportion of correctly labeled predictive matches
*   Confidence Scores - determines a measure of confidence for each prediction


Additionally, the code uses a variety of negative samples to help measure the model's resistence to varying conditions




### Implementation code

#### Support Functions and Setup

In [None]:
# from demo downstream retrieval Graph
def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr


def do_CL_eval(X, Y, neg_Y, args):
    X = F.normalize(X, dim=-1)
    X = X.unsqueeze(1) # B, 1, d

    Y = Y.unsqueeze(0)
    Y = torch.cat([Y, neg_Y], dim=0) # T, B, d
    Y = Y.transpose(0, 1)  # B, T, d
    Y = F.normalize(Y, dim=-1)

    logits = torch.bmm(X, Y.transpose(1, 2)).squeeze()  # B*T
    B = X.size()[0]
    labels = torch.zeros(B).long().to(logits.device)  # B*1

    criterion = nn.CrossEntropyLoss()

    CL_loss = criterion(logits, labels)
    pred = logits.argmax(dim=1, keepdim=False)
    confidence = logits
    CL_conf = confidence.max(dim=1)[0]
    CL_conf = CL_conf.cpu().numpy()

    CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B
    return CL_loss, CL_conf, CL_acc


def get_text_repr(text):
    text_tokens_ids, text_masks = prepare_text_tokens(
        device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)
    text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)
    text_repr = text_output["pooler_output"]
    text_repr = text2latent(text_repr)
    return text_repr


@torch.no_grad()
def eval_epoch(dataloader):
    text_model.eval()
    molecule_model.eval()
    text2latent.eval()
    mol2latent.eval()

    accum_acc_list = [0 for _ in args.T_list]
    if args.verbose:
        L = tqdm(dataloader)
    else:
        L = dataloader
    for batch in L:
        text = batch[0]
        molecule_data = batch[1]
        neg_text = batch[2]
        neg_molecule_data = batch[3]

        text_repr = get_text_repr(text)

        molecule_data = molecule_data.to(device)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            molecule_data, mol2latent=mol2latent,
            molecule_type="Graph", molecule_model=molecule_model)

        if test_mode == "given_text":
            neg_molecule_repr = [
                get_molecule_repr_MoleculeSTM(
                    neg_molecule_data[idx].to(device), mol2latent=mol2latent,
                    molecule_type="Graph", molecule_model=molecule_model) for idx in range(T_max)
            ]
            neg_molecule_repr = torch.stack(neg_molecule_repr)
            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        elif test_mode == "given_molecule":
            neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)]
            neg_text_repr = torch.stack(neg_text_repr)
            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        else:
            raise Exception

    accum_acc_list = np.array(accum_acc_list)
    accum_acc_list /= len(dataloader)
    return accum_acc_list

#### Eval

In [None]:
# from demo downstream retrieval Graph
text_model = text_model.to(device)
molecule_model = molecule_model.to(device)
text2latent = text2latent.to(device)
mol2latent = mol2latent.to(device)

T_max = max(args.T_list) - 1

initial_test_acc_list = []
test_mode = args.test_mode
dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data")


dataset_class = DrugBank_Datasets_Graph_retrieval
dataloader_class = pyg_DataLoader
processed_dir_prefix = args.task

if args.task == "molecule_description":
    template = "SMILES_description_{}.txt"
elif args.task == "molecule_description_removed_PubChem":
    template = "SMILES_description_removed_from_PubChem_{}.txt"
elif args.task == "molecule_description_Raw":
    template = "SMILES_description_{}_Raw.txt"
elif args.task == "molecule_description_removed_PubChem_Raw":
    template = "SMILES_description_removed_from_PubChem_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics":
    template = "SMILES_pharmacodynamics_{}.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt"
elif args.task == "molecule_pharmacodynamics_Raw":
    template = "SMILES_pharmacodynamics_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt"

full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, processed_dir_prefix=processed_dir_prefix, template=template)
full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers

initial_test_acc_list = eval_epoch(full_dataloader)

### Results, Analyses and Plans

### Plans
*   Resolve setup issues when running the code.
<br>Notes: (1) Due to multiple out-of-date and version-conflicted packages in the environment setup, the team has not resolved all the setup issues after lots of attempts on different machines. As a result, the code cannot be fully run for now, and implementing the model training is currently not possible. (2) The team is currently working hard to resolve the environment setup issues and fix bugs in the relevant code, and aiming to make the code executable as soon as possible.

*   Complete model training.
*   Evaluate the training results and summarize them.
*   Document the learning journey in the discussion section.
*   Create a video presentation of our work.


<!-- ## Model comparison
# compare you model with others
# you don't need to re-run all other experiments, instead, you can directly refer the metrics/numbers in the paper-->

<!-- # Discussion

In this section,you should discuss your work and make future plan. The discussion should address the following questions:
  * Make assessment that the paper is reproducible or not.
  * Explain why it is not reproducible if your results are kind negative.
  * Describe “What was easy” and “What was difficult” during the reproduction.
  * Make suggestions to the author or other reproducers on how to improve the reproducibility.
  * What will you do in next phase.
 -->


>



<!-- # References

1. Liu, S., Nie, W., Wang, C. et al. Multi-modal molecule structure–text model for
text-based retrieval and editing. Nat Mach Intell 5, 1447–1457 (2023).
https://doi.org/10.1038/s42256-023-00759-6 -->
