### How to run this file

Create a virtual environment:

`py -3.9 -m venv venv`


Activate the environment 

`.\venv\Scripts\activate`

Install working version of dgl

`pip install dgl-2.0.0-cp39-cp39-win_amd64.whl`

Install openpom

`pip install openpom`


In [39]:
#!pip install openpom
#!pip install  dgl -f https://data.dgl.ai/wheels/torch-2.4/repo.html
import deepchem as dc
from openpom.feat.graph_featurizer import GraphFeaturizer, GraphConvConstants
from openpom.utils.data_utils import get_class_imbalance_ratio
#from openpom.models.mpnn_pom import MPNNPOMModel
from datetime import datetime

In [40]:
from pom_models.models import MPNNPOMModel
import pandas as pd

In [41]:
#import deepchem as dc
#from openpom.feat.graph_featurizer import GraphFeaturizer

# Define the SMILES string to predict
single_smiles = 'COC1=C(C=CC(=C1)C=O)O'  # Replace 'CCO' with the actual SMILES string

# Featurize the SMILES string
featurizer = GraphFeaturizer()
single_molecule = featurizer.featurize([single_smiles])

# Create a NumpyDataset for the single molecule
single_dataset = dc.data.NumpyDataset(single_molecule)

The train_ratios parameter contains a list tacking the occurence of each fragance note compared with the most common note. The most common one has a ratio of 1 all other values are calculated by dividing the amount of molecules with a certain fragrance note with the amount of molecules with the most common note.

It is passed to MPNNPOMModel


```python
class_imbalance_ratio: Optional[List]
            List of imbalance ratios per task.
```

The ratios are saved in `train_ratios.csv`. 

In [42]:
df = pd.read_csv("train_ratios.csv", index_col=0)
df.head()

Unnamed: 0,train_ratios
0,0.055848
1,0.074244
2,0.065703
3,0.044021
4,0.052562


In [43]:
train_ratios = list(df["train_ratios"])

In [44]:
sum(train_ratios)

12.64454664914586

In [45]:
# initialize model

model = MPNNPOMModel(n_tasks = 138,
                            batch_size=128,
                            learning_rate=1e-4,
                            class_imbalance_ratio = train_ratios,
                            loss_aggr_type = 'sum',
                            node_out_feats = 100,
                            edge_hidden_feats = 75,
                            edge_out_feats = 100,
                            num_step_message_passing = 5,
                            mpnn_residual = True,
                            message_aggregator_type = 'sum',
                            mode = 'classification',
                            number_atom_features = GraphConvConstants.ATOM_FDIM,
                            number_bond_features = GraphConvConstants.BOND_FDIM,
                            n_classes = 1,
                            readout_type = 'set2set',
                            num_step_set2set = 3,
                            num_layer_set2set = 2,
                            ffn_hidden_list= [392, 392],
                            ffn_embeddings = 256,
                            ffn_activation = 'relu',
                            ffn_dropout_p = 0.12,
                            ffn_dropout_at_input_no_act = False,
                            weight_decay = 1e-5,
                            self_loop = False,
                            optimizer_name = 'adam',
                            log_frequency = 32,
                            model_dir = 'data',
                            device_name='cpu')

In [46]:
# Restore the model from the checkpoint
model.restore()

# Predict the probabilities for the single molecule
predicted_probabilities = model.predict(single_dataset)

# Output the predicted probabilities
print("Predicted Probabilities for the 138 notes:", predicted_probabilities)

Predicted Probabilities for the 138 notes: [[0.04575704 0.04282943 0.05341665 0.28640267 0.0329218  0.20288868
  0.36172265 0.02111001 0.04237433 0.317042   0.4279455  0.03573177
  0.03719727 0.02371974 0.13851589 0.32038516 0.0289441  0.01613994
  0.32080555 0.09608915 0.03982432 0.08091616 0.27430353 0.02159856
  0.07350693 0.03478024 0.05178089 0.25843382 0.11839368 0.21073028
  0.06080894 0.06195458 0.34229103 0.06232229 0.146159   0.16066001
  0.01784373 0.0504692  0.09188339 0.07404196 0.28687137 0.35412017
  0.00835921 0.10641305 0.16819969 0.1959213  0.18295585 0.06957997
  0.05369798 0.06176599 0.44382554 0.099797   0.0437536  0.34240264
  0.04317852 0.04487146 0.06809818 0.1567298  0.02871715 0.03830222
  0.12342923 0.30967513 0.28897983 0.04847127 0.24057142 0.11724697
  0.0538983  0.05945813 0.02251486 0.02622347 0.09840045 0.01832962
  0.05089311 0.19233459 0.02440911 0.03873232 0.08897623 0.10484001
  0.42817798 0.06610651 0.13504575 0.09801025 0.06511155 0.01436159
  0.0

  data = torch.load(checkpoint, map_location=self.device)
