In [1]:
# Standard library
from collections import defaultdict
from typing import List, Tuple
import importlib.resources as pkg_resources
from multiprocessing.pool import ThreadPool

# Third-party scientific stack
import numpy as np
import pandas as pd
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import eigsh

# RDKit
from rdkit import Chem

# PyTorch core
import torch
import torch.nn as nn
import torch.nn.functional as F

# PyTorch Geometric
from torch_geometric.data import Data, Batch, DataLoader
from torch_geometric.nn import GCNConv, GINConv, BatchNorm, global_mean_pool
from torch_geometric.loader import DataLoader as PyGDataLoader

# Polyatomic complexes
from polyatomic_complexes.src.complexes.abstract_complex import AbstractComplex
from polyatomic_complexes.src.complexes import PolyatomicGeometrySMILE

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed):
    """
    Fix all random seeds for reproducibility across Python, NumPy, and PyTorch.
    """
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(59)

In [3]:
def load_lipophil_data():
  data_path = pkg_resources.files('polyatomic_complexes.dataset.lipophilicity') / 'Lipophilicity.csv'
  df = pd.read_csv(data_path)
  return df

In [4]:
data = load_lipophil_data()
data.dropna(inplace=True)
data.reset_index(drop=True, inplace=True)
data

Unnamed: 0,CMPD_CHEMBLID,exp,smiles
0,CHEMBL596271,3.54,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14
1,CHEMBL1951080,-1.18,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...
2,CHEMBL1771,3.69,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl
3,CHEMBL234951,3.37,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...
4,CHEMBL565079,3.10,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...
...,...,...,...
4194,CHEMBL496929,3.85,OCCc1ccc(NC(=O)c2cc3cc(Cl)ccc3[nH]2)cc1
4195,CHEMBL199147,3.21,CCN(C1CCN(CCC(c2ccc(F)cc2)c3ccc(F)cc3)CC1)C(=O...
4196,CHEMBL15932,2.10,COc1cccc2[nH]ncc12
4197,CHEMBL558748,2.65,Clc1ccc2ncccc2c1C(=O)NCC3CCCCC3


In [5]:
data

Unnamed: 0,CMPD_CHEMBLID,exp,smiles
0,CHEMBL596271,3.54,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14
1,CHEMBL1951080,-1.18,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...
2,CHEMBL1771,3.69,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl
3,CHEMBL234951,3.37,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...
4,CHEMBL565079,3.10,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...
...,...,...,...
4194,CHEMBL496929,3.85,OCCc1ccc(NC(=O)c2cc3cc(Cl)ccc3[nH]2)cc1
4195,CHEMBL199147,3.21,CCN(C1CCN(CCC(c2ccc(F)cc2)c3ccc(F)cc3)CC1)C(=O...
4196,CHEMBL15932,2.10,COc1cccc2[nH]ncc12
4197,CHEMBL558748,2.65,Clc1ccc2ncccc2c1C(=O)NCC3CCCCC3


In [6]:
data.columns

Index(['CMPD_CHEMBLID', 'exp', 'smiles'], dtype='object')

In [7]:
def build_graph_from_smiles(smile: str, descriptors: List[float], topk_lap: int = 5):
    """
    Builds a PyG Data object with:
      - Node features: chain_0 values (shape [n,1]) + broadcasted descriptors
      - Edge structure and features from get_bonds()
      - Graph-level features: mean/std of chain_1..chain_k + topk eigenvalues of molecule laplacians
    """
    try:
        # 1) Abstract complex
        pg = PolyatomicGeometrySMILE(smile=smile, mode="abstract")
        ac = pg.smiles_to_geom_complex()
        assert isinstance(ac, AbstractComplex)

        # 2) Raw chains
        chains = ac.get_raw_k_chains()
        # chain_0: node-level
        chain0 = chains.get('chain_0')
        if chain0 is None:
            print(f"⚠️ Missing chain_0 for {smile}")
            return None
        n = len(chain0)
        x_node = torch.tensor(chain0, dtype=torch.float32).unsqueeze(1)  # [n,1]
        # graph-level chain stats
        g_stats = []
        for k, arr in chains.items():
            if k == 'chain_0': continue
            a = np.array(arr, dtype=np.float32)
            g_stats.extend([a.mean(), a.std()])
        g_stats = np.array(g_stats, dtype=np.float32)

        # 3) Laplacian eigenvalues (flatten all dims)
        laps = ac.get_laplacians().get('molecule_laplacians', [])
        lap_tuples = []
        for group in laps:
            if isinstance(group, list):
                for item in group:
                    if isinstance(item, tuple) and len(item)==2:
                        lap_tuples.append(item)
            elif isinstance(group, tuple) and len(group)==2:
                lap_tuples.append(group)
        lap_feats = []
        for dim, mat in lap_tuples:
            # convert to sparse if needed
            M = coo_matrix(mat) if not isinstance(mat, coo_matrix) else mat
            k = min(topk_lap, M.shape[0]-1)
            if k <= 0:
                vals = np.zeros(topk_lap, dtype=np.float32)
            else:
                try:
                    vals, _ = eigsh(M, k=k, return_eigenvectors=False)
                except Exception:
                    vals = np.zeros(k, dtype=np.float32)
                if len(vals) < topk_lap:
                    vals = np.pad(vals, (0, topk_lap-len(vals)))
            lap_feats.extend(vals.tolist())
        g_feats = torch.tensor(np.concatenate([g_stats, lap_feats]), dtype=torch.float32)

        # 4) 0-simplices -> node_ids ordering) 0-simplices -> node_ids ordering
        sk = ac.get_skeleta().get('molecule_skeleta', [[]])[0]
        zero = next((lst for dim,lst in sk if dim=='0'), [])
        node_ids = [next(iter(fz))[0] for fz in zero]
        if len(node_ids) != n:
            print(f"⚠️ Node count mismatch for {smile}: {len(node_ids)} vs {n}")
            return None

        # 5) Edges and edge features from bonds
        bonds = ac.get_bonds()  # list of (atom1, atom2, [type, order])
        # Map atom symbols to node indices
        atom_map = defaultdict(list)
        for idx, nid in enumerate(node_ids):
            sym = nid.split('_')[0]
            atom_map[sym].append(idx)

        edge_index_list = []
        edge_attr_list = []
        for a1, a2, (btype, order) in bonds:
            for i in atom_map.get(a1, []):
                for j in atom_map.get(a2, []):
                    edge_index_list.append([i, j])
                    t_int = {'SINGLE': 1, 'DOUBLE': 2, 'TRIPLE': 3}.get(btype, 0)
                    edge_attr_list.append([t_int, float(order)])

        # Fallback: if no bond edges, use 1-simplices from CW-complex
        if not edge_index_list:
            one_list = next((lst for dim, lst in sk if dim == '1'), [])
            for fz in one_list:
                ids = [nid for nid, _ in fz]
                if len(ids) == 2 and ids[0] in node_ids and ids[1] in node_ids:
                    i, j = node_ids[ids[0]], node_ids[ids[1]]
                    edge_index_list.extend([[i, j], [j, i]])
                    # dummy edge_attr for fallback edges
                    edge_attr_list.extend([[0, 0.0], [0, 0.0]])

        edge_index = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr_list, dtype=torch.float32)

        # 6) Combine node+graph descriptors
        desc = torch.tensor(descriptors, dtype=torch.float32)
        desc_feats = desc.unsqueeze(0).expand(n,-1)
        x = torch.cat([x_node, desc_feats], dim=1)

        # 7) Build Data
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        data.graph_feats = g_feats
        return data

    except Exception as e:
        print(f"Failed {smile}: {e}")
        return None

In [8]:
def decode_feature_from_frozenset(fzset):
    """
    Expects a frozenset containing a single tuple:
    ('P_3', ('proton', byte_blob))
    Returns a 1D numpy array by decoding the byte blob.
    """
    for _, (_, byte_blob) in fzset:
        try:
            arr = np.frombuffer(byte_blob, dtype=np.float32)
        except (ValueError, TypeError):
            arr = np.frombuffer(byte_blob, dtype=np.float64)
        return arr
    return None

def load_dataset():
    df = load_lipophil_data()
    graphs = []
    for _, row in df.iterrows():
        descriptors = []
        g = build_graph_from_smiles(row['smiles'], descriptors)
        if g is not None:
            g.y = torch.tensor([row['expt']], dtype=torch.float)
            graphs.append(g)
    return graphs


def process_row(args):
    idx, row = args
    try:
        descriptors = []
        graph = build_graph_from_smiles(row['smiles'], descriptors)
        if graph is not None:
            graph.y = torch.tensor([row['exp']], dtype=torch.float)
        print(f"Success for smile {row['smiles']} at index {idx}")
        return idx, graph
    except Exception as e:
        print(f"Failed on row {idx}: {e}")
        return idx, None

In [9]:
def load_dataset_parallel(df: pd.DataFrame, num_workers=4):
    with ThreadPool(num_workers) as pool:
        results = pool.map(process_row, [(i, row) for i, row in df.iterrows()])
    results = sorted([r for r in results if r[1] is not None], key=lambda x: x[0])
    return [g for _, g in results]

In [10]:
loading = True
if not loading:
    all_data = load_dataset_parallel(data, num_workers=8)
    torch.save(all_data, 'all_data_lipophil.pt')

In [11]:
from torch_geometric.nn import SAGEConv, GATv2Conv, global_mean_pool, global_max_pool, GraphNorm

class ResidualGATv2Block(nn.Module):
    def __init__(self, hidden_dim, heads=4, dropout=0.2):
        super().__init__()
        self.attn = GATv2Conv(hidden_dim, hidden_dim, heads=heads, concat=False)
        self.norm = GraphNorm(hidden_dim)
        self.dropout = dropout

    def forward(self, x, edge_index):
        residual = x
        x = self.attn(x, edge_index)
        x = self.norm(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x + residual

class PolyatomicNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=1, num_layers=3, dropout=0.3):
        super().__init__()

        self.embed = nn.Linear(input_dim, hidden_dim)

        self.convs = nn.ModuleList([
            ResidualGATv2Block(hidden_dim, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.readout = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, output_dim)
        )

    def forward(self, x, edge_index, batch):
        if edge_index.numel() == 0 or edge_index.size(1) == 0:
            n = x.size(0)
            edge_index = torch.stack([torch.arange(n), torch.arange(n)], dim=0).to(x.device)

        x = F.relu(self.embed(x))

        for conv in self.convs:
            x = conv(x, edge_index)

        x_mean = global_mean_pool(x, batch)
        x_max = global_max_pool(x, batch)
        x = torch.cat([x_mean, x_max], dim=1)

        return self.readout(x).view(-1)

In [12]:
def collate_with_graph_feats(batch_list):
    graph_feats = torch.stack([data.graph_feats for data in batch_list], dim=0)
    for data in batch_list:
        del data.graph_feats
    batched = Batch.from_data_list(batch_list)
    batched.graph_feats = graph_feats
    return batched

In [13]:
def train(model, loader, opt, loss_fn):
    model.train()
    total = 0.0
    for batch in loader:
        opt.zero_grad()
        out = model(batch.x, batch.edge_index, batch.batch)
        loss = loss_fn(out, batch.y.view(-1))
        loss.backward()
        opt.step()
        total += loss.item() * batch.num_graphs
    return total / len(loader.dataset)

import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
def compute_metrics_with_ci(trues, preds, n_boot=2000, alpha=0.05, seed=42):
    """
    Compute MAE and RMSE with bootstrap 95% confidence intervals.
    """
    trues = np.array(trues)
    preds = np.array(preds)
    mae = mean_absolute_error(trues, preds)
    rmse = np.sqrt(mean_squared_error(trues, preds))

    rng = np.random.RandomState(seed)
    mae_samples = []
    rmse_samples = []
    n = len(trues)
    for _ in range(n_boot):
        idx = rng.randint(0, n, n)
        t = trues[idx]
        p = preds[idx]
        mae_samples.append(mean_absolute_error(t, p))
        rmse_samples.append(np.sqrt(mean_squared_error(t, p)))

    lower = 100 * (alpha/2)
    upper = 100 * (1 - alpha/2)
    mae_ci = (np.percentile(mae_samples, lower), np.percentile(mae_samples, upper))
    rmse_ci = (np.percentile(rmse_samples, lower), np.percentile(rmse_samples, upper))
    return {'mae': mae, 'mae_ci': mae_ci, 'rmse': rmse, 'rmse_ci': rmse_ci}

def evaluate(model, loader):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for batch in loader:
            out = model(batch.x, batch.edge_index, batch.batch)
            preds.append(out)
            trues.append(batch.y.view(-1))
    preds = torch.cat(preds)
    trues = torch.cat(trues)
    return torch.sqrt(torch.mean((preds - trues)**2)).item()


def evaluate_with_ci(model, loader):
    model.eval()
    preds, trues = [], []
    with torch.no_grad():
        for batch in loader:
            out = model(batch.x, batch.edge_index, batch.batch)
            y_true = batch.y.view(-1).cpu().tolist()
            y_pred = out.view(-1).cpu().tolist()
            trues.extend(y_true)
            preds.extend(y_pred)
    metrics = compute_metrics_with_ci(trues, preds)
    return metrics


In [14]:
loading = True
if loading:
    all_data = torch.load('all_data_lipophil.pt', weights_only=False)
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    ys = np.array([d.y.item() for d in all_data]).reshape(-1, 1)
    ys_scaled = scaler.fit_transform(ys)
    for i, d in enumerate(all_data):
        d.y = torch.tensor([ys_scaled[i][0]], dtype=torch.float32)

In [17]:
data_list = all_data
train_n = int(0.8*len(data_list))
train_ds, test_ds = torch.utils.data.random_split(data_list, [train_n,len(data_list)-train_n])
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=32)



input_dim = all_data[0].x.size(1)
model     = PolyatomicNet(input_dim=input_dim, hidden_dim=64, output_dim=1, num_layers=2, dropout=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
loss_fn   = nn.L1Loss()

In [None]:
history = {}

for epoch in range(1, 21):
    tr_loss = train(model, train_loader, optimizer, loss_fn)
    metrics = evaluate_with_ci(model, test_loader)
    print(f"Epoch {epoch:02d} | Train Loss: {tr_loss:.4f} | "
          f"Test MAE: {metrics['mae']:.4f} (95% CI [{metrics['mae_ci'][0]:.4f}, {metrics['mae_ci'][1]:.4f}]) | "
          f"Test RMSE: {metrics['rmse']:.4f} (95% CI [{metrics['rmse_ci'][0]:.4f}, {metrics['rmse_ci'][1]:.4f}])")
    history[epoch] = metrics

Epoch 01 | Train Loss: 0.8103 | Test MAE: 0.7827 (95% CI [0.7429, 0.8239]) | Test RMSE: 0.9923 (95% CI [0.9425, 1.0430])


In [None]:
final = history[20]
print("*"*20)
print(f"Test MAE: {final['mae']:.4f} (95% CI [{final['mae_ci'][0]:.4f}, {final['mae_ci'][1]:.4f}])")
print(f"Test RMSE: {final['rmse']:.4f} (95% CI [{final['rmse_ci'][0]:.4f}, {final['rmse_ci'][1]:.4f}])")
print("*"*20)