In [None]:
import math
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from sklearn.metrics import mean_absolute_error, root_mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA

import torch
from torch import nn
import torch.nn.functional as F

import torch_geometric as tg
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MessagePassing, global_add_pool

import optuna

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Datasets and Dataloaders

### Feature Extraction

In [None]:
def one_hot_encoding_unk(value, choices: list) -> list:
    # One hot encoding with unknown value handling
    # If the value is in choices, it puts a 1 at the corresponding index
    # Otherwise, it puts a 1 at the last index (unknown)
    encoding = [0] * (len(choices) + 1)
    index = choices.index(value) if value in choices else -1
    encoding[index] = 1
    return encoding


def get_atom_features(atom) -> list:
    # Returns a feature list for the atom
    # Concatenates the one-hot encodings into a single list
    atom_features = [
        one_hot_encoding_unk(atom.GetSymbol(), ['B','Be','Br','C','Cl','F','I','N','Nb','O','P','S','Se','Si','V','W']),
        one_hot_encoding_unk(atom.GetTotalDegree(), [0, 1, 2, 3, 4, 5]),
        one_hot_encoding_unk(atom.GetFormalCharge(), [-1, -2, 1, 2, 0]),
        one_hot_encoding_unk(int(atom.GetTotalNumHs()), [0, 1, 2, 3, 4]),
        one_hot_encoding_unk(int(atom.GetHybridization()),[
                                                        Chem.rdchem.HybridizationType.SP,
                                                        Chem.rdchem.HybridizationType.SP2,
                                                        Chem.rdchem.HybridizationType.SP3,
                                                        Chem.rdchem.HybridizationType.SP3D,
                                                        Chem.rdchem.HybridizationType.SP3D2
                                                        ]),
        [1 if atom.GetIsAromatic() else 0],
        [atom.GetMass() * 0.01]
    ]
    return sum(atom_features, []) # Flatten the list into a single list


def get_bond_features(bond) -> list:
    # Returns a one-hot encoded feature list for the bond
    bond_fdim = 7

    if bond is None:
        bond_features = [1] + [0] * (bond_fdim - 1)
    else:
        bt = bond.GetBondType()
        bond_features = [
            0,  # Zeroth index indicates if bond is None
            bt == Chem.rdchem.BondType.SINGLE,
            bt == Chem.rdchem.BondType.DOUBLE,
            bt == Chem.rdchem.BondType.TRIPLE,
            bt == Chem.rdchem.BondType.AROMATIC,
            (bond.GetIsConjugated() if bt is not None else 0),
            (bond.IsInRing() if bt is not None else 0)
        ]
    return bond_features

### Dataset Construction

In [None]:
class MolGraph:
    # Returns a custom molecular graph for a given SMILES string
    # Contains atom, bond features and node connectivity
    def __init__(self, smiles: str):
        self.smiles = smiles
        self.atom_features = []
        self.bond_features = []
        self.edge_index = []

        molecule = Chem.MolFromSmiles(self.smiles)
        n_atoms = molecule.GetNumAtoms()

        for atom_1 in range(n_atoms):
            self.atom_features.append(get_atom_features(molecule.GetAtomWithIdx(atom_1)))

            for atom_2 in range(atom_1 + 1, n_atoms):
                bond = molecule.GetBondBetweenAtoms(atom_1, atom_2)
                if bond is None:
                    continue
                bond_features = get_bond_features(bond)
                self.bond_features.append(bond_features)
                self.bond_features.append(bond_features) # Bond features are added twice for both directions
                self.edge_index.extend([(atom_1, atom_2), (atom_2, atom_1)]) # Edge index list with tuples of connected nodes instead of adjacency matrix

