We will be training an adapter on the language model `[CLS]` token embedding. 

In [1]:
#%pip install pandas torch torch-geometric matplotlib scikit-learn numpy

We'll define our hyperparameters up top

In [2]:
import torch
import random
import numpy as np
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

BATCH_SIZE = 128
NORMALIZE_TARGETS = False

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda:0"

To start off, we need to load the dataset. 

In [3]:
#More stupid python stuff
import sys
import os
root_dir = os.path.abspath(os.path.join(os.getcwd(), "..")) 
sys.path.append(root_dir)

from src.data.make_clsdataset import QM9CLS
from src.data.targets_dataset import TargetsDataset
import pandas as pd

data_dir = os.path.join(os.path.dirname(os.getcwd()), "data")

molecule_cls = QM9CLS("../data/etc/")
targets = TargetsDataset("../data/custom_qm9/raw", "../data/etc", normalize=True)

Loading embeddings from cached file...


In [4]:
import torch

print(f"[CLS] Embedding Size: {molecule_cls[0].size()}")
print(f"First Molecule Target Tensor: {targets[0]}")
print(f"[CLS] Dataset Size: {torch.stack(molecule_cls.embeddings).size()}")
print(f"Target Dataset Size: {targets.targets.size()}")

[CLS] Embedding Size: torch.Size([768])
First Molecule Target Tensor: tensor([ 1.5771e+02,  1.5771e+02,  1.5771e+02,  0.0000e+00,  1.3210e+01,
        -3.8770e-01,  1.1710e-01,  5.0480e-01,  3.5364e+01,  4.4749e-02,
        -4.0479e+01, -4.0476e+01, -4.0475e+01, -4.0499e+01,  6.4690e+00,
        -3.9600e+02, -3.9864e+02, -4.0101e+02, -3.7247e+02])
[CLS] Dataset Size: torch.Size([130831, 768])
Target Dataset Size: torch.Size([130831, 19])


So now we have inputs `molecule_cls` and targets for regression `targets`. Note that we currently have `targets` as normalized between $-1$ and $1$. This could be changed later if neccessary. 

In [5]:
from torch.utils.data import Dataset, TensorDataset, DataLoader, Subset


dataset = TensorDataset(torch.stack(molecule_cls.embeddings), targets.targets)
# train_size = int(len(dataset) * 0.8) # Use 80% of the dataset as a training dataset and leave the remaining 20% for testing
# valid_size = int(len(dataset) * 0.1)
split_indices_path = "../data/etc/split_idxs.pt"


assert(os.path.exists(split_indices_path)) # Split indices are generated by HydraGNN, so if they don't exist it should fail

train_indices, valid_indices, test_indices = torch.load(split_indices_path)


target_values = targets.targets[:,3] # Dipole moment $\mu$

# Normalize the target values

if NORMALIZE_TARGETS: 
    normalized_targets = (target_values - target_values.min()) / (target_values.max() - target_values.min())
else:
    normalized_targets = target_values
print(normalized_targets)

dataset = TensorDataset(torch.stack(molecule_cls.embeddings), normalized_targets)

#Create train and test datasets
train_dataset = Subset(dataset, train_indices)
valid_dataset = Subset(dataset, valid_indices)
test_dataset = Subset(dataset, test_indices)

print(len(train_dataset), len(valid_dataset), len(test_dataset))

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# print(next(iter(test_dataloader)))

tensor([0.0000, 1.6256, 1.8511,  ..., 1.2480, 1.9576, 0.8626])
91581 19624 19626


In [6]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Now, we need to train. We will use a neural network with 4 hidden layers with LeakyReLU activations and no output activation. We will use the AdamW optimizer and the MSE Loss

