<!-- # Before you use this template

This template is just a recommended template for project Report. It only considers the general type of research in our paper pool. Feel free to edit it to better fit your project. You will iteratively update the same notebook submission for your draft and the final submission. Please check the project rubriks to get a sense of what is expected in the template.

---

# FAQ and Attentions
* Copy and move this template to your Google Drive. Name your notebook by your team ID (upper-left corner). Don't eidt this original file.
* This template covers most questions we want to ask about your reproduction experiment. You don't need to exactly follow the template, however, you should address the questions. Please feel free to customize your report accordingly.
* any report must have run-able codes and necessary annotations (in text and code comments).
* The notebook is like a demo and only uses small-size data (a subset of original data or processed data), the entire runtime of the notebook including data reading, data process, model training, printing, figure plotting, etc,
must be within 8 min, otherwise, you may get penalty on the grade.
  * If the raw dataset is too large to be loaded  you can select a subset of data and pre-process the data, then, upload the subset or processed data to Google Drive and load them in this notebook.
  * If the whole training is too long to run, you can only set the number of training epoch to a small number, e.g., 3, just show that the training is runable.
  * For results model validation, you can train the model outside this notebook in advance, then, load pretrained model and use it for validation (display the figures, print the metrics).
* The post-process is important! For post-process of the results,please use plots/figures. The code to summarize results and plot figures may be tedious, however, it won't be waste of time since these figures can be used for presentation. While plotting in code, the figures should have titles or captions if necessary (e.g., title your figure with "Figure 1. xxxx")
* There is not page limit to your notebook report, you can also use separate notebooks for the report, just make sure your grader can access and run/test them.
* If you use outside resources, please refer them (in any formats). Include the links to the resources if necessary. -->