In [None]:
class ChemDataset(Dataset):
    def __init__(self, smiles: str, labels, flip_prob: float=0.5, noise_std: float=0.5, precompute: bool=True):
        # Choose here how much noise to add for the denoising task
        super(ChemDataset, self).__init__()
        self.smiles = smiles
        self.labels = labels
        self.cache = {}
        self.flip_prob = flip_prob
        self.noise_std = noise_std
        self.precompute = precompute

        # Precomputing the dataset so the get method is faster, and the GPU doesn't have to wait for the CPU
        if precompute:
            print(f"Precomputing data...")
            with ThreadPoolExecutor(max_workers=4) as executor:
                futures = [
                    executor.submit(self.process_key , idx)
                    for idx in range(len(self.smiles))
                ]

                for future in as_completed(futures):
                    future.result()

            print(f"Precomputation finished. {len(self.cache)} molecules cached.")

    def process_key(self, key):
        # Process the key to get the corresponding molecule graph
        # If the molecule is already cached, return it
        smiles = self.smiles[key]
        if smiles in self.cache.keys():
            molecule = self.cache[smiles]
        else:
            molgraph = MolGraph(smiles)
            molecule = self.molgraph2data(molgraph, key)
            self.cache[smiles] = molecule
        return molecule

    def molgraph2data(self, molgraph, key):
        data = tg.data.Data()

        # Coverting all features and labels to tensors
        # And adding it to the data object
        data.x = torch.tensor(molgraph.atom_features, dtype=torch.float)
        data.edge_index = torch.tensor(molgraph.edge_index, dtype=torch.long).t().contiguous()
        data.edge_attr = torch.tensor(molgraph.bond_features, dtype=torch.float)
        data.y = torch.tensor([self.labels[key]], dtype=torch.float)
        data.smiles = self.smiles[key]

        if self.flip_prob > 0 or self.noise_std > 0:
            # Create a deep copy to avoid modifying original data
            x_noisy = deepcopy(data.x)
            edge_attr_noisy = deepcopy(data.edge_attr)
            
            # Apply bit flipping to binary features if probability > 0
            if self.flip_prob > 0:
                binary_features = x_noisy[:, :-1]  # All but last column, which contains mass
                flip_mask = torch.rand_like(binary_features) < self.flip_prob
                binary_features[flip_mask] = 1.0 - binary_features[flip_mask]  # Flip 0->1 and 1->0
                x_noisy[:, :-1] = deepcopy(binary_features)

                binary_features = edge_attr_noisy # Edge features only contain one-hot encodings
                flip_mask = torch.rand_like(binary_features) < self.flip_prob
                binary_features[flip_mask] = 1.0 - binary_features[flip_mask]  # Flip 0->1 and 1->0
                edge_attr_noisy = deepcopy(binary_features)
            
            # Apply Gaussian noise to continuous feature if std > 0
            if self.noise_std > 0:
                mass_feature = x_noisy[:, -1:]  # Just the last column, which contains mass
                # Adding noise which is a percentage of the mass feature
                mass_feature += mass_feature * torch.randn_like(mass_feature) * self.noise_std
                x_noisy[:, -1:] = deepcopy(mass_feature)
            
            data.x_noisy = x_noisy
            data.edge_attr_noisy = edge_attr_noisy

        return data

    def get(self,key):
        return self.process_key(key)
    
    def __getitem__(self, key):
        # Standard get method for PyTorch Dataset
        return self.process_key(key)

    def len(self):
        return len(self.smiles)

    def __len__(self):
        # Standard len method for PyTorch Dataset
        return len(self.smiles)

### Dataloader construction

In [None]:
def construct_loader(data_df: pd.DataFrame, smiles_column: str, target_column: str, shuffle: bool=True, batch_size: int=16):  
    # Constructs a PyTorch Geometric DataLoader from a DataFrame
    # Takes the SMILES and target column names as input
    assert len(data_df) > 0, "DataFrame is empty"
      
    smiles = data_df[smiles_column].values
    labels = data_df[target_column].values.astype(np.float32)  
    
    dataset = ChemDataset(smiles, labels)
    loader = DataLoader(dataset=dataset,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            pin_memory=True
                       )
    return loader

# Model

### Encoder