In [7]:
from tqdm import tqdm 
from math import sqrt
from torch import nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
import ray
from ray import train, tune
from ray.tune.search import ConcurrencyLimiter
# from ray.tune.search.bayesopt import BayesOptSearch
from ray.tune.search.hyperopt import HyperOptSearch
from torch.nn.functional import mse_loss, l1_loss
from ray.tune.schedulers import ASHAScheduler
# from ray.tune.search.bohb import TuneBOHB
# from ray.tune.schedulers.hb_bohb import HyperBandForBOHB

ray.init(num_gpus=1, num_cpus=24)

search_space = {
    'scheduler_factor': tune.loguniform(0.1, 0.9),
    'scheduler_patience': tune.randint(3, 10),
    'epochs': 200,
    'lr': tune.loguniform(1e-6, 1e-1),
    'early_stopping_patience': 10,
    'print': False,
    'dropout': tune.uniform(0.1, 0.5)
}

# intial_config = {
#     'epochs' : 100,
#     'lr' : 1e-3,
#     'scheduler_factor' : 0.5,
#     'scheduler_patience' : 5,
#     'early_stopping_patience' : 10,
#     'print': False,
#     'dropout': 0.2,
# }

intial_config = {'scheduler_factor': 0.27298185135469305, 'scheduler_patience': 8, 'epochs': 200, 'lr': 0.00136041146487004, 'early_stopping_patience': 10, 'print': False, 'dropout': 0.18660077210645098}

rtrain_dataloader = ray.put(train_dataloader)
rvalid_dataloader = ray.put(valid_dataloader)
rtest_dataloader  = ray.put( test_dataloader)