# Introduction
**Github: https://github.com/xiaorandu/dl4h_project**
<br>**Background of the problem**
  * **type of problem:** Artificial intelligence (AI) is being used to help aid drug discovery; however, many of these processes focus on the studies of chemical structures and largely ignoring the plethora of information found in text-based instructions. This limitation hinders the advancement of textual descriptions for drug design, molecule editing, and predicting complex biological activities.

  * **importance/meaning of solving the problem:** Solving this problem will help enable faster iterations of drug discovery, such as re-purposing and multi-objective lead optimization.

  * **the difficulty of the problem:** This problem contains many difficult aspects including the zero shot learning tasks (which are especially difficult in the context of bio chemistry) and understanding natural language. Data insufficiency (PubChemSTM consists of 250,000 molecules and 281,000 structure-text pairs vs. 400 million in the vision-language domain used by peers from other domains) is another limitation, and the expressiveness of chemical structure models is also a bottleneck of this work.

  * **the state of the art methods and effectiveness:** A multi modal model was designed that incorporates both molecular structural information and textual knowledge. A multi modal model, MoleculeSTM, which consists of two brances, the chemical structure branch (to handle molecules' internal structure)  and textual description branch (to handle external domain knowledge) was designed. Such design enables the model to be integrated with existing models trained on each seperately , i.e., molecular structural models and scientific language models. A large multi-modal structure-text dataset was created to align the two branhes of MoleculeSTM. Two challenging downstream tasks were desinged, the structure retrieval task and text based molecule editing task and petrained MoleculeSTM was applied on them in a zero-shot manner. By studing these tasks two main attributes of MoleculeSTM were summaried, open vocabulary and composibilty. Open vocabulary means the model can support exploring a wide range of biochemical concepts with unbound vocabulary. Compositionality means complex concepts can be expressed by decomposing it into several simpler concepts. Results had shown the effectiveness of MoleculeSTM can reach the best performance on six zero-shot retrival tasks, which is up to 50% higher accuracy and twenty zero-shot text-based editing tasks, which is up to 40% higher hit ratio when comparing with the stsate-of-the-art methods. Additionally, MoleculeSTM was able to detect critical structure inferred by text descriptions for molecular editing tasks.

<br>**Paper explanation**
  * **what did the paper propose:** The paper introduced MoleculeSTM, a model that integrates chemical structures of molecules with their textual descriptions using a contrastive learning approach. The model aims to perform tasks such as structure-text retrieval and molecule editing based on text instructions in a zero-shot setting. It utilizes a vast, multi-modal dataset, PubChemSTM, containing over 280,000 chemical structure-text pairs.

  * **innovations of the method:** MoleculeSTM uniquely combines chemical structural data with textual information, enhancing the model's understanding and generalization capabilities across various biochemical contexts. It demonstrates significant effectiveness in zero-shot scenarios, where the model performs tasks without having been explicitly trained on them. The model supports open-ended vocabularies and can decompose complex instructions into simpler concepts, making it versatile in handling diverse and novel scientific queries.

  * **how well the proposed method work (in its own metrics):** MoleculeSTM significantly outperformed existing methods in zero-shot retrieval and text-based molecule editing tasks. It achieved up to 50% higher accuracy in retrieval tasks and up to 40% higher hit ratios in editing tasks compared to state-of-the-art methods. This indicates a robust ability to generalize and effectively apply learned knowledge to new and unseen data.

  * **what is the contribution to the reasearch regime (referring the Background above, how important the paper is to the problem):** The research incorporated textual descriptions with chemical structures for molecule representation learning.The multi-modal model, MoleculeSTM, consistently showed improved performance when compared to the existing models. MoleculeSTM might accelerate various downstream drug discovery practices, since it was observed that the model can successfully modify molecule substructures to gain desired properties and also retrieve novel drug-target relations. This paper is important as it was able to illustrate the effictiveness of incorporating textual descriptions in addition to chemical structures for molecule representation learning. It did have two limitations,data insufficiency and expressiveness of the chemical structure models.


# Scope of Reproducibility:

List hypotheses from the paper we will test and the corresponding experiments you will run.

Hypothesis:
1. MoleculeSTM achieves state-of-the-art performance in zero-shot structure-text retrieval and molecule editing tasks compared to the existing method.
2. Incorporating textual descriptions through contrastive learning will significantly improve the model's performance on zero-shot retrieval and molecule editing tasks.

Experiments:
Retraining the model on the constructed PubChemSTM dataset, consisting of over 280,000 chemical structure-text pairs, and evaluating it on specified zero-shot tasks: structure-text retrieval and molecule editing.






# Methodology

This methodology is the core of your project. It consists of run-able codes with necessary annotations to show the expeiment you executed for testing the hypotheses.

The methodology at least contains two subsections **data** and **model** in your experiment.

##  Environment

### Python version


In [12]:
!python --version

Python 3.11.3


In [37]:
from google.colab import drive
drive.mount('/content/drive')

ModuleNotFoundError: No module named 'google'

Python 3.10.12

### Install Dependencies

In [30]:
!pip install rdkit
!pip install torch torchvision
!pip install requests tqdm matplotlib spacy Levenshtein boto3 deepspeed
!pip install ogb==1.2.0
!pip install transformers==4.30.2

!pip install torch_geometric
!pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install git+https://github.com/MolecularAI/pysmilesutils.git

Collecting matplotlib
  Using cached matplotlib-3.8.4-cp311-cp311-win_amd64.whl.metadata (5.9 kB)
Collecting spacy
  Using cached spacy-3.7.4-cp311-cp311-win_amd64.whl.metadata (27 kB)
Collecting Levenshtein
  Using cached Levenshtein-0.25.1-cp311-cp311-win_amd64.whl.metadata (3.4 kB)
Collecting boto3
  Using cached boto3-1.34.98-py3-none-any.whl.metadata (6.6 kB)
Collecting deepspeed
  Using cached deepspeed-0.14.2.tar.gz (1.3 MB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'error'


  error: subprocess-exited-with-error
  
  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [23 lines of output]
      DS_BUILD_OPS=1
      Traceback (most recent call last):
        File "c:\Users\Benjamin\Desktop\Illinois\598_DLH\dl4h_project\.venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 353, in <module>
          main()
        File "c:\Users\Benjamin\Desktop\Illinois\598_DLH\dl4h_project\.venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 335, in main
          json_out['return_val'] = hook(**hook_input['kwargs'])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "c:\Users\Benjamin\Desktop\Illinois\598_DLH\dl4h_project\.venv\Lib\site-packages\pip\_vendor\pyproject_hooks\_in_process\_in_process.py", line 118, in get_requires_for_build_wheel
          return hook(config_settings)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "C:\Users\Benjam

Looking in links: https://data.pyg.org/whl/torch-2.2.1+cu121.html
Looking in links: https://data.pyg.org/whl/torch-2.2.1+cu121.html
Looking in links: https://data.pyg.org/whl/torch-2.2.1+cu121.html
Looking in links: https://data.pyg.org/whl/torch-2.2.1+cu121.html
Collecting git+https://github.com/MolecularAI/pysmilesutils.git
  Cloning https://github.com/MolecularAI/pysmilesutils.git to c:\users\benjamin\appdata\local\temp\pip-req-build-69mbdatx
  Resolved https://github.com/MolecularAI/pysmilesutils.git to commit b1e7ced15a42e18e629b984816020b1f0b0a2aa3
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Installing backend dependencies: started
  Installing backend dependencies: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/MolecularAI/pysmilesutils.git 'C:\Users\Benjamin\AppData\Local\Temp\pip-req-build-69mbdatx'


In [None]:
%cd ..

/content/drive/MyDrive/MoleculeSTM


In [31]:
!pip install packaging
!pip install "jedi>=0.16"
!pip install "cxxfilt>=0.2.0"
!pip install "PyYAML>=5.1"



In [None]:
%pwd

'/content/drive/MyDrive/MoleculeSTM'

Verify package installation

In [32]:
import pkg_resources

# List of packages from apex requirements.txt
required_packages = {
    "cxxfilt": "0.2.0",
    "tqdm": "4.28.1",
    "numpy": "1.15.3",
    "PyYAML": "5.1",
    "pytest": "3.5.1",
    "packaging": "14.0"
}

def check_packages(packages):
    installed_packages = {pkg.key: pkg.version for pkg in pkg_resources.working_set}
    for package, version in packages.items():
        current_version = pkg_resources.get_distribution(package).version
        if current_version and (pkg_resources.parse_version(current_version) >= pkg_resources.parse_version(version)):
            print(f"{package} >= {version} is installed (Current version: {current_version})")
        else:
            missing_version = version if not current_version else f"{current_version} (Required: {version})"
            print(f"{package} >= {version} is NOT installed. Current/Required version: {missing_version}")

check_packages(required_packages)

cxxfilt >= 0.2.0 is installed (Current version: 0.3.0)
tqdm >= 4.28.1 is installed (Current version: 4.66.4)
numpy >= 1.15.3 is installed (Current version: 1.26.4)
PyYAML >= 5.1 is installed (Current version: 6.0.1)
pytest >= 3.5.1 is installed (Current version: 3.5.1)
packaging >= 14.0 is installed (Current version: 24.0)


In [None]:
# install apex
!git clone https://github.com/chao1224/apex.git
%cd apex
!pip install -v --disable-pip-version-check --no-cache-dir ./
%cd ..

In [None]:
# install metagron
# TODO: delete me, lets focus on graph
!git clone https://github.com/MolecularAI/MolBART.git --branch megatron-molbart-with-zinc
%cd MolBART/megatron_molbart/Megatron-LM-v1.1.5-3D_parallelism
!pip install .
%cd ../../..

##  Data
<!-- Data includes raw data (MIMIC III tables), descriptive statistics (our homework questions), and data processing (feature engineering).
  * Source of the data: where the data is collected from; if data is synthetic or self-generated, explain how. If possible, please provide a link to the raw datasets.
  * Statistics: include basic descriptive statistics of the dataset like size, cross validation split, label distribution, etc.
  * Data process: how do you munipulate the data, e.g., change the class labels, split the dataset to train/valid/test, refining the dataset.
  * Illustration: printing results, plotting figures for illustration.
  * You can upload your raw dataset to Google Drive and mount this Colab to the same directory. If your raw dataset is too large, you can upload the processed dataset and have a code to load the processed dataset. -->


### Data download instruction
We can use the following python script to download the pretraining dataset and downstream datasets.

In [None]:
from huggingface_hub import HfApi, snapshot_download
api = HfApi()
snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="dataset", local_dir='data')

The data folder can be found at [google drive link](https://drive.google.com/drive/u/0/folders/1pCr0WrY-3lbxxy44D68u2cDzmDONVLBD).
<br/>The data folder will include:
```
data
├── PubChemSTM_data/
│   └── raw
│        └── CID2SMILES.csv
│        └── CID2name.json
│        └── CID2name_raw.json
│        └── molecules.sdf
│   └── processed/
├── pretrained_SciBERT/
├── pretrained_MegaMolBART/
├── pretrained_KV-PLM/
├── pretrained_GraphMVP/
├── pretrained_MoleculeSTM_Raw/
├── pretrained_MoleculeSTM/
├── DrugBank_data/
├── ZINC250K_data/
├── Editing_data/
│   └── single_multi_property_SMILES.txt
│   └── neighbor2drug/
│   └── ChEMBL_data/
└── MoleculeNet_data/
```

### Data descriptions with helpful charts and visualizations

In [None]:
%cd MoleculeSTM/data/PubChemSTM_data/raw

SMILE (simplified molecular input line entry system) is used to turn a three dimensional chemical structure into a string

In [None]:
import pandas as pd
CID2SMILES = 'CID2SMILES.csv'
df_CID2SMILES = pd.read_csv(CID2SMILES, usecols=['CID', 'SMILES'])
df_CID2SMILES.head()

```
index	CID	SMILES
0	     1	CC(=O)OC(CC(=O)[O-])C[N+](C)(C)C
1	     3	O=C(O)C1=CC=CC(O)C1O
2	     4	CC(O)CN
3	     5	NCC(=O)COP(=O)(O)O
4	     6	O=[N+]([O-])c1ccc(Cl)c([N+](=O)[O-])c1
```

In [None]:
import json
CID2name_raw = "CID2name_raw.json"
with open(CID2name_raw, 'r') as file:
  data = json.load(file)

df_CID2name_raw = pd.DataFrame(list(data.items()), columns=['CID', 'Names'])
df_CID2name_raw.head()

```
index	CID	Names
0	    180	Acetone,ACETONE,Acetone
1	    222	Ammonia,AMMONIA SOLUTIONS (CONTAINING MORE THAN 35% BUT NOT MORE THAN 50% AMMONIA),AMMONIA, ANHYDROUS,AMMONIA, SOLUTION, WITH MORE THAN 10% BUT NOT MORE THAN 35% AMMONIA,Ammonia
2	   5359596	Arsenic,ARSENIC,Arsenic,Arsenic atom,Arsenic Compounds
3	    241	Benzene,BENZENE,Benzene,Benzene
4	   23973	Cadmium,CADMIUM,Cadmium atom,Cadmium,Cadmium Compounds
```

In [None]:
import json
CID2name = "CID2name.json"
with open(CID2name, 'r') as file:
  data = json.load(file)

df_CID2name = pd.DataFrame(list(data.items()), columns=['CID', 'Names'])
df_CID2name.head()

```
index	CID	Names
0	    180	Acetone,Acetone,Acetone,Acetone,Acetone
1	    222	Ammonia,Ammonia solutions (containing more than 35% but not more than 50% ammonia),Ammonia, anhydrous,Ammonia, solution, with more than 10% but not more than 35% ammonia,Ammonia,Ammonia,Ammonia,Ammonia
2	   5359596	Arsenic,Arsenic,Arsenic,Arsenic atom,Arsenic, a naturally occurring element,,Arsenic
3	    241	Benzene,Benzene,Benzene,Benzene,Benzene,Benzene
4	   23973	Cadmium,Cadmium,Cadmium atom,The main sources of cadmium in the air,Cadmium,Cadmium,Cadmium
```

In [None]:
from rdkit.Chem import PandasTools

sdf_file = "molecules.sdf"
df_molecules = PandasTools.LoadSDF(sdf_file)
print(df_molecules.head())

```
# the output would be like this.
  PUBCHEM_COMPOUND_CID PUBCHEM_COMPOUND_CANONICALIZED  \
0             29500027                              1   
1             29500038                              1   
2             29500039                              1   
3             29500070                              1   
4             29500073                              1   

  PUBCHEM_CACTVS_COMPLEXITY PUBCHEM_CACTVS_HBOND_ACCEPTOR  \
0                       460                             7   
1                       480                             6   
2                       480                             6   
3                       543                             6   
4                       451                             6   

  PUBCHEM_CACTVS_HBOND_DONOR PUBCHEM_CACTVS_ROTATABLE_BOND  \
0                          1                             4   
1                          2                             6   
2                          2                             6   
3                          1                             4   
4                          2                             3   

                             PUBCHEM_CACTVS_SUBSKEYS  \
0  AAADccB7oABAAAAAAAAAAAAAAAAAAWLAAAA8QAAAAAAAAA...   
1  AAADceB7sABgAAAAAAAAAAAAAAAAAWJAAAAsAAAAAAAAAA...   
2  AAADceB7sABgAAAAAAAAAAAAAAAAAWJAAAAsAAAAAAAAAA...   
3  AAADceB7oABAAAAAAAAAAAAAAAAAAWLAAAAwYAAAAAAAAA...   
4  AAADceB7oABAAAAAAAAAAAAAAAAAAWAAAAA8QAAAAAAAAA...   

                          PUBCHEM_IUPAC_OPENEYE_NAME  \
0  N-[4-(2-pyridyl)thiazol-2-yl]-4-(tetrazol-1-yl...   
1  (3S)-3-acetamido-N-[4-(2-pyridyl)thiazol-2-yl]...   
2  (3R)-3-acetamido-N-[4-(2-pyridyl)thiazol-2-yl]...   
3  4-(tetrazol-1-yl)-N-[4-(2,4,5-trimethylphenyl)...   
4  3-amino-N-[4-(2,4,5-trimethylphenyl)thiazol-2-...   

                              PUBCHEM_IUPAC_CAS_NAME  \
0  N-[4-(2-pyridinyl)-2-thiazolyl]-4-(1-tetrazoly...   
1  (3S)-3-acetamido-N-[4-(2-pyridinyl)-2-thiazoly...   
2  (3R)-3-acetamido-N-[4-(2-pyridinyl)-2-thiazoly...   
3  4-(1-tetrazolyl)-N-[4-(2,4,5-trimethylphenyl)-...   
4  3-amino-N-[4-(2,4,5-trimethylphenyl)-2-thiazol...   

                           PUBCHEM_IUPAC_NAME_MARKUP  ...  \
0  <I>N</I>-(4-pyridin-2-yl-1,3-thiazol-2-yl)-4-(...  ...   
1  (3<I>S</I>)-3-acetamido-<I>N</I>-(4-pyridin-2-...  ...   
2  (3<I>R</I>)-3-acetamido-<I>N</I>-(4-pyridin-2-...  ...   
3  4-(tetrazol-1-yl)-<I>N</I>-[4-(2,4,5-trimethyl...  ...   
4  3-amino-<I>N</I>-[4-(2,4,5-trimethylphenyl)-1,...  ...   

  PUBCHEM_BOND_UDEF_STEREO_COUNT PUBCHEM_ISOTOPIC_ATOM_COUNT  \
0                              0                           0   
1                              0                           0   
2                              0                           0   
3                              0                           0   
4                              0                           0   

  PUBCHEM_COMPONENT_COUNT PUBCHEM_CACTVS_TAUTO_COUNT PUBCHEM_COORDINATE_TYPE  \
0                       1                         -1               1\n5\n255   
1                       1                         -1               1\n5\n255   
2                       1                         -1               1\n5\n255   
3                       1                         -1               1\n5\n255   
4                       1                         -1               1\n5\n255   

                             PUBCHEM_BONDANNOTATIONS        ID  \
0  1  18  8\n1  20  8\n10  12  8\n10  13  8\n11  ...  29500027   
1  1  11  8\n1  16  8\n11  13  8\n13  15  8\n15  ...  29500038   
2  1  11  8\n1  16  8\n11  13  8\n13  15  8\n15  ...  29500039   
3  1  19  8\n1  20  8\n10  14  8\n11  12  8\n11  ...  29500070   
4  1  18  8\n1  19  8\n10  11  8\n10  12  8\n11  ...  29500073   

                                              ROMol PUBCHEM_XLOGP3  \
0  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca4cf0>            NaN   
1  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca4c80>            NaN   
2  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca5070>            NaN   
3  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca51c0>            NaN   
4  <rdkit.Chem.rdchem.Mol object at 0x7d30b3ca5310>            NaN   

  PUBCHEM_REFERENCE_STANDARDIZATION  
0                               NaN  
1                               NaN  
2                               NaN  
3                               NaN  
4                               NaN  
```

### Preprocessing code + command

The preprocessing code is inside preprocessing/PubchemSTM

##   Model

### **Citation to the original paper**
```
@article{liu2023moleculestm,
    title={Multi-modal molecule structure-text model for text-based retrieval and editing},
    author={Liu, Shengchao and Nie, Weili and Wang, Chengpeng and Lu, Jiarui and Qiao, Zhuoran and Liu, Ling and Tang, Jian and Xiao, Chaowei and Anandkumar, Anima},
    title={Multi-modal molecule structure--text model for text-based retrieval and editing},
    journal={Nature Machine Intelligence},
    year={2023},
    month={Dec},
    day={01},
    volume={5},
    number={12},
    pages={1447-1457},
    issn={2522-5839},
    doi={10.1038/s42256-023-00759-6},
    url={https://doi.org/10.1038/s42256-023-00759-6}
}

### **Link to the original paper’s repo**
https://github.com/chao1224/MoleculeSTM

### **Model descriptions**
MoleculeSTM consists of two branches: the **chemical structure branch $x_c$** and the **textual description branch $x_t$**.

*   The chemical structure branch illustrates the arrangement of atoms in a moleculem, and we specifically focus on its **two-dimensional molecular graph**, which takes the atoms and bonds as the nodes and edges, respectively.
*   The textual description branch provides a high-level description of the molecule’s functionality.


This paper applies several models for text-based retrival and editing tasks from multi-modal molecule structure-text model. We apply the models below in our project.

*   **Molecule graph encoder $f_c$**: apply a graph neural network (GNN) encoder to get a latent vector as molecule representation. Speciafically, we take a pretrained graph isomorphism network using **GraphMVP** pretraining.

  *   GraphMVP doing a multi-view pretraining between the two-dimensional topologies and three-dimensional geometries on 250,000 conformations from the Geometric Ensemble of Molecules (GEOM) dataset.
    
*   **Text encoder $f_t$**: adapt the pretrained **SciBERT32** languange model, which was pretrained on the textual data from the chemical and biological domain.



### **Implementation code**


#### Molecule graph encoder $f_c$
*   `GINConv` and `GCNConv` classes:
  *   These are convolutional layers used within the GNN, configured with `*emb_dim` and aggregation method.
*   `GNN` class:
  *   Constructs a sequence of GNN layers (`gnns`) based on `num_layer` and `gnn_type`.
  *   Applies batch normalization (`batch_norms`) across layers.
  *   Utilizes dropout as specified by `dropout_ratio`.
  *   Incorporates the Jumping Knowledge (JK) network setup through `JK`.
*   `GNN_graph` class:
  *   Extends the GNN model to predict properties at the graph level.
  *   Uses `graph_pooling` for reducing node features to graph features.
  *   Integrates the entire node model (`molecule_node_model`) for processing graphs.








In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (MessagePassing, global_add_pool,
                                global_max_pool, global_mean_pool)
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.utils import add_self_loops, softmax, degree
from torch_scatter import scatter_add
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from collections import OrderedDict


class GINConv(MessagePassing):
    def __init__(self, emb_dim, aggr="add"):
        '''
            emb_dim (int): node embedding dimensionality
        '''
        super(GINConv, self).__init__(aggr=aggr)

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        edge_embedding = self.bond_encoder(edge_attr)
        out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
        return out

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out


class GCNConv(MessagePassing):
    def __init__(self, emb_dim, aggr="add"):
        super(GCNConv, self).__init__(aggr=aggr)

        self.linear = torch.nn.Linear(emb_dim, emb_dim)
        self.root_emb = torch.nn.Embedding(1, emb_dim)
        self.bond_encoder = BondEncoder(emb_dim = emb_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.linear(x)
        edge_embedding = self.bond_encoder(edge_attr)

        row, col = edge_index

        #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device)
        deg = degree(row, x.size(0), dtype = x.dtype) + 1
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1)

    def message(self, x_j, edge_attr, norm):
        return norm.view(-1, 1) * F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out


class GNN(nn.Module):
    def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0., gnn_type="gin"):

        if num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        super(GNN, self).__init__()
        self.drop_ratio = drop_ratio
        self.num_layer = num_layer
        self.JK = JK

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of MLPs
        self.gnns = nn.ModuleList()
        for layer in range(num_layer):
            if gnn_type == "gin":
                self.gnns.append(GINConv(emb_dim, aggr="add"))
            elif gnn_type == "gcn":
                self.gnns.append(GCNConv(emb_dim))

        ###List of batchnorms
        self.batch_norms = nn.ModuleList()
        for layer in range(num_layer):
            self.batch_norms.append(nn.BatchNorm1d(emb_dim))

    def forward(self, *argv):
        if len(argv) == 3:
            x, edge_index, edge_attr = argv[0], argv[1], argv[2]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        else:
            raise ValueError("unmatched number of arguments.")

        x = self.atom_encoder(x)

        h_list = [x]
        for layer in range(self.num_layer):
            h = self.gnns[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)
            # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            if layer == self.num_layer - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "concat":
            node_representation = torch.cat(h_list, dim=1)
        elif self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "max":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0]
        elif self.JK == "sum":
            h_list = [h.unsqueeze_(0) for h in h_list]
            node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]
        else:
            raise ValueError("not implemented.")
        return node_representation


class GNN_graphpred(nn.Module):
    """
    Extension of GIN to incorporate edge information by concatenation.

    Args:
        num_layer (int): the number of GNN layers
        arg.emb_dim (int): dimensionality of embeddings
        num_tasks (int): number of tasks in multi-task learning scenario
        JK (str): last, concat, max or sum.
        graph_pooling (str): sum, mean, max, attention, set2set

    See https://arxiv.org/abs/1810.00826
    JK-net: https://arxiv.org/abs/1806.03536 """

    def __init__(self, num_layer, emb_dim, num_tasks, JK, graph_pooling, molecule_node_model=None):
        super(GNN_graphpred, self).__init__()

        if num_layer < 2:
            raise ValueError("# layers must > 1.")

        self.molecule_node_model = molecule_node_model
        self.num_layer = num_layer
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.JK = JK

        # Different kind of graph pooling
        if graph_pooling == "sum":
            self.pool = global_add_pool
        elif graph_pooling == "mean":
            self.pool = global_mean_pool
        elif graph_pooling == "max":
            self.pool = global_max_pool
        else:
            raise ValueError("Invalid graph pooling type.")

        # For graph-level binary classification
        self.mult = 1

        if self.JK == "concat":
            self.graph_pred_linear = nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim,
                                               self.num_tasks)
        else:
            self.graph_pred_linear = nn.Linear(self.mult * self.emb_dim, self.num_tasks)
        return

    def from_pretrained(self, model_file):
        print("Loading from {} ...".format(model_file))
        state_dict = torch.load(model_file)
        self.molecule_node_model.load_state_dict(state_dict)
        return

    def forward(self, *argv):
        if len(argv) == 4:
            x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3]
        elif len(argv) == 1:
            data = argv[0]
            x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        else:
            raise ValueError("unmatched number of arguments.")

        node_representation = self.molecule_node_model(x, edge_index, edge_attr)
        graph_representation = self.pool(node_representation, batch)
        output = self.graph_pred_linear(graph_representation)
        return graph_representation, output

FileNotFoundError: Could not find module 'C:\Users\Benjamin\Desktop\Illinois\598_DLH\dl4h_project\.venv\Lib\site-packages\torch_scatter\_scatter_cuda.pyd' (or one of its dependencies). Try using the full path with constructor syntax.

For GraphMVP, check this [repo](https://github.com/chao1224/GraphMVP), and use thie checkpoints on this [link](https://drive.google.com/drive/u/1/folders/1uPsBiQF3bfeCAXSDd4JfyXiTh-qxYfu6).
```
pretrained_GraphMVP/
├── GraphMVP_C
│   └── model.pth
└── GraphMVP_G
    └── model.pth
```

#### Text encoder $f_t$
This can be done by calling the following from SciBERT:
```
SciBERT_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)
SciBERT_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)
```

### **Pretrained model**
In the pretrain phase, MoleculeSTM aims to map the representations extracted from the chemical structure branch and the textual description branch to a joint space via contrastive learning.
We initialize the encoders from both branches with the pretrained single-modal checkpoints, and perform contrastive pretraining on the dataset.

<img src="https://raw.githubusercontent.com/xiaorandu/dl4h_project/main/img/pretraining.png" width=1500 />
Figure 1: MoleculeSTM Contrastive Pretraining (source: https://github.com/chao1224/MoleculeSTM)

\
The contrastive learning strategy is adopted by using EBM-NCE and InfoNCE. They align the structure–text pairs for the same molecule and contrast the pairs for different molecules simultaneously. The objectives for EBM-NCE and InfoNCE are as follows:
<img src="https://raw.githubusercontent.com/xiaorandu/dl4h_project/main/img/formula_1.png" width=500 />


In [None]:
# 1. Load and Customize Arguments
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)

parser.add_argument("--dataspace_path", type=str, default="/content/drive/MyDrive/MoleculeSTM/data")
parser.add_argument("--dataset", type=str, default="PubChemSTM1K")
parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT"])
parser.add_argument("--molecule_type", type=str, default="Graph", choices=["Graph"])

parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--text_lr", type=float, default=1e-4)
parser.add_argument("--mol_lr", type=float, default=1e-4)
parser.add_argument("--text_lr_scale", type=float, default=0.1)
parser.add_argument("--mol_lr_scale", type=float, default=0.1)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--epochs", type=int, default=5) #default=100
parser.add_argument("--decay", type=float, default=0)
parser.add_argument("--verbose", type=int, default=1)
parser.add_argument("--output_model_dir", type=str, default=None)

########## for SciBERT ##########
parser.add_argument("--max_seq_len", type=int, default=512)

########## for 2D GNN ##########
parser.add_argument("--pretrain_gnn_mode", type=str, default="GraphMVP_G", choices=["GraphMVP_G"])
parser.add_argument("--gnn_emb_dim", type=int, default=300)
parser.add_argument("--num_layer", type=int, default=5)
parser.add_argument('--JK', type=str, default='last')
parser.add_argument("--dropout_ratio", type=float, default=0.5)
parser.add_argument("--gnn_type", type=str, default="gin")
parser.add_argument('--graph_pooling', type=str, default='mean')

########## for contrastive SSL ##########
parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"])
parser.add_argument("--SSL_emb_dim", type=int, default=256)
parser.add_argument("--CL_neg_samples", type=int, default=1)
parser.add_argument("--T", type=float, default=0.1)
parser.add_argument('--normalize', dest='normalize', action='store_true')
parser.add_argument('--no_normalize', dest='normalize', action='store_false')
parser.set_defaults(normalize=True)

args = parser.parse_args("")
print("arguments\t", args)

arguments	 Namespace(seed=42, device=0, dataspace_path='../data', dataset='PubChemSTM1K', text_type='SciBERT', molecule_type='Graph', batch_size=4, text_lr=0.0001, mol_lr=0.0001, text_lr_scale=0.1, mol_lr_scale=0.1, num_workers=8, epochs=5, decay=0, verbose=1, output_model_dir=None, max_seq_len=512, pretrain_gnn_mode='GraphMVP_G', gnn_emb_dim=300, num_layer=5, JK='last', dropout_ratio=0.5, gnn_type='gin', graph_pooling='mean', SSL_loss='EBM_NCE', SSL_emb_dim=256, CL_neg_samples=1, T=0.1, normalize=True)


In [None]:
# 2. Load Packages
import sys
sys.path.append('/content/drive/MyDrive/MoleculeSTM')
import os
import time
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader as torch_DataLoader

from torch_geometric.loader import DataLoader as pyg_DataLoader
from transformers import AutoModel, AutoTokenizer

from MoleculeSTM.datasets import (
    PubChemSTM_Datasets_Graph, PubChemSTM_SubDatasets_Graph,
    PubChemSTM_Datasets_Raw_Graph, PubChemSTM_SubDatasets_Raw_Graph
)
from MoleculeSTM.models import GNN, GNN_graphpred
from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network

In [None]:
# 3. Supporting functions

# create a cyclically shifted version of an array
def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr

#  perform contrastive learning
def do_CL(X, Y, args):
    if args.normalize:
        X = F.normalize(X, dim=-1)
        Y = F.normalize(Y, dim=-1)

    # Energy-Based Model with Noise-Contrastive Estimation
    if args.SSL_loss == 'EBM_NCE':
        criterion = nn.BCEWithLogitsLoss()
        # Generate negative samples for Y by cyclically shifting each sample
        neg_Y = torch.cat([Y[cycle_index(len(Y), i + 1)] for i in range(args.CL_neg_samples)], dim=0)
        # Repeat X to match the number of negative samples
        neg_X = X.repeat((args.CL_neg_samples, 1))

        # Compute positive and negative predictions and apply temperature scaling
        pred_pos = torch.sum(X * Y, dim=1) / args.T
        pred_neg = torch.sum(neg_X * neg_Y, dim=1) / args.T

        # Compute loss for positive and negative predictions
        loss_pos = criterion(pred_pos, torch.ones(len(pred_pos)).to(pred_pos.device))
        loss_neg = criterion(pred_neg, torch.zeros(len(pred_neg)).to(pred_neg.device))

        # Calculate the overall contrastive learning loss
        CL_loss = (loss_pos + args.CL_neg_samples * loss_neg) / (1 + args.CL_neg_samples)

        # Calculate the accuracy for positive and negative predictions
        CL_acc = (torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()) / \
                 (len(pred_pos) + len(pred_neg))
        CL_acc = CL_acc.detach().cpu().item()

    # Information Noise-Contrastive Estimation
    elif args.SSL_loss == 'InfoNCE':
        criterion = nn.CrossEntropyLoss()
        B = X.size()[0]

        # Compute the dot product between all pairs of X and Y and apply temperature scaling
        logits = torch.mm(X, Y.transpose(1, 0))  # B*B
        logits = torch.div(logits, args.T)
        labels = torch.arange(B).long().to(logits.device)  # B*1

        # Compute the loss using cross-entropy
        CL_loss = criterion(logits, labels)

        # Determine the predicted class and calculate accuracy
        pred = logits.argmax(dim=1, keepdim=False)
        CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B

    else:
        raise Exception

    return CL_loss, CL_acc

In [None]:
# 4. Training Function
def train(
    epoch,
    dataloader,
    text_model, text_tokenizer,
    molecule_model, MegaMolBART_wrapper=None):

    text_model.train()
    molecule_model.train()
    text2latent.train()
    mol2latent.train()

    if args.verbose:
        L = tqdm(dataloader)
    else:
        L = dataloader

    start_time = time.time()
    accum_loss, accum_acc = 0, 0
    for step, batch in enumerate(L):
        description = batch[0]
        molecule_data = batch[1]

        description_tokens_ids, description_masks = prepare_text_tokens(
            device=device, description=description, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)
        description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks)
        description_repr = description_output["pooler_output"]
        description_repr = text2latent(description_repr)

        molecule_data = molecule_data.to(device)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            molecule_data, mol2latent=mol2latent,
            molecule_type=molecule_type, molecule_model=molecule_model)

        loss_01, acc_01 = do_CL(description_repr, molecule_repr, args)
        loss_02, acc_02 = do_CL(molecule_repr, description_repr, args)
        loss = (loss_01 + loss_02) / 2
        acc = (acc_01 + acc_02) / 2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        accum_loss += loss.item()
        accum_acc += acc

    accum_loss /= len(L)
    accum_acc /= len(L)

    global optimal_loss
    temp_loss = accum_loss
    if temp_loss < optimal_loss:
        optimal_loss = temp_loss
    print("CL Loss: {:.5f}\tCL Acc: {:.5f}\tTime: {:.5f}".format(accum_loss, accum_acc, time.time() - start_time))
    return

In [None]:
# 5. Pretraining
# 5.1 set seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

In [None]:
# 5.2 prepare text model
kwargs = {}

if args.text_type == "SciBERT":
    pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')
    print("Download SciBert to {}".format(pretrained_SciBERT_folder))
    text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)
    text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)
    kwargs["text_tokenizer"] = text_tokenizer
    kwargs["text_model"] = text_model
    text_dim = 768
else:
    raise Exception

Download SciBert to ../data/pretrained_SciBERT


Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
print(text_model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(31090, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [None]:
# 5.3 starting training MoleculeSTM-Graph

# prepare GraphMVP (Graph Model) and Dataset
dataset_root = os.path.join(args.dataspace_path, "PubChemSTM_data")

molecule_type = "Graph"

# PubChemSTM_Datasets_Graph(dataset_root)
dataset = PubChemSTM_SubDatasets_Graph(dataset_root, size=100) #size = 1000

dataloader_class = pyg_DataLoader

molecule_node_model = GNN(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim,
    JK=args.JK, drop_ratio=args.dropout_ratio,
    gnn_type=args.gnn_type)
molecule_model = GNN_graphpred(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling,
    num_tasks=1, molecule_node_model=molecule_node_model)
pretrained_model_path = os.path.join(args.dataspace_path, "pretrained_GraphMVP", args.pretrain_gnn_mode, "model.pth")
molecule_model.from_pretrained(pretrained_model_path)

molecule_model = molecule_model.to(device)

kwargs["molecule_model"] = molecule_model
molecule_dim = args.gnn_emb_dim

dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)

Loading from ../data/pretrained_GraphMVP/GraphMVP_G/model.pth ...


In [None]:
print(f"dataset size is: {dataset.size}, sample data shows below:")
for i in range(20):
  print(dataset.CID_list[i], dataset.text_list[i])

dataset size is: 100, sample data shows below:
29927686 Scutellarin(1-)
29982675 9-cis-4-oxoretinoate
29918871 Monacolin J carboxylate
29986894 11-cis-retinoate
29919282 (R)-imazamox(1-)
29919280 (S)-imazamox(1-)
29986450 6-(O-phosphocholine)oxyhexanoate(1-)
29918994 Tenofovir(1-)
29922751 Cidofovir(1-)
29986451 6-(O-phosphocholine)oxyhexanoic acid
29986850 Cis-resveratrol 3-O-glucuronide
29969962 Abacavir 5'-glucuronide
29919281 (S)-imazamox
29922189 (2S)-2-methylbutanoic acid [(1S,7S,8S,8aR)-8-[2-[(2R,4S)-4-hydroxy-2-oxanyl]ethyl]-7-methyl-1,2,3,7,8,8a-hexahydronaphthalen-1-yl] ester
29981063 Homovanillic acid sulfate
29919046 N-Desmethylzolmitriptan
29921593 Beta-D-Gal-(1->4)-alpha-D-Man
29976540 Cyanomethyl 2,3,4,6-tetra-O-acetyl-1-thio-alpha-D-mannoside
29918831 6-hydroxyetodolac
29976188 (3S)-3-(4-chlorophenyl)-4-hydroxybutanoic acid


In [None]:
print(molecule_model)

GNN_graphpred(
  (molecule_node_model): GNN(
    (atom_encoder): AtomEncoder(
      (atom_embedding_list): ModuleList(
        (0): Embedding(119, 300)
        (1): Embedding(4, 300)
        (2-3): 2 x Embedding(12, 300)
        (4): Embedding(10, 300)
        (5-6): 2 x Embedding(6, 300)
        (7-8): 2 x Embedding(2, 300)
      )
    )
    (gnns): ModuleList(
      (0-4): 5 x GINConv()
    )
    (batch_norms): ModuleList(
      (0-4): 5 x BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (graph_pred_linear): Linear(in_features=300, out_features=1, bias=True)
)


In [None]:
# prepare two project layers
text2latent = nn.Linear(text_dim, args.SSL_emb_dim).to(device)
mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim).to(device)
print(f"text2latent: {text2latent}")
print(f"mol2latent: {mol2latent}")

text2latent: Linear(in_features=768, out_features=256, bias=True)
mol2latent: Linear(in_features=300, out_features=256, bias=True)


In [None]:
# prepare optimizers
model_param_group = [
    {"params": text_model.parameters(), "lr": args.text_lr},
    {"params": molecule_model.parameters(), "lr": args.mol_lr},
    {"params": text2latent.parameters(), "lr": args.text_lr * args.text_lr_scale},
    {"params": mol2latent.parameters(), "lr": args.mol_lr * args.mol_lr_scale},
]
optimizer = optim.Adam(model_param_group, weight_decay=args.decay)
optimal_loss = 1e10
print(optimizer)

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0

Parameter Group 2
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 1e-05
    maximize: False
    weight_decay: 0

Parameter Group 3
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 1e-05
    maximize: False
    weight_decay: 0
)


In [None]:
# start training
for e in range(3):
    print("Epoch {}".format(e))
    train(e, dataloader, **kwargs)

##   Training


### **Computational requirements**


*   Hardware: Google Colab V100 GPU, RAM 16GB
*   Seeds: 42
*   Training epochs: 100

### **Implementation code**

#### **1 Downstream: Zero-shot Structure-text Retrieval**

In [None]:
# 1. Load Packages
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append('/content/drive/MyDrive/MoleculeSTM')
import os
import time
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader as torch_DataLoader
from torch_geometric.loader import DataLoader as pyg_DataLoader

from transformers import AutoModel, AutoTokenizer
from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_Graph_retrieval
from MoleculeSTM.models import GNN, GNN_graphpred
from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network

# Set-up the environment variable to ignore warnings
os.environ['TOKENIZERS_PARALLELISM'] = 'False'

In [None]:
# 2. Setup Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--SSL_emb_dim", type=int, default=256)
parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"])
parser.add_argument("--load_latent_projector", type=int, default=1)
parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"])

########## for dataset and split ##########
parser.add_argument("--dataspace_path", type=str, default="../data")
parser.add_argument("--dataset", type=str, default="PubChem")
parser.add_argument("--task", type=str, default="molecule_description",
    choices=[
        "molecule_description", "molecule_description_Raw",
        "molecule_description_removed_PubChem", "molecule_description_removed_PubChem_Raw",
        "molecule_pharmacodynamics", "molecule_pharmacodynamics_Raw",
        "molecule_pharmacodynamics_removed_PubChem", "molecule_pharmacodynamics_removed_PubChem_Raw"])
parser.add_argument("--test_mode", type=str, default="given_text", choices=["given_text", "given_molecule"])

########## for optimization ##########
parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20])
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--text_lr", type=float, default=1e-5)
parser.add_argument("--mol_lr", type=float, default=1e-5)
parser.add_argument("--text_lr_scale", type=float, default=0.1)
parser.add_argument("--mol_lr_scale", type=float, default=0.1)
parser.add_argument("--decay", type=float, default=0)

########## for contrastive objective ##########
parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"])
parser.add_argument("--CL_neg_samples", type=int, default=1)
parser.add_argument("--T", type=float, default=0.1)
parser.add_argument('--normalize', dest='normalize', action='store_true')
parser.add_argument('--no_normalize', dest='normalize', action='store_false')
parser.set_defaults(normalize=True)

########## for BERT model ##########
parser.add_argument("--max_seq_len", type=int, default=512)

########## for molecule model ##########
parser.add_argument("--molecule_type", type=str, default="Graph", choices=["SMILES", "Graph"])

########## for 2D GNN ##########
parser.add_argument("--gnn_emb_dim", type=int, default=300)
parser.add_argument("--num_layer", type=int, default=5)
parser.add_argument('--JK', type=str, default='last')
parser.add_argument("--dropout_ratio", type=float, default=0.5)
parser.add_argument("--gnn_type", type=str, default="gin")
parser.add_argument('--graph_pooling', type=str, default='mean')

########## for saver ##########
parser.add_argument("--eval_train", type=int, default=0)
parser.add_argument("--verbose", type=int, default=0)

parser.add_argument("--input_model_dir", type=str, default="demo/demo_checkpoints_Graph")
parser.add_argument("--input_model_path", type=str, default="demo/demo_checkpoints_Graph/molecule_model.pth")


args = parser.parse_args("")
print("arguments\t", args)

arguments	 Namespace(seed=42, device=0, SSL_emb_dim=256, text_type='SciBERT', load_latent_projector=1, training_mode='zero_shot', dataspace_path='../data', dataset='PubChem', task='molecule_description', test_mode='given_text', T_list=[4, 10, 20], batch_size=32, num_workers=8, epochs=1, text_lr=1e-05, mol_lr=1e-05, text_lr_scale=0.1, mol_lr_scale=0.1, decay=0, SSL_loss='EBM_NCE', CL_neg_samples=1, T=0.1, normalize=True, max_seq_len=512, molecule_type='Graph', gnn_emb_dim=300, num_layer=5, JK='last', dropout_ratio=0.5, gnn_type='gin', graph_pooling='mean', eval_train=0, verbose=0, input_model_dir='demo/demo_checkpoints_Graph', input_model_path='demo/demo_checkpoints_Graph/molecule_model.pth')


In [None]:
# 3. Setup Seed
np.random.seed(args.seed)
torch.random.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# 4. Load SciBERT
pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')
text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)
text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)
text_dim = 768

# input_model_path = os.path.join(args.input_model_dir, "text_model.pth")
input_model_path = "/content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/text_model.pth"
print("Loading from {}...".format(input_model_path))
state_dict = torch.load(input_model_path, map_location='cpu')
text_model.load_state_dict(state_dict)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Loading from /content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/text_model.pth...


<All keys matched successfully>

In [None]:
print(text_model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(31090, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [None]:
# 5. Load MoleculeSTM-Graph
molecule_node_model = GNN(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim,
    JK=args.JK, drop_ratio=args.dropout_ratio,
    gnn_type=args.gnn_type)
molecule_model = GNN_graphpred(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling,
    num_tasks=1, molecule_node_model=molecule_node_model)
molecule_dim = args.gnn_emb_dim

input_model_path = "/content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/molecule_model.pth"
print("Loading from {}...".format(input_model_path))
state_dict = torch.load(input_model_path, map_location='cpu')
molecule_model.load_state_dict(state_dict)

Loading from /content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/molecule_model.pth...


<All keys matched successfully>

In [None]:
print(molecule_model)

GNN_graphpred(
  (molecule_node_model): GNN(
    (atom_encoder): AtomEncoder(
      (atom_embedding_list): ModuleList(
        (0): Embedding(119, 300)
        (1): Embedding(4, 300)
        (2-3): 2 x Embedding(12, 300)
        (4): Embedding(10, 300)
        (5-6): 2 x Embedding(6, 300)
        (7-8): 2 x Embedding(2, 300)
      )
    )
    (gnns): ModuleList(
      (0-4): 5 x GINConv()
    )
    (batch_norms): ModuleList(
      (0-4): 5 x BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (graph_pred_linear): Linear(in_features=300, out_features=1, bias=True)
)


In [None]:
# 6. Load Projection Layers
text2latent = nn.Linear(text_dim, args.SSL_emb_dim)
input_model_path = "/content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/text2latent_model.pth"
print("Loading from {}...".format(input_model_path))
state_dict = torch.load(input_model_path, map_location='cpu')
text2latent.load_state_dict(state_dict)

mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim)
input_model_path = "/content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/mol2latent_model.pth"
print("Loading from {}...".format(input_model_path))
state_dict = torch.load(input_model_path, map_location='cpu')
mol2latent.load_state_dict(state_dict)

print(f"text2latent: {text2latent}")
print(f"mol2latent: {mol2latent}")

Loading from /content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/text2latent_model.pth...
Loading from /content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/mol2latent_model.pth...
text2latent: Linear(in_features=768, out_features=256, bias=True)
mol2latent: Linear(in_features=300, out_features=256, bias=True)


In [None]:
# 7. support functions
def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr


def do_CL_eval(X, Y, neg_Y, args):
    X = F.normalize(X, dim=-1)
    X = X.unsqueeze(1) # B, 1, d

    Y = Y.unsqueeze(0)
    Y = torch.cat([Y, neg_Y], dim=0) # T, B, d
    Y = Y.transpose(0, 1)  # B, T, d
    Y = F.normalize(Y, dim=-1)

    logits = torch.bmm(X, Y.transpose(1, 2)).squeeze()  # B*T
    B = X.size()[0]
    labels = torch.zeros(B).long().to(logits.device)  # B*1

    criterion = nn.CrossEntropyLoss()

    CL_loss = criterion(logits, labels)
    pred = logits.argmax(dim=1, keepdim=False)
    confidence = logits
    CL_conf = confidence.max(dim=1)[0]
    CL_conf = CL_conf.cpu().numpy()

    CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B
    return CL_loss, CL_conf, CL_acc


def get_text_repr(text):
    text_tokens_ids, text_masks = prepare_text_tokens(
        device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)
    text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)
    text_repr = text_output["pooler_output"]
    text_repr = text2latent(text_repr)
    return text_repr


@torch.no_grad()
def eval_epoch(dataloader):
    text_model.eval()
    molecule_model.eval()
    text2latent.eval()
    mol2latent.eval()

    accum_acc_list = [0 for _ in args.T_list]
    if args.verbose:
        L = tqdm(dataloader)
    else:
        L = dataloader
    for batch in L:
        text = batch[0]
        molecule_data = batch[1]
        neg_text = batch[2]
        neg_molecule_data = batch[3]

        text_repr = get_text_repr(text)

        molecule_data = molecule_data.to(device)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            molecule_data, mol2latent=mol2latent,
            molecule_type="Graph", molecule_model=molecule_model)

        if test_mode == "given_text":
            neg_molecule_repr = [
                get_molecule_repr_MoleculeSTM(
                    neg_molecule_data[idx].to(device), mol2latent=mol2latent,
                    molecule_type="Graph", molecule_model=molecule_model) for idx in range(T_max)
            ]
            neg_molecule_repr = torch.stack(neg_molecule_repr)
            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        elif test_mode == "given_molecule":
            neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)]
            neg_text_repr = torch.stack(neg_text_repr)
            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        else:
            raise Exception

    accum_acc_list = np.array(accum_acc_list)
    accum_acc_list /= len(dataloader)
    return accum_acc_list

In [None]:
# 8. text retrival
text_model = text_model.to(device)
molecule_model = molecule_model.to(device)
text2latent = text2latent.to(device)
mol2latent = mol2latent.to(device)

T_max = max(args.T_list) - 1

initial_test_acc_list = []
test_mode = args.test_mode
dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data")


dataset_class = DrugBank_Datasets_Graph_retrieval
dataloader_class = pyg_DataLoader
processed_dir_prefix = args.task

if args.task == "molecule_description":
    template = "SMILES_description_{}.txt"
elif args.task == "molecule_description_removed_PubChem":
    template = "SMILES_description_removed_from_PubChem_{}.txt"
elif args.task == "molecule_description_Raw":
    template = "SMILES_description_{}_Raw.txt"
elif args.task == "molecule_description_removed_PubChem_Raw":
    template = "SMILES_description_removed_from_PubChem_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics":
    template = "SMILES_pharmacodynamics_{}.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt"
elif args.task == "molecule_pharmacodynamics_Raw":
    template = "SMILES_pharmacodynamics_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt"

full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, processed_dir_prefix=processed_dir_prefix, template=template)
full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers

initial_test_acc_list = eval_epoch(full_dataloader)
print('Results', initial_test_acc_list)

```
Data: Data(x=[40309, 2], edge_index=[2, 85886], edge_attr=[85886, 2], id=[1168])
Index(['text', 'smiles'], dtype='object')
Loading negative samples from ../data/DrugBank_data/index/SMILES_description_full.txt
```

#### **2 Downstream: Zero-shot Text-based Molecule Editing**

In [None]:
# For graphs
! python downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py \
    --MoleculeSTM_molecule_type=Graph \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph


! python downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py \
    --MoleculeSTM_molecule_type=Graph \
    --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph \
    --language_edit_model_dir=../data/demo/demo_checkpoints_Graph \
    --input_description_id=101

#### **3 Downstream: Molecular Property Prediction**

In [None]:
import warnings
warnings.filterwarnings('ignore')
import sys
sys.path.append('/content/drive/MyDrive/MoleculeSTM')
import os
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score, mean_absolute_error, mean_squared_error

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader as torch_DataLoader
from torch_geometric.loader import DataLoader as pyg_DataLoader

from MoleculeSTM.datasets import MoleculeNetSMILESDataset, MoleculeNetGraphDataset
from MoleculeSTM.splitters import scaffold_split
from MoleculeSTM.utils import get_num_task_and_type, get_molecule_repr_MoleculeSTM
from MoleculeSTM.models import GNN, GNN_graphpred

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--training_mode", type=str, default="fine_tuning", choices=["fine_tuning", "linear_probing"])
parser.add_argument("--molecule_type", type=str, default="Graph", choices=["SMILES", "Graph"])

########## for dataset and split ##########
parser.add_argument("--dataspace_path", type=str, default="../data")
parser.add_argument("--dataset", type=str, default="bace")
parser.add_argument("--split", type=str, default="scaffold")

########## for optimization ##########
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--lr_scale", type=float, default=1)
parser.add_argument("--num_workers", type=int, default=1)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--schedule", type=str, default="cycle")
parser.add_argument("--warm_up_steps", type=int, default=10)

########## for 2D GNN ##########
parser.add_argument("--gnn_emb_dim", type=int, default=300)
parser.add_argument("--num_layer", type=int, default=5)
parser.add_argument('--JK', type=str, default='last')
parser.add_argument("--dropout_ratio", type=float, default=0.5)
parser.add_argument("--gnn_type", type=str, default="gin")
parser.add_argument('--graph_pooling', type=str, default='mean')

########## for saver ##########
parser.add_argument("--eval_train", type=int, default=0)
parser.add_argument("--verbose", type=int, default=1)

parser.add_argument("--input_model_path", type=str, default="/content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/molecule_model.pth")
parser.add_argument("--output_model_dir", type=str, default=None)

args = parser.parse_args("")
print("arguments\t", args)

arguments	 Namespace(seed=42, device=0, training_mode='fine_tuning', molecule_type='Graph', dataspace_path='../data', dataset='bace', split='scaffold', batch_size=32, lr=0.0001, lr_scale=1, num_workers=1, epochs=5, weight_decay=0, schedule='cycle', warm_up_steps=10, gnn_emb_dim=300, num_layer=5, JK='last', dropout_ratio=0.5, gnn_type='gin', graph_pooling='mean', eval_train=0, verbose=1, input_model_path='/content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/molecule_model.pth', output_model_dir=None)


In [None]:
# setup seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda:" + str(args.device)) \
    if torch.cuda.is_available() else torch.device("cpu")

In [None]:
num_tasks, task_mode = get_num_task_and_type(args.dataset)
dataset_folder = os.path.join(args.dataspace_path, "MoleculeNet_data", args.dataset)

dataset = MoleculeNetGraphDataset(dataset_folder, args.dataset)
dataloader_class = pyg_DataLoader
use_pyg_dataset = True

smiles_list = pd.read_csv(
    dataset_folder + "/processed/smiles.csv", header=None)[0].tolist()
train_dataset, valid_dataset, test_dataset = scaffold_split(
    dataset, smiles_list, null_value=0, frac_train=0.8,
    frac_valid=0.1, frac_test=0.1, pyg_dataset=use_pyg_dataset)


train_loader = dataloader_class(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
val_loader = dataloader_class(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader = dataloader_class(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)


In [None]:
def print_dataset_info(dataset):
    print(f"Dataset Name: {dataset.dataset}")
    print(f"Number of Graphs: {len(dataset)}")
    print(f"Transformations: {dataset.transform}")
    print(f"Pre-transformations: {dataset.pre_transform}")
    print(f"Pre-filters: {dataset.pre_filter}")
    print(f"Processed Path: {dataset.processed_paths}")
    print(f"Raw Files: {dataset.raw_file_names}")

print_dataset_info(train_dataset)

Dataset Name: bace
Number of Graphs: 1210
Transformations: None
Pre-transformations: None
Pre-filters: None
Processed Path: ['../data/MoleculeNet_data/bace/processed/geometric_data_processed.pt']
Raw Files: ['bace.csv']


In [None]:
molecule_node_model = GNN(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim,
    JK=args.JK, drop_ratio=args.dropout_ratio,
    gnn_type=args.gnn_type)
model = GNN_graphpred(
    num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling,
    num_tasks=1, molecule_node_model=molecule_node_model)
molecule_dim = args.gnn_emb_dim

if "GraphMVP" in args.input_model_path:
    print("Start from pretrained model (GraphMVP) in {}.".format(args.input_model_path))
    model.from_pretrained(args.input_model_path)
else:
    print("Start from pretrained model (MoleculeSTM) in {}.".format(args.input_model_path))
    state_dict = torch.load(args.input_model_path, map_location='cpu')
    model.load_state_dict(state_dict)


model = model.to(device)
linear_model = nn.Linear(molecule_dim, num_tasks).to(device)

# Rewrite the seed by MegaMolBART
torch.manual_seed(args.seed)
np.random.seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

Start from pretrained model (MoleculeSTM) in /content/drive/MyDrive/MoleculeSTM/data/demo/demo_checkpoints_Graph/molecule_model.pth.


In [None]:
print(f"model: {model}")
print(f"linear model: {linear_model}")

model: GNN_graphpred(
  (molecule_node_model): GNN(
    (atom_encoder): AtomEncoder(
      (atom_embedding_list): ModuleList(
        (0): Embedding(119, 300)
        (1): Embedding(4, 300)
        (2-3): 2 x Embedding(12, 300)
        (4): Embedding(10, 300)
        (5-6): 2 x Embedding(6, 300)
        (7-8): 2 x Embedding(2, 300)
      )
    )
    (gnns): ModuleList(
      (0-4): 5 x GINConv()
    )
    (batch_norms): ModuleList(
      (0-4): 5 x BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (graph_pred_linear): Linear(in_features=300, out_features=1, bias=True)
)
linear model: Linear(in_features=300, out_features=1, bias=True)


In [None]:
if args.training_mode == "fine_tuning":
    model_param_group = [
        {"params": model.parameters()},
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
else:
    model_param_group = [
        {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale}
    ]
optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay)
print(optimizer)

In [None]:
def train_classification(model, device, loader, optimizer):
    if args.training_mode == "fine_tuning":
        model.train()
    else:
        model.eval()
    linear_model.train()
    total_loss = 0

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        batch = batch.to(device)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            batch, mol2latent=None,
            molecule_type="Graph", molecule_model=model)
        pred = linear_model(molecule_repr)
        pred = pred.float()
        y = batch.y.view(pred.shape).to(device).float()

        is_valid = y ** 2 > 0
        loss_mat = criterion(pred, (y + 1) / 2)
        loss_mat = torch.where(
            is_valid, loss_mat,
            torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))

        optimizer.zero_grad()
        loss = torch.sum(loss_mat) / torch.sum(is_valid)
        loss.backward()
        optimizer.step()
        total_loss += loss.detach().item()

    return total_loss / len(loader)


@torch.no_grad()
def eval_classification(model, device, loader):
    model.eval()
    linear_model.eval()
    y_true, y_scores = [], []

    if args.verbose:
        L = tqdm(loader)
    else:
        L = loader
    for step, batch in enumerate(L):
        batch = batch.to(device)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            batch, mol2latent=None,
            molecule_type="Graph", molecule_model=model)
        pred = linear_model(molecule_repr)
        pred = pred.float()
        y = batch.y.view(pred.shape).to(device).float()

        y_true.append(y)
        y_scores.append(pred)

    y_true = torch.cat(y_true, dim=0).cpu().numpy()
    y_scores = torch.cat(y_scores, dim=0).cpu().numpy()

    roc_list = []
    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:
            is_valid = y_true[:, i] ** 2 > 0
            roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))
        else:
            print("{} is invalid".format(i))

    if len(roc_list) < y_true.shape[1]:
        print(len(roc_list))
        print("Some target is missing!")
        print("Missing ratio: %f" %(1 - float(len(roc_list)) / y_true.shape[1]))

    return sum(roc_list) / len(roc_list), 0, y_true, y_scores

In [None]:
train_func = train_classification
eval_func = eval_classification

train_roc_list, val_roc_list, test_roc_list = [], [], []
train_acc_list, val_acc_list, test_acc_list = [], [], []
best_val_roc, best_val_idx = -1, 0
criterion = nn.BCEWithLogitsLoss(reduction="none")

for epoch in range(1, args.epochs + 1):
    loss_acc = train_func(model, device, train_loader, optimizer)
    print("Epoch: {}\nLoss: {}".format(epoch, loss_acc))

    if args.eval_train:
        train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader)
    else:
        train_roc = train_acc = 0
    val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader)
    test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader)

    train_roc_list.append(train_roc)
    train_acc_list.append(train_acc)
    val_roc_list.append(val_roc)
    val_acc_list.append(val_acc)
    test_roc_list.append(test_roc)
    test_acc_list.append(test_acc)
    print("train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc, val_roc, test_roc))
    print()

print("best train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx]))


Get pre-trained model

In [14]:
# WARNING: these files are huge
!git lfs install
!git clone https://huggingface.co/chao1224/MoleculeSTM

Git LFS initialized.
^C


In [52]:
PATH_TO_MOLECULE_MODEL = "pretrainedModels/molecule_model.pth"
PATH_TO_TEXT_MODEL = "pretrainedModels/text_model.pth"
PATH_TO_TEXT2LATENT_MODEL = "pretrainedModels/text2latent_model.pth"
PATH_TO_MOL2LATENT_MODEL = "pretrainedModels/mol2latent_model.pth"

if torch.cuda.is_available():
    molecule_model_loaded = torch.load(PATH_TO_MOLECULE_MODEL)
    text_model_loaded = torch.load(PATH_TO_TEXT_MODEL)
    text2latent_loaded = torch.load(PATH_TO_TEXT2LATENT_MODEL)
    mol2latent_loaded = torch.load(PATH_TO_MOL2LATENT_MODEL)
else:
    molecule_model_loaded = torch.load(PATH_TO_MOLECULE_MODEL, map_location=torch.device('cpu'))
    text_model_loaded = torch.load(PATH_TO_TEXT_MODEL, map_location=torch.device('cpu'))
    text2latent_loaded = torch.load(PATH_TO_TEXT2LATENT_MODEL, map_location=torch.device('cpu'))
    mol2latent_loaded = torch.load(PATH_TO_MOL2LATENT_MODEL, map_location=torch.device('cpu'))

molecule_model = GNN_graphpred()
molecule_model.load_state_dict(molecule_model_loaded)
# TODO: add other models here

##   Evaluation

### Metric descriptions

This evaluation uses several different metrics:


*   Contrastive loss - measuring the the performance of the model in correctly identifying true matches from a set of possible matches
*   Accuracy - determines the proportion of correctly labeled predictive matches
*   Confidence Scores - determines a measure of confidence for each prediction


Additionally, the code uses a variety of negative samples to help measure the model's resistence to varying conditions




### Implementation code

#### Support Functions and Setup

In [21]:
# from demo downstream retrieval Graph
def cycle_index(num, shift):
    arr = torch.arange(num) + shift
    arr[-shift:] = torch.arange(shift)
    return arr


def do_CL_eval(X, Y, neg_Y, args):
    X = F.normalize(X, dim=-1)
    X = X.unsqueeze(1) # B, 1, d

    Y = Y.unsqueeze(0)
    Y = torch.cat([Y, neg_Y], dim=0) # T, B, d
    Y = Y.transpose(0, 1)  # B, T, d
    Y = F.normalize(Y, dim=-1)

    logits = torch.bmm(X, Y.transpose(1, 2)).squeeze()  # B*T
    B = X.size()[0]
    labels = torch.zeros(B).long().to(logits.device)  # B*1

    criterion = nn.CrossEntropyLoss()

    CL_loss = criterion(logits, labels)
    pred = logits.argmax(dim=1, keepdim=False)
    confidence = logits
    CL_conf = confidence.max(dim=1)[0]
    CL_conf = CL_conf.cpu().numpy()

    CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B
    return CL_loss, CL_conf, CL_acc


def get_text_repr(text):
    text_tokens_ids, text_masks = prepare_text_tokens(
        device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)
    text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)
    text_repr = text_output["pooler_output"]
    text_repr = text2latent(text_repr)
    return text_repr


@torch.no_grad()
def eval_epoch(dataloader):
    text_model.eval()
    molecule_model.eval()
    text2latent.eval()
    mol2latent.eval()

    accum_acc_list = [0 for _ in args.T_list]
    if args.verbose:
        L = tqdm(dataloader)
    else:
        L = dataloader
    for batch in L:
        text = batch[0]
        molecule_data = batch[1]
        neg_text = batch[2]
        neg_molecule_data = batch[3]

        text_repr = get_text_repr(text)

        molecule_data = molecule_data.to(device)
        molecule_repr = get_molecule_repr_MoleculeSTM(
            molecule_data, mol2latent=mol2latent,
            molecule_type="Graph", molecule_model=molecule_model)

        if test_mode == "given_text":
            neg_molecule_repr = [
                get_molecule_repr_MoleculeSTM(
                    neg_molecule_data[idx].to(device), mol2latent=mol2latent,
                    molecule_type="Graph", molecule_model=molecule_model) for idx in range(T_max)
            ]
            neg_molecule_repr = torch.stack(neg_molecule_repr)
            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        elif test_mode == "given_molecule":
            neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)]
            neg_text_repr = torch.stack(neg_text_repr)
            for T_idx, T in enumerate(args.T_list):
                _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args)
                accum_acc_list[T_idx] += acc
        else:
            raise Exception

    accum_acc_list = np.array(accum_acc_list)
    accum_acc_list /= len(dataloader)
    return accum_acc_list

#### Eval

In [49]:
# from demo downstream retrieval Graph
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
text_model = text_model.to(device)
molecule_model = molecule_model.to(device)
text2latent = text2latent.to(device)
mol2latent = mol2latent.to(device)

T_max = max(args.T_list) - 1

initial_test_acc_list = []
test_mode = args.test_mode
dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data")


dataset_class = DrugBank_Datasets_Graph_retrieval
dataloader_class = pyg_DataLoader
processed_dir_prefix = args.task

if args.task == "molecule_description":
    template = "SMILES_description_{}.txt"
elif args.task == "molecule_description_removed_PubChem":
    template = "SMILES_description_removed_from_PubChem_{}.txt"
elif args.task == "molecule_description_Raw":
    template = "SMILES_description_{}_Raw.txt"
elif args.task == "molecule_description_removed_PubChem_Raw":
    template = "SMILES_description_removed_from_PubChem_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics":
    template = "SMILES_pharmacodynamics_{}.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt"
elif args.task == "molecule_pharmacodynamics_Raw":
    template = "SMILES_pharmacodynamics_{}_Raw.txt"
elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw":
    template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt"

full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, processed_dir_prefix=processed_dir_prefix, template=template)
full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers

initial_test_acc_list = eval_epoch(full_dataloader)

AttributeError: 'collections.OrderedDict' object has no attribute 'to'

## Results

## Discussion

### Problems & Solutions

Problem: We had great difficulty setting up the environment. Specifically, much of the instructions in the README were outdated and did not work as intended. The authors of this paper did not share the versioning of many of their python dependedencies.

Solution: Despite testing countless different package versions and environments, we were unable to solve this problem. In an effort to continue with our hypothesis we decided to remove some dependencies (and its related code) from the project

---

Problem: After adjusting the code for our modified environment, we ultimately found ourselves unable to run the training for the model.

Solution: First, we tried to debug this error by looking into the code ourselves. Unfortunately, we were not able to determine the solution for this issue. To attempt to try and resolve this we opened an issue on Git to hopefully get some resolution from the author of this repository. Fortunately, we were able to find a pretrained model from the author on hugging face to allow us to continue with our hypothesis 