In [None]:
class DMPNNConv(MessagePassing): 
    # Extending the MessagePassing class from PyG
    # Used for the convolutional layers in the encoder
    def __init__(self, hidden_size: int):
        super(DMPNNConv, self).__init__(aggr='add') # Sum aggregation function, most expressive aggregation as far as I know
        self.linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, edge_index, edge_attr):
        row, _ = edge_index
        # Since each edge is bidirectional, we do two message passings, one for each direction
        aggregated_message = self.propagate(edge_index, x=None, edge_attr=edge_attr)
        reversed_message = torch.flip(edge_attr.view(edge_attr.size(0) // 2, 2, -1), dims=[1]).view(edge_attr.size(0), -1)

        return aggregated_message, self.linear(aggregated_message[row] - reversed_message)

    def message(self, edge_attr):
        return edge_attr

In [None]:
class GNNEncoder(nn.Module):
    def __init__(self, num_node_features: int, num_edge_features: int, hidden_size: int, mode: str, depth: int, dropout: float):
        super().__init__()
        self.depth = depth
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.mode = mode

        # Encoder layers
        self.edge_init = nn.Linear(num_node_features + num_edge_features, hidden_size)
        self.convs = nn.ModuleList([DMPNNConv(hidden_size) for _ in range(depth)])
        self.edge_to_node = nn.Linear(num_node_features + hidden_size, hidden_size)
        self.pool = global_add_pool  # Not learnable

    def forward(self, data):
        edge_index, edge_attr, batch = data.edge_index, data.edge_attr, data.batch
        
        if self.mode == 'denoise':
            x = data.x_noisy
            edge_attr = data.edge_attr_noisy
        elif self.mode == 'predict':
            x = data.x
        else:
            raise ValueError("Invalid mode. Choose 'denoise' or 'predict'.")

        # Edge initialization
        row, _ = edge_index
        h_0 = F.relu(self.edge_init(torch.cat([x[row], edge_attr], dim=1)))
        h = h_0

        # DMPNN Conv layers
        for layer in self.convs:
            _, h = layer(edge_index, h)
            h += h_0
            h = F.dropout(F.relu(h), self.dropout, training=self.training)

        # Edge to node aggregation
        # Re-using the last layer's results for s
        s, _ = self.convs[-1](edge_index, h)
        
        # Due to a recurring error which I can't figure out, we add a check here
        # to ensure that the sizes of s and x match
        # This is a workaround and should be fixed in the future
        # Luckily, this issue only occurs for batches with batch size 1

        # Pad/truncate s to match x's size
        if s.shape[0] != x.shape[0]:
            # Create tensor with same length as x (regardless of connectivity)
            s_fixed = torch.zeros(x.shape[0], self.hidden_size, device=s.device)
            # Only use the connected nodes we have (first min(s.shape[0], x.shape[0]))
            min_len = min(s.shape[0], x.shape[0]) 
            s_fixed[:min_len] = s[:min_len]
            s = s_fixed

        q = torch.cat([x, s], dim=1)
        h = F.relu(self.edge_to_node(q))

        # Global pooling for the final node embeddings
        embedding = self.pool(h, batch)
        
        return embedding

### Decoder

In [None]:
class GNNDecoder(nn.Module):
    # Decoder for self-supervised denoising task
    # Decoding both node and edge features
    def __init__(self, hidden_size: int, num_node_features: int, num_edge_features: int, dropout: float):
        super().__init__()
        # Node decoding layer
        self.node_lin = nn.Linear(hidden_size, num_node_features)
        # Edge decoding layers
        self.edge_lin = nn.Linear(hidden_size, num_edge_features)
        self.dropout = dropout

    def forward(self, graph_embedding, batch, edge_index):
        # Decode node features
        batch_size = graph_embedding.size(0)
        node_counts = torch.bincount(batch)  # number of nodes in each graph

        # Expand each graph embedding for nodes
        expanded_nodes = []
        for g in range(batch_size):
            expanded_nodes.append(graph_embedding[g].unsqueeze(0).repeat(node_counts[g], 1))

        # Concatenate along node dimension
        expanded_nodes = torch.cat(expanded_nodes, dim=0)  # total_nodes x hidden_size

        # Decode node features
        x_hat = F.dropout(expanded_nodes, p=self.dropout, training=self.training)
        x_hat = self.node_lin(x_hat)
        
        # Decode edge features
        # Map edges to their source node's graph
        edge_src = edge_index[0]  # Source nodes of edges
        edge_batch = batch[edge_src]  # Batch indices of source nodes (which graph they belong to)
        
        # Expand graph embeddings for edges
        expanded_edges = graph_embedding[edge_batch]  # shape = (num_edges, hidden_size)
        
        # Decode edge features
        edge_hat = F.dropout(expanded_edges, p=self.dropout, training=self.training)
        edge_hat = self.edge_lin(edge_hat)

        return x_hat, edge_hat

### Prediction Head

In [None]:
class GNNHead(nn.Module):
    # Prediction Head for prediction solubility
    def __init__(self, hidden_size: int, dropout: float):
        super().__init__()
        # Only some FFN layers which get the embedding as input
        self.ffn1 = nn.Linear(hidden_size, hidden_size)
        self.ffn2 = nn.Linear(hidden_size, 1)
        self.dropout = dropout

    def forward(self, graph_embedding):
        x = F.relu(self.ffn1(graph_embedding))
        x = F.dropout(x, self.dropout, training=self.training)
        return self.ffn2(x).squeeze(-1)

### Main Model

In [None]:
class GNN(nn.Module):
    # The main GNN model which brings together the encoder, decoder and head
    # It has two modes, denoise and predict
    # The encoder branches out to the decoder and head
    def __init__(self, num_node_features: int, num_edge_features: int, hidden_size: int=300, depth: int=5, mode: str='denoise', dropout: float=0.1):
        super().__init__()
        self.encoder = GNNEncoder(num_node_features, num_edge_features, hidden_size=hidden_size, mode=mode, depth=depth, dropout=dropout)
        self.head = GNNHead(hidden_size=hidden_size, dropout=dropout)
        self.decoder = GNNDecoder(hidden_size=hidden_size, num_node_features=num_node_features, num_edge_features=num_edge_features, dropout=dropout)

    def set_mode(self, mode: str):
        # Update the mode in the encoder
        # So the encoder knows if it needs to read noisy or noise-free data
        self.encoder.mode = mode

    def get_embedding(self, data):
        # Get the graph embedding from the encoder
        graph_embedding = self.encoder(data)
        return graph_embedding

    def forward(self, data):
        graph_embedding = self.encoder(data)

        if self.encoder.mode == 'predict':
            prediction = self.head(graph_embedding)
            return prediction
        
        elif self.encoder.mode == 'denoise':
            node_features, edge_features = self.decoder(graph_embedding, data.batch, data.edge_index)
            return node_features, edge_features
    
        else:
            raise ValueError("Invalid mode. Choose 'predict' or 'denoise'.")
        
    def encode(self, data):
        return self.encoder(data)

In [None]:
class Standardizer:
    # Standardizer for the solubility values
    def __init__(self, mean: float, std: float):
        self.mean = mean
        self.std = std

    def __call__(self, x, rev: bool=False):
        if rev:
            return (x * self.std) + self.mean
        return (x - self.mean) / self.std

In [None]:
def train_epoch(model, loader, optimizer, loss, alpha: float, stdzer: Standardizer=None):
    # Train the model for one epoch on the denoising and prediction tasks simultaneously
    model.train()
    total_loss_count = 0
    denoise_loss_count = 0
    pred_loss_count = 0

    # Unfreeze all parts of the model
    for param in model.parameters():
        param.requires_grad = True

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Get losses for the denoising task
        model.set_mode('denoise')
        node_out, edge_out = model(batch)
        node_loss = loss(node_out, batch.x)
        edge_loss = loss(edge_out, batch.edge_attr)
        denoise_loss = alpha * (node_loss + edge_loss)

        # Get losses for the prediction task
        model.set_mode('predict')
        pred_out = model(batch)
        pred_loss = (1 - alpha) * loss(pred_out, stdzer(batch.y))

        # Combine the weighted losses as a sum and backpropagate them
        combined_loss = denoise_loss + pred_loss
        combined_loss.backward()
        optimizer.step()

        # Cui et al. did a non-weighted sum of self-supervised and supervised losses
        # https://doi.org/10.1038/s41467-025-57101-4

        # Wang et al. did a weighted sum of self-supervised and supervised losses
        # https://doi.org/10.48550/arXiv.2210.08813

        # Also experimented with alternating backpropagation steps for each task
        # Results were worse

        total_loss_count += combined_loss.item()
        denoise_loss_count += denoise_loss.item()
        pred_loss_count += pred_loss.item()

    return math.sqrt(total_loss_count / len(loader.dataset)), math.sqrt(denoise_loss_count / len(loader.dataset)), math.sqrt(pred_loss_count / len(loader.dataset))


def train_epoch_without_SSL(model, loader, optimizer, loss, alpha: float, stdzer: Standardizer=None):
    # Train the model for one epoch on the prediction task only
    # This is only needed for reference
    # Unfreeze the encoder
    for param in model.encoder.parameters():
        param.requires_grad = True

    # Freeze the decoder
    for param in model.decoder.parameters():
        param.requires_grad = False

    # Unfreeze the prediction head
    for param in model.head.parameters():
        param.requires_grad = True

    # Train the model for one epoch, either on denoising or prediction task
    model.train()
    pred_loss_count = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        # Get losses for the prediction task
        model.set_mode('predict')
        pred_out = model(batch)
        pred_loss = (1 - alpha) * loss(pred_out, stdzer(batch.y))

        # Only backpropagate the prediction loss
        pred_loss.backward()
        optimizer.step()

        pred_loss_count += pred_loss.item()

    return math.sqrt(pred_loss_count / len(loader.dataset))


def pred(model, loader, mode: str, stdzer: Standardizer=None):
    # Predict with the model, either on denoising or prediction task
    # No test-time adaptation here, just a simple forward pass
    if mode == 'denoise':
        model.set_mode('denoise')
        model.eval()

        preds = []
        with torch.no_grad():
            for batch in loader:
                batch = batch.to(device)
                node_out, edge_out = model(batch)
                node_out.cpu().detach().flatten().tolist()
                preds.extend(node_out.cpu().detach().flatten().tolist() + edge_out.cpu().detach().flatten().tolist())
                
        return preds

    elif mode == 'predict':
        model.set_mode('predict')
        model.eval()

        preds = []
        with torch.no_grad():
            for batch in loader:
                batch = batch.to(device)
                out = model(batch)
                pred = stdzer(out, rev=True)
                preds.extend(pred.cpu().detach().tolist())

        return preds
    
    else:
        raise ValueError("Invalid mode. Choose 'denoise' or 'predict'.")

# Training

In [None]:
batch_size = 512

### Loading and preprocessing the Data

In [None]:
torch.manual_seed(0)
data_df = pd.read_csv("data/AqSolDBc.csv")
# Drop single atoms    
idx_single = [i for i,s in enumerate(data_df['SmilesCurated']) if Chem.MolFromSmiles(s).GetNumAtoms() == 1 or '.' in s]
data_df = data_df.drop(idx_single)
if len(idx_single) > 0:
    print(f"Removing {idx_single} due to single atoms")

test_df = pd.read_csv("data/OChemUnseen.csv")
# Drop some Nonetypes
idx_nonetype = [i for i,s in enumerate(test_df['SMILES']) if Chem.MolFromSmiles(s) is None] # Got an error for a SMILES which was None
test_df = test_df.drop(idx_nonetype)
if len(idx_nonetype) > 0:
    print(f"Removing {idx_nonetype} due to Nonetypes")

# Drop single atoms
idx_single = [i for i,s in enumerate(test_df['SMILES']) if Chem.MolFromSmiles(s).GetNumAtoms() == 1 or '.' in s]
test_df = test_df.drop(idx_single)
if len(idx_single) > 0:
    print(f"Removing {idx_single} due to single atoms")

In [None]:
train_df, val_df = train_test_split(data_df, test_size=0.1, random_state=0)
train_loader = construct_loader(train_df, 'SmilesCurated', 'ExperimentalLogS', shuffle=True, batch_size=batch_size)
val_loader = construct_loader(val_df, 'SmilesCurated', 'ExperimentalLogS', shuffle=False, batch_size=batch_size)
test_loader = construct_loader(test_df, 'SMILES', 'LogS', shuffle=False, batch_size=1)
print(f"Train size: {len(train_loader.dataset)}, Val size: {len(val_loader.dataset)}, Test size: {len(test_loader.dataset)}")

# Standardizer for the solubility labels
mean = np.mean(train_loader.dataset.labels)
std = np.std(train_loader.dataset.labels)
stdzer = Standardizer(mean, std)

### Hyperparameter optimization

In [None]:
hyperparam_opt = True # Optional

In [None]:
# Optional hyperparameter optimization on the validation set

def objective(trial):
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True)
    alpha = trial.suggest_float("alpha", 1e-3, 1.0, log=True)
    hidden_size = trial.suggest_int("hidden_size", 64, 512, step=64)
    depth = trial.suggest_int("depth", 1, 7, step=1)
    dropout = trial.suggest_float("dropout", 0.0, 0.8, step=0.1)

    model = GNN(train_loader.dataset.num_node_features, train_loader.dataset.num_edge_features, hidden_size=hidden_size, depth=depth, dropout=dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss = nn.MSELoss(reduction='sum')

    for epoch in range(5):
        total_train_loss, _, _ = train_epoch(model, train_loader, optimizer, loss, alpha, stdzer)

    preds = pred(model, val_loader, mode='predict', stdzer=stdzer)
    pred_val_loss = root_mean_squared_error(preds, val_loader.dataset.labels)

    return pred_val_loss

In [None]:
if hyperparam_opt:    
    study = optuna.create_study(direction="minimize", pruner=optuna.pruners.HyperbandPruner(), study_name="hyperparam_opt")
    study.optimize(objective, n_trials=50, n_jobs=4)

    print(f"Best learning rate: {study.best_params['learning_rate']:.4g}")
    print(f"Best alpha: {study.best_params['alpha']:.4g}")
    print(f"Best hidden size: {study.best_params['hidden_size']:.4g}")
    print(f"Best depth: {study.best_params['depth']:.4g}")
    print(f"Best dropout: {study.best_params['dropout']:.4g}")

    study_results = study.trials_dataframe()
    study_results.to_csv("hyperparam_opt.csv", index=False)
    print("Study results saved to hyperparam_opt.csv")

    learning_rate = study.best_params['learning_rate']
    alpha = study.best_params['alpha']
    hidden_size = study.best_params['hidden_size']
    depth = study.best_params['depth']
    dropout = study.best_params['dropout']

# Standard hyperparameters are the result of a previous hyperparameter optimization
else:
    learning_rate = 0.0064
    alpha = 0.002 # Weighting the losses of the two tasks
    hidden_size = 384
    depth = 3
    dropout = 0.1

epochs = 10

# If wanted, we can load an already trained model
save_model = False  # Set to False if you don't want to save
load_trained_model = False  # Set to True to skip training and load existing model
save_plots = False

### Training on both Tasks

In [None]:
model = GNN(train_loader.dataset.num_node_features, train_loader.dataset.num_edge_features, hidden_size=hidden_size, depth=depth, dropout=dropout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss = nn.MSELoss(reduction='sum')
print('\n', model, '\n')

In [None]:
# If wanted, we can load an already trained model

# Path for the best model
model_path = os.path.join("trained_models", "model.pt")
if load_trained_model and os.path.exists(model_path):
    # Load pretrained model
    print(f"Loading model from {model_path}")
    model.load_state_dict(torch.load(model_path))
    best_model = deepcopy(model).to(device)
    # Skip training
    epochs = 0

In [None]:
# Training the model on both tasks simultaneously
if not load_trained_model:
    best_model = deepcopy(model).to(device)
    best_pred_val_loss = 1e5

    total_train_loss_list = []
    denoise_train_loss_list = []
    pred_train_loss_list = []

    total_val_loss_list = []
    denoise_val_loss_list = []
    pred_val_loss_list = []

    if epochs == 0:
        print("Skipping training, previously trained model was loaded.")

    for epoch in range(0, epochs):
        total_train_loss, denoise_train_loss, pred_train_loss = train_epoch(model, train_loader, optimizer, loss, alpha=alpha, stdzer=stdzer)

        denoised = pred(model, val_loader, mode='denoise')
        node_feature_targets = [feature for batch in val_loader for feature in batch.x.cpu().flatten().tolist()]
        edge_feature_targets = [feature for batch in val_loader for feature in batch.edge_attr.cpu().flatten().tolist()]
        denoise_val_loss = root_mean_squared_error(denoised, node_feature_targets + edge_feature_targets)

        preds = pred(model, val_loader, mode='predict', stdzer=stdzer)
        pred_val_loss = root_mean_squared_error(preds, val_loader.dataset.labels)
        total_val_loss = alpha * denoise_val_loss + (1 - alpha) * pred_val_loss

        print(f"Epoch {epoch}  Train Total Loss: {total_train_loss:.2f}  Train Denoise Loss: {denoise_train_loss:.2f}  Train Pred Loss: {pred_train_loss:.2f}  Val Total Loss: {total_val_loss:.2f}  Val Denoise Loss: {denoise_val_loss:.2f}  Val Pred Loss: {pred_val_loss:.2f}")

        total_train_loss_list.append(total_train_loss)
        denoise_train_loss_list.append(denoise_train_loss)
        pred_train_loss_list.append(pred_train_loss)
        denoise_val_loss_list.append(denoise_val_loss)
        pred_val_loss_list.append(pred_val_loss)
        total_val_loss_list.append(total_val_loss)

        if pred_val_loss < best_pred_val_loss:
            best_model = deepcopy(model).to(device)
            best_pred_val_loss = pred_val_loss
            
            # Save the best model
            if save_model:
                print(f"Saving best model based on Val Pred Loss...")
                torch.save(best_model.state_dict(), os.path.join("trained_models", f"model.pt"))

In [None]:
# Only plot if actual training was done
if epochs != 0:

    fig, ax = plt.subplots(1, 1, figsize=(7, 4))


    # Just plotting the losses
    ax.plot(list(range(epochs)), total_train_loss_list, label='Total Train Loss', color='blue')
    ax.plot(list(range(epochs)), denoise_train_loss_list, label='Denoise Train Loss', color='orange')
    ax.plot(list(range(epochs)), pred_train_loss_list, label='Pred Train Loss', color='green')
    ax.plot(list(range(epochs)), total_val_loss_list, label='Total Val Loss', color='blue', linestyle='dashed')
    ax.plot(list(range(epochs)), denoise_val_loss_list, label='Denoise Val Loss', color='orange', linestyle='dashed')
    ax.plot(list(range(epochs)), pred_val_loss_list, label='Pred Val Loss', color='green', linestyle='dashed')
    ax.set_title('Training and Validation Losses for both Tasks')
    ax.legend(ncol=2)
    ax.set_xlabel('Epochs')
    ax.set_xticks(list(range(0, epochs, 5)))
    ax.set_ylabel('Loss')

    plt.tight_layout()
    if save_plots:   
        fig.savefig("figures/loss_plot.jpg", dpi=300)
    plt.show()

### Training without SSL for reference

In [None]:
model_non_SSL = GNN(train_loader.dataset.num_node_features, train_loader.dataset.num_edge_features, hidden_size=hidden_size, depth=depth, dropout=dropout).to(device)
loss_non_SSL = nn.MSELoss(reduction='sum')
print('\n', model, '\n')

In [None]:
# If wanted, we can load an already trained model

# Path for the best model
model_path = os.path.join("trained_models", "model_non_SSL.pt")
if load_trained_model and os.path.exists(model_path):
    # Load pretrained model
    print(f"Loading model from {model_path}")
    model_non_SSL.load_state_dict(torch.load(model_path))
    best_model_non_SSL = deepcopy(model_non_SSL).to(device)
    # Skip training
    epochs = 0

In [None]:
# Different optimizers for different tasks
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

best_model_non_SSL = deepcopy(model_non_SSL).to(device)
best_val_loss = 1e5

train_loss_list = []
val_loss_list = []

if epochs == 0:
    print("Skipping training, previously trained model was loaded.")

for epoch in range(0, epochs):
    train_loss = train_epoch_without_SSL(model, train_loader, optimizer, loss, alpha=alpha, stdzer=stdzer)

    preds = pred(model, val_loader, mode='predict', stdzer=stdzer)
    val_loss = root_mean_squared_error(preds, val_loader.dataset.labels)

    print(f"Epoch {epoch}  Train Loss: {train_loss:.2f}  Val Loss: {val_loss:.2f}")

    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)

    if val_loss < best_val_loss:
        best_model_non_SSL = deepcopy(model_non_SSL).to(device)
        best_val_loss = val_loss
        
        # Save the best model
        if save_model:
            print(f"Saving best model...")
            torch.save(best_model_non_SSL.state_dict(), os.path.join("trained_models", f"model_non_SSL.pt"))

# Prediction

In [None]:
def pred_with_TTA(model, loader, alpha: float, lr: float, stdzer:Standardizer=None):
    # Predict with test-time adaptation (TTA)
    # We want a batch size of 1 for this

    # Unfreeze the encoder
    for param in model.encoder.parameters():
        param.requires_grad = True

    # Unfreeze the decoder
    for param in model.decoder.parameters():
        param.requires_grad = True

    # Freeze the prediction head
    for param in model.head.parameters():
        param.requires_grad = False

    model = deepcopy(model).to(device)
    model_before_step = deepcopy(model).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    preds = []

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        model.set_mode('denoise')
        model.train()
        
        node_out, edge_out = model(batch)
        node_loss = loss(node_out, batch.x)
        edge_loss = loss(edge_out, batch.edge_attr)
        # Losses get the same weighting as in the training step
        denoise_loss = alpha * (node_loss + edge_loss)
        denoise_loss.backward()
        optimizer.step()

        model.set_mode('predict')
        model.eval()

        with torch.no_grad():
            out = model(batch)
            pred = stdzer(out, rev=True)
            preds.extend(pred.cpu().detach().tolist())

        model = deepcopy(model_before_step)
        
    return preds

In [None]:
# Prediction with non-SSL model
preds_non_SSL = pred(model_non_SSL, test_loader, mode='predict', stdzer=stdzer)
print('Reference results on test set for model trained without SSL:')
print(f"RMSE: {root_mean_squared_error(preds_non_SSL, test_loader.dataset.labels):.2f}")

In [None]:
# Prediction with SLL model without TTA
preds = pred(best_model, test_loader, mode='predict', stdzer=stdzer)
print('Results on test set without TTA:')  
print(f"RMSE: {root_mean_squared_error(preds, test_loader.dataset.labels):.4g}")

In [None]:
# Optional hyperparameter optimization on the validation set for the TTA step

def objective(trial):
    learning_rate_TTA = trial.suggest_float("learning_rate_TTA", 1e-5, 1.0, log=True)
    alpha_TTA = trial.suggest_float("alpha_TTA", 1e-4, 1.0, log=True)
    preds_TTA = pred_with_TTA(best_model, val_loader, alpha=alpha_TTA, lr=learning_rate_TTA, stdzer=stdzer)

    return root_mean_squared_error(preds_TTA, val_loader.dataset.labels)

In [None]:
if hyperparam_opt:    
    study = optuna.create_study(direction="minimize", pruner=optuna.pruners.HyperbandPruner(), study_name="TTA_hyperparam_opt")
    study.optimize(objective, n_trials=100, n_jobs=6)

    print(f"Best TTA learning rate: {study.best_params['learning_rate_TTA']:.4g}")
    print(f"Best TTA alpha: {study.best_params['alpha_TTA']:.4g}")

    study_results = study.trials_dataframe()
    study_results.to_csv("TTA_hyperparam_opt.csv", index=False)
    print("Study results saved to TTA_hyperparam_opt.csv")

    learning_rate_TTA = study.best_params['learning_rate_TTA']
    alpha_TTA = study.best_params['alpha_TTA']

# Standard hyperparameters are the result of a previous hyperparameter optimization
else:
    learning_rate_TTA = 0.00175
    alpha_TTA = 0.0958

In [None]:
# Prediction with SLL model with TTA
# Takes about three times as long as prediction without TTA
preds_TTA = pred_with_TTA(best_model, test_loader, alpha=alpha_TTA, lr=learning_rate_TTA, stdzer=stdzer)
print('Results on test set with TTA:')
print(f"RMSE: {root_mean_squared_error(preds_TTA, test_loader.dataset.labels):.4g}")

# Analysis

### Getting the Embeddings

In [None]:
def embeddings_with_TTA(model, loader, alpha: float, lr: float):
    # Get embeddings with test-time adaptation (TTA)
    # We want a batch size of 1 for this

    # Unfreeze the encoder
    for param in model.encoder.parameters():
        param.requires_grad = True

    # Unfreeze the decoder
    for param in model.decoder.parameters():
        param.requires_grad = True

    # Freeze the prediction head
    for param in model.head.parameters():
        param.requires_grad = False

    model = deepcopy(model).to(device)
    model_before_step = deepcopy(model).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    embeddings = []

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        model.set_mode('denoise')
        model.train()
        
        node_out, edge_out = model(batch)
        node_loss = loss(node_out, batch.x)
        edge_loss = loss(edge_out, batch.edge_attr)
        # Losses get the same weighting as in the training step
        denoise_loss = alpha * (node_loss + edge_loss)
        denoise_loss.backward()
        optimizer.step()

        model.set_mode('predict')
        model.eval()

        with torch.no_grad():
            batch = batch.to(device)
            embedding = model.get_embedding(batch)
            embeddings.extend(embedding.cpu().detach().numpy())

        model = deepcopy(model_before_step)
        
    return embeddings

In [None]:
best_model = best_model.to(device)

# Getting embeddings for analysis using PCA
# The embeddings are lists of length "samples" of vectors with length "hidden_size"
train_embeddings = []
for batch in train_loader:
    batch = batch.to(device)
    embedding = best_model.get_embedding(batch)
    train_embeddings.extend(embedding.cpu().detach().numpy())

val_embeddings = []
for batch in val_loader:
    batch = batch.to(device)
    embedding = best_model.get_embedding(batch)
    val_embeddings.extend(embedding.cpu().detach().numpy())

test_embeddings = []
for batch in test_loader:
    batch = batch.to(device)
    embedding = best_model.get_embedding(batch)
    test_embeddings.extend(embedding.cpu().detach().numpy())

# Create and fit PCA
pca = PCA(n_components=2)
train_embeddings_2d = pca.fit_transform(np.array(train_embeddings))
val_embeddings_2d = pca.transform(np.array(val_embeddings))
test_embeddings_2d = pca.transform(np.array(test_embeddings))
print(f"Explained variance ratio: {pca.explained_variance_ratio_}")

# Also getting the test set embeddings after TTA
test_embeddings_with_TTA = embeddings_with_TTA(best_model, test_loader, alpha=alpha_TTA, lr=learning_rate_TTA)
test_embeddings_with_TTA_2d = pca.transform(np.array(test_embeddings_with_TTA))

### Analyzing the Solubility Distribution

In [None]:
# Get the solubility values from each dataset
train_solubility = train_loader.dataset.labels
val_solubility = val_loader.dataset.labels
test_solubility = test_loader.dataset.labels

# Create a single figure and axes
fig, ax = plt.subplots(figsize=(7, 7))

# Define a consistent colormap for solubility
cmap = plt.cm.viridis
norm = plt.Normalize(min(min(train_solubility), min(val_solubility), min(test_solubility)),
                     max(max(train_solubility), max(val_solubility), max(test_solubility)))

# Plot all datasets on the same axes with different markers
sc1 = ax.scatter(train_embeddings_2d[:, 0], train_embeddings_2d[:, 1], 
                c=train_solubility, cmap=cmap, norm=norm, alpha=0.5, s=3)
sc2 = ax.scatter(val_embeddings_2d[:, 0], val_embeddings_2d[:, 1], 
                c=val_solubility, cmap=cmap, norm=norm, alpha=0.5, s=3)
sc3 = ax.scatter(test_embeddings_2d[:, 0], test_embeddings_2d[:, 1], 
                c=test_solubility, cmap=cmap, norm=norm, alpha=0.5, s=3)

# Add a colorbar, title, labels, and legend
cbar = fig.colorbar(sc1, ax=ax, aspect=50)
cbar.set_label('Solubility')
ax.set_title('PCA Projection of Embedding Vectors with Solubility')
ax.set_xlabel('Principal Component 1')
ax.set_ylabel('Principal Component 2')

# Dynamically adjust the limits to focus on the main data distribution
# Use percentiles to avoid sensitivity to outliers
x_min_train, x_max_train = np.percentile(train_embeddings_2d[:, 0], [5, 95])
y_min_train, y_max_train = np.percentile(train_embeddings_2d[:, 1], [5, 95])
x_min_test, x_max_test = np.percentile(test_embeddings_2d[:, 0], [5, 95])
y_min_test, y_max_test = np.percentile(test_embeddings_2d[:, 1], [5, 95])
x_max = max(x_max_train, x_max_test)
x_min = min(x_min_train, x_min_test)
y_max = max(y_max_train, y_max_test)
y_min = min(y_min_train, y_min_test)
# Add a small margin to the limits
margin_x = (x_max - x_min) * 0.05
margin_y = (y_max - y_min) * 0.05

ax.set_xlim(x_min - margin_x, x_max + margin_x)
ax.set_ylim(y_min - margin_y, y_max + margin_y)

plt.tight_layout()
if save_plots:  
    fig.savefig("figures/pca_solubility.jpg", dpi=300)
plt.show()

### Analyzing the Class Distribution and Effect of TTA

In [None]:
# Plot the 2D embeddings using fig, ax
fig, ax = plt.subplots(figsize=(7, 7))

ax.scatter(train_embeddings_2d[:, 0], train_embeddings_2d[:, 1], alpha=0.5, label='Train Set', s=3)
ax.scatter(val_embeddings_2d[:, 0], val_embeddings_2d[:, 1], alpha=0.5, label='Validation Set', s=3)
ax.scatter(test_embeddings_2d[:, 0], test_embeddings_2d[:, 1], alpha=0.5, label='Test Set', s=3)
ax.scatter(test_embeddings_with_TTA_2d[:, 0], test_embeddings_with_TTA_2d[:, 1], alpha=0.5, label='Test Set with TTA', s=3)

# Calculate centroids with median insetad of mean since we have outliers
train_centroid = np.median(train_embeddings_2d, axis=0)
val_centroid = np.median(val_embeddings_2d, axis=0)
test_centroid = np.median(test_embeddings_2d, axis=0)
test_tta_centroid = np.median(test_embeddings_with_TTA_2d, axis=0)

# Plot centroids with larger markers
ax.scatter(train_centroid[0], train_centroid[1], s=100, 
           c='blue', marker='+', edgecolors='white', linewidths=2, label='Train Centroid', alpha=0.9)
ax.scatter(val_centroid[0], val_centroid[1], s=100, 
           c='orange', marker='x', edgecolors='black', linewidths=2, label='Val Centroid', alpha=0.9)
ax.scatter(test_centroid[0], test_centroid[1], s=100, 
           c='green', marker='+', edgecolors='black', linewidths=2, label='Test Centroid', alpha=0.9)
ax.scatter(test_tta_centroid[0], test_tta_centroid[1], s=100, 
           c='red', marker='x', edgecolors='black', linewidths=2, label='Test Centroid with TTA', alpha=0.9)

ax.legend()
ax.set_title('PCA Projection of Embedding with Sets')
ax.set_xlabel('Principal Component 1')
ax.set_ylabel('Principal Component 2')
ax.set_xlim(x_min - margin_x, x_max + margin_x)
ax.set_ylim(y_min - margin_y, y_max + margin_y)

plt.tight_layout()
if save_plots:  
    fig.savefig("figures/pca_sets.jpg", dpi=300)
plt.show()

In [None]:
# Plot the 2D embeddings using fig, ax
fig, ax = plt.subplots(figsize=(7, 7))
ax.scatter(test_embeddings_2d[:, 0], test_embeddings_2d[:, 1], alpha=0.5, label='Test Set without TTA', s=3)
ax.scatter(test_embeddings_with_TTA_2d[:, 0], test_embeddings_with_TTA_2d[:, 1], alpha=0.5, label='Test Set with TTA', s=3)

# Plot centroids with larger markers
ax.scatter(train_centroid[0], train_centroid[1], s=100, 
           c='blue', marker='+', edgecolors='white', linewidths=2, label='Train Centroid', alpha=0.9)
ax.scatter(val_centroid[0], val_centroid[1], s=100, 
           c='orange', marker='x', edgecolors='black', linewidths=2, label='Val Centroid', alpha=0.9)
ax.scatter(test_centroid[0], test_centroid[1], s=100, 
           c='green', marker='+', edgecolors='black', linewidths=2, label='Test Centroid', alpha=0.9)
ax.scatter(test_tta_centroid[0], test_tta_centroid[1], s=100, 
           c='red', marker='x', edgecolors='black', linewidths=2, label='Test Centroid with TTA', alpha=0.9)

ax.legend()
ax.set_title('PCA Projection of Embedding with TTA vs No TTA')
ax.set_xlabel('Principal Component 1')
ax.set_ylabel('Principal Component 2')
ax.set_xlim(x_min - margin_x, x_max + margin_x)
ax.set_ylim(y_min - margin_y, y_max + margin_y)

plt.tight_layout()
fig.savefig("figures/pca_sets_TTA_vs_no_TTA.jpg", dpi=300)
plt.show()