def run(config, return_model=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_dataloader = ray.get(rtrain_dataloader)
    valid_dataloader = ray.get(rvalid_dataloader)
    test_dataloader = ray.get(rtest_dataloader)

    model = nn.Sequential(
        nn.Linear(768, 500),
        nn.BatchNorm1d(500),
        nn.LeakyReLU(),
        nn.Dropout(p=config['dropout']),  # Add dropout for regularization
        nn.Linear(500, 500),
        nn.BatchNorm1d(500),
        nn.LeakyReLU(),
        nn.Dropout(p=config['dropout']),
        nn.Linear(500, 500),
        nn.BatchNorm1d(500),
        nn.LeakyReLU(),
        nn.Dropout(p=config['dropout']),
        nn.Linear(500, 1),
        # nn.ReLU()
    ).to(device)

    optimizer = AdamW(model.parameters(), lr=config['lr'])
    loss_function = l1_loss #nn.HuberLoss(delta=0.3) #nn.L1Loss() #nn.SmoothL1Loss() #torch.nn.MSELoss()
    MAE_Loss = l1_loss

    scheduler = ReduceLROnPlateau(optimizer, 'min', factor=config['scheduler_factor'], patience=config['scheduler_patience'], verbose=True) 

    best_valid_loss = float('inf')
    for i, epoch in enumerate(range(config['epochs'])):
        if config['print']:
            print(f"----Epoch {i+1}----")
        # Training phase
        model.train()  # Set model to training mode
        train_loss = 0.0

        for batch_idx, (inputs, _targets) in enumerate(tqdm(train_dataloader)):
            inputs, _targets = inputs.to(device), _targets.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Calculate loss
            loss = loss_function(outputs.squeeze(1), _targets)

            # Backward pass and optimization
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

            train_loss += loss.item() * inputs.size(0) # Accumulate batch loss
        train_loss /= len(train_dataloader.dataset) 

        model.eval()  # Set model to eval mode
        with torch.no_grad():
            valid_loss = 0.0
            valid_loss_mae = 0.0
            for batch_idx, (inputs, _targets) in enumerate(tqdm(valid_dataloader)): 
                inputs, _targets = inputs.to(device), _targets.to(device)
                outputs = model(inputs)

                # print(outputs[0], outputs[1])
                # Calculate loss
                loss = loss_function(outputs.squeeze(1), _targets)
                loss_mae = MAE_Loss(outputs.squeeze(1), _targets)

                valid_loss += loss.item() * inputs.size(0)
                valid_loss_mae += loss_mae.item() * inputs.size(0) 

            valid_loss /= len(test_dataloader.dataset)
            valid_loss_mae /= len(test_dataloader.dataset)
        
        if config['print']:
            with torch.no_grad():
                model.eval()  # Set model to eval mode
                test_loss = 0.0
                test_loss_mae = 0.0
                for batch_idx, (inputs, _targets) in enumerate(tqdm(test_dataloader)): 
                    inputs, _targets = inputs.to(device), _targets.to(device)
                    outputs = model(inputs)
                    # print(outputs[0], outputs[1])
                    # Calculate loss
                    loss = loss_function(outputs.squeeze(1), _targets)
                    loss_mae = MAE_Loss(outputs.squeeze(1), _targets)
                    test_loss += loss.item() * inputs.size(0)
                    test_loss_mae += loss_mae.item() * inputs.size(0) 

                test_loss /= len(test_dataloader.dataset)
                test_loss_mae /= len(test_dataloader.dataset)

        train.report({
            'val_loss': valid_loss,
        })

        #Update LR based on test_loss
        scheduler.step(valid_loss)

        if config['print']:
            print(f"Epoch {epoch+1} - Learning rate: {scheduler.optimizer.param_groups[0]['lr']:.6f}") 
            print(f'Train Loss (MSE): {train_loss:.6f}, Validation Loss (MSE): {valid_loss:.6f}, Test Loss (MSE): {test_loss:.6f}')
            print(f'Validation RMSE: {sqrt(valid_loss):.6f}, Test RMSE: {sqrt(test_loss):.6f}')
            print(f'Validation MAE: {valid_loss_mae:.6f}, Test MAE: {test_loss_mae:.6f}')

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= config['early_stopping_patience']:
                if config['print']:
                    print(f"Early stopping at epoch {epoch+1}")
                break 

        model.train()

    if return_model:
        return model

# algo = BayesOptSearch(utility_kwargs={"kind": "ucb", "kappa": 2.5, "xi": 0.0})


# algo = TuneBOHB(metric="mean_loss", mode="min")
# bohb = HyperBandForBOHB(
#     time_attr="training_iteration",
#     # metric="mean_loss",
#     mode="min",
#     max_t=100)



algo = HyperOptSearch(
    points_to_evaluate=[intial_config],
    random_state_seed = 0
)
algo = ConcurrencyLimiter(algo, max_concurrent=10)

asha_scheduler = ASHAScheduler(
    # metric="val_loss",  # The metric to use for early stopping
    # mode="min",         # Whether to minimize or maximize the metric
    max_t=100,          # Maximum number of epochs/iterations per trial (adjust as needed)
    grace_period=10     # Minimum number of epochs/iterations before stopping a trial 
)

tuner = tune.Tuner(
    tune.with_resources(run, {"gpu": 0.1}),
    tune_config=tune.TuneConfig(
        metric="val_loss",
        mode="min",
        search_alg=algo,
        num_samples=100,
        scheduler=asha_scheduler,
    ),
    param_space=search_space,
)

if htune:=True:
    results = tuner.fit()
    best_config = results.get_best_result().config
    print(best_config)
    best_config["print"] = True
    model = run(best_config, return_model=True)
else:
    intial_config["print"] = True
    model = run(intial_config, True)

0,1
Current time:,2024-06-25 12:57:27
Running for:,00:02:22.33
Memory:,22.6/31.1 GiB

Trial name,status,loc,dropout,early_stopping_patie nce,epochs,lr,print,scheduler_factor,scheduler_patience,iter,total time (s),val_loss
run_6be99b29,RUNNING,10.158.10.3:396710,0.186601,10,200,0.00136041,False,0.272982,8,6,138.403,0.756407
run_8d803cf7,RUNNING,10.158.10.3:396782,0.332614,10,200,0.00584849,False,0.214712,5,5,120.158,0.746546
run_4b1ed4f1,RUNNING,10.158.10.3:397018,0.366719,10,200,1.26297e-06,False,0.433124,7,5,123.762,2.05072
run_e48f1279,RUNNING,10.158.10.3:397090,0.139634,10,200,0.0611894,False,0.255693,6,5,120.725,0.782782
run_1dc562ca,RUNNING,10.158.10.3:397160,0.393408,10,200,0.00217462,False,0.119198,5,5,130.042,0.769241
run_2957f8f7,RUNNING,10.158.10.3:397238,0.25975,10,200,0.00160665,False,0.452572,7,4,105.224,0.759946
run_aeaeb00b,RUNNING,10.158.10.3:397310,0.427803,10,200,0.000849773,False,0.258156,6,4,103.075,0.772512
run_d1146878,RUNNING,10.158.10.3:397385,0.144986,10,200,0.00156018,False,0.153462,8,4,108.775,0.749037
run_27dc8b92,RUNNING,10.158.10.3:397462,0.314872,10,200,8.39948e-06,False,0.212626,8,4,109.756,1.07991
run_48fc6c88,RUNNING,10.158.10.3:397539,0.428012,10,200,1.88625e-06,False,0.165193,8,4,110.129,1.91472


  0%|          | 0/716 [00:00<?, ?it/s]
  1%|          | 7/716 [00:00<00:10, 67.98it/s]
  2%|▏         | 16/716 [00:00<00:08, 77.78it/s]
  3%|▎         | 24/716 [00:00<00:08, 78.72it/s]
  5%|▍         | 33/716 [00:00<00:08, 79.05it/s]
  6%|▌         | 41/716 [00:00<00:08, 78.77it/s]
  7%|▋         | 49/716 [00:00<00:08, 78.31it/s]
  8%|▊         | 57/716 [00:00<00:08, 78.06it/s]
  9%|▉         | 65/716 [00:00<00:08, 77.92it/s]
 10%|█         | 73/716 [00:00<00:08, 77.83it/s]
 11%|█▏        | 81/716 [00:01<00:08, 77.93it/s]
 12%|█▏        | 89/716 [00:01<00:08, 77.63it/s]
 14%|█▎        | 97/716 [00:01<00:07, 77.57it/s]
 15%|█▍        | 105/716 [00:01<00:07, 76.89it/s]
 16%|█▌        | 113/716 [00:01<00:07, 76.47it/s]
 17%|█▋        | 121/716 [00:01<00:07, 77.47it/s]
 18%|█▊        | 130/716 [00:01<00:07, 78.92it/s]
 19%|█▉        | 138/716 [00:01<00:07, 78.95it/s]
 21%|██        | 147/716 [00:01<00:07, 79.64it/s]
 22%|██▏       | 156/716 [00:01<00:06, 80.36it/s]
 23%|██▎       | 165/71

{'scheduler_factor': 0.21471150120190907, 'scheduler_patience': 5, 'epochs': 200, 'lr': 0.0058484902526790995, 'early_stopping_patience': 10, 'print': False, 'dropout': 0.33261411259331913}




----Epoch 1----


100%|██████████| 716/716 [00:03<00:00, 191.59it/s]
100%|██████████| 154/154 [00:00<00:00, 547.76it/s]
100%|██████████| 154/154 [00:00<00:00, 564.45it/s]


Epoch 1 - Learning rate: 0.005848
Train Loss (MSE): 0.924686, Validation Loss (MSE): 0.832809, Test Loss (MSE): 0.815772
Validation RMSE: 0.912584, Test RMSE: 0.903201
Validation MAE: 0.832809, Test MAE: 0.815772
----Epoch 2----


100%|██████████| 716/716 [00:03<00:00, 186.24it/s]
100%|██████████| 154/154 [00:00<00:00, 594.49it/s]
100%|██████████| 154/154 [00:00<00:00, 597.23it/s]


Epoch 2 - Learning rate: 0.005848
Train Loss (MSE): 0.830828, Validation Loss (MSE): 0.787819, Test Loss (MSE): 0.771664
Validation RMSE: 0.887591, Test RMSE: 0.878444
Validation MAE: 0.787819, Test MAE: 0.771664
----Epoch 3----


100%|██████████| 716/716 [00:03<00:00, 192.41it/s]
100%|██████████| 154/154 [00:00<00:00, 566.42it/s]
100%|██████████| 154/154 [00:00<00:00, 600.06it/s]


Epoch 3 - Learning rate: 0.005848
Train Loss (MSE): 0.797628, Validation Loss (MSE): 0.764798, Test Loss (MSE): 0.744383
Validation RMSE: 0.874527, Test RMSE: 0.862776
Validation MAE: 0.764798, Test MAE: 0.744383
----Epoch 4----


100%|██████████| 716/716 [00:03<00:00, 192.14it/s]
100%|██████████| 154/154 [00:00<00:00, 587.95it/s]
100%|██████████| 154/154 [00:00<00:00, 584.64it/s]


Epoch 4 - Learning rate: 0.005848
Train Loss (MSE): 0.784890, Validation Loss (MSE): 0.756087, Test Loss (MSE): 0.735102
Validation RMSE: 0.869533, Test RMSE: 0.857381
Validation MAE: 0.756087, Test MAE: 0.735102
----Epoch 5----


100%|██████████| 716/716 [00:03<00:00, 189.57it/s]
100%|██████████| 154/154 [00:00<00:00, 486.63it/s]
100%|██████████| 154/154 [00:00<00:00, 484.35it/s]


Epoch 5 - Learning rate: 0.005848
Train Loss (MSE): 0.770092, Validation Loss (MSE): 0.747416, Test Loss (MSE): 0.728551
Validation RMSE: 0.864532, Test RMSE: 0.853552
Validation MAE: 0.747416, Test MAE: 0.728551
----Epoch 6----


100%|██████████| 716/716 [00:03<00:00, 190.08it/s]
100%|██████████| 154/154 [00:00<00:00, 588.77it/s]
100%|██████████| 154/154 [00:00<00:00, 590.78it/s]


Epoch 6 - Learning rate: 0.005848
Train Loss (MSE): 0.763291, Validation Loss (MSE): 0.747655, Test Loss (MSE): 0.727291
Validation RMSE: 0.864670, Test RMSE: 0.852813
Validation MAE: 0.747655, Test MAE: 0.727291
----Epoch 7----


100%|██████████| 716/716 [00:03<00:00, 185.58it/s]
100%|██████████| 154/154 [00:00<00:00, 589.26it/s]
100%|██████████| 154/154 [00:00<00:00, 588.52it/s]


Epoch 7 - Learning rate: 0.005848
Train Loss (MSE): 0.756729, Validation Loss (MSE): 0.745067, Test Loss (MSE): 0.727742
Validation RMSE: 0.863173, Test RMSE: 0.853078
Validation MAE: 0.745067, Test MAE: 0.727742
----Epoch 8----


100%|██████████| 716/716 [00:03<00:00, 188.32it/s]
100%|██████████| 154/154 [00:00<00:00, 531.98it/s]
100%|██████████| 154/154 [00:00<00:00, 595.50it/s]


Epoch 8 - Learning rate: 0.005848
Train Loss (MSE): 0.750168, Validation Loss (MSE): 0.740908, Test Loss (MSE): 0.724145
Validation RMSE: 0.860760, Test RMSE: 0.850967
Validation MAE: 0.740908, Test MAE: 0.724145
----Epoch 9----


100%|██████████| 716/716 [00:03<00:00, 185.21it/s]
100%|██████████| 154/154 [00:00<00:00, 593.12it/s]
100%|██████████| 154/154 [00:00<00:00, 581.23it/s]


Epoch 9 - Learning rate: 0.005848
Train Loss (MSE): 0.742819, Validation Loss (MSE): 0.730204, Test Loss (MSE): 0.713090
Validation RMSE: 0.854520, Test RMSE: 0.844447
Validation MAE: 0.730204, Test MAE: 0.713090
----Epoch 10----


100%|██████████| 716/716 [00:03<00:00, 191.49it/s]
100%|██████████| 154/154 [00:00<00:00, 596.34it/s]
100%|██████████| 154/154 [00:00<00:00, 591.67it/s]


Epoch 10 - Learning rate: 0.005848
Train Loss (MSE): 0.738606, Validation Loss (MSE): 0.715061, Test Loss (MSE): 0.702610
Validation RMSE: 0.845613, Test RMSE: 0.838218
Validation MAE: 0.715061, Test MAE: 0.702610
----Epoch 11----


100%|██████████| 716/716 [00:03<00:00, 187.33it/s]
100%|██████████| 154/154 [00:00<00:00, 568.63it/s]
100%|██████████| 154/154 [00:00<00:00, 445.45it/s]


Epoch 11 - Learning rate: 0.005848
Train Loss (MSE): 0.733648, Validation Loss (MSE): 0.713654, Test Loss (MSE): 0.699713
Validation RMSE: 0.844781, Test RMSE: 0.836488
Validation MAE: 0.713654, Test MAE: 0.699713
----Epoch 12----


  0%|          | 0/716 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), '../models/lms/fine_tuned_qm9/good_model.pth')

Now, lets plot a scatter plot for the test dataset and the $R^2$ value.

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
from scipy.stats import spearmanr


model.eval()  # Set model to eval mode
outputs_list = []

y_true = normalized_targets[test_indices]
with torch.no_grad():
    model.eval()  # Set model to eval mode
    test_loss = 0.0
    test_loss_mae = 0.0
    for batch_idx, (inputs, _targets) in enumerate(tqdm(test_dataloader)): 
        inputs, _targets = inputs.to(device), _targets.to(device)
        outputs = model(inputs)
        # print(outputs[0], outputs[1])
        # Calculate loss
        loss = mse_loss(outputs.squeeze(1), _targets)
        loss_mae = l1_loss(outputs.squeeze(1), _targets)
        test_loss += loss.item() * inputs.size(0)
        test_loss_mae += loss_mae.item() * inputs.size(0) 

    test_loss /= len(test_dataloader.dataset)
    test_loss_mae /= len(test_dataloader.dataset)

    print(f'Test Loss (MSE): {test_loss:.6f}')
    print(f'Test RMSE: {sqrt(test_loss):.6f}')
    print(f'Test MAE: {test_loss_mae:.6f}')


    y_pred = model(torch.stack(molecule_cls.embeddings)[test_indices].to(device)).to(torch.device("cpu")).cpu().numpy()
# print(y_pred.size())

r2 = r2_score(y_true.cpu().numpy(), y_pred)
src, _ = spearmanr(y_true.cpu().numpy(), y_pred)

# Create the plot
plt.figure(figsize=(20, 20)) 
plt.scatter(y_true, y_pred, alpha=0.7)  # Plot the data points

# Add labels and title
plt.xlabel("Actual Dipole Moment (Debye)", fontsize=12)
plt.ylabel("Predicted Dipole Moment (Debye)", fontsize=12)
plt.title("Dipole Moment Prediction - Language Model (CLS Token)", fontsize=14)

# Add R-squared value to the plot
plt.text(0.05, 0.9, f"$R^2$ = {r2:.3f}", fontsize=12, transform=plt.gca().transAxes)
plt.text(0.05, 0.85, f"$r_s$ = {src:.3f}", fontsize=12, transform=plt.gca().transAxes)

# Add a diagonal line for reference (perfect prediction)
plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], linestyle='--', color='red')

plt.grid(True) 
plt.show()


In [None]:
%load_ext tensorboard