# MTL

In [None]:
# ============= MODIFICATIONS FOR MULTI-TASK LEARNING =============

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# 1. MASKED MSE LOSS WITH TASK WEIGHTING
class MaskedMSELoss(nn.Module):
    """
    Masked MSE loss that handles NaN values and applies task-specific weighting
    based on the inverse of the range of each task.
    """
    def __init__(self, task_ranges=None):
        super(MaskedMSELoss, self).__init__()
        self.task_ranges = task_ranges
        if task_ranges is not None:
            # Calculate task weights based on inverse range
            weights = []
            for range_val in task_ranges.values():
                weights.append(1.0 / range_val if range_val > 0 else 1.0)
            total_weight = sum(weights)
            self.task_weights = torch.tensor([w / total_weight for w in weights])
        else:
            self.task_weights = None
    
    def forward(self, pred, target):
        """
        pred: [batch_size, n_tasks]
        target: [batch_size, n_tasks]
        """
        # Create mask for non-NaN values
        mask = ~torch.isnan(target)
        
        # Calculate MSE only for non-NaN values
        if mask.sum() == 0:
            return torch.tensor(0.0, requires_grad=True)
        
        # Apply mask
        pred_masked = pred[mask]
        target_masked = target[mask]
        
        # Calculate squared errors
        se = (pred_masked - target_masked) ** 2
        
        # If we have task weights, apply them
        if self.task_weights is not None:
            # Expand mask to get task indices
            task_indices = torch.where(mask)[1]
            weights = self.task_weights.to(pred.device)[task_indices]
            weighted_se = se * weights
            loss = weighted_se.mean()
        else:
            loss = se.mean()
        
        return loss

    
# 2. MODIFIED DTA MODEL FOR MULTI-TASK LEARNING
class MTL_DTAModel(nn.Module):
    def __init__(self,
            task_names=['pKi', 'pEC50', 'pKd', 'pIC50'],  # List of tasks
            prot_emb_dim=1280,
            prot_gcn_dims=[128, 256, 256],
            prot_fc_dims=[1024, 128],
            drug_node_in_dim=[66, 1], drug_node_h_dims=[128, 64],
            drug_edge_in_dim=[16, 1], drug_edge_h_dims=[32, 1],            
            drug_fc_dims=[1024, 128],
            mlp_dims=[1024, 512], mlp_dropout=0.25):
        super(MTL_DTAModel, self).__init__()
        
        self.task_names = task_names
        self.n_tasks = len(task_names)
        
        # Same encoders as before
        self.drug_model = DrugGVPModel(
            node_in_dim=drug_node_in_dim, node_h_dim=drug_node_h_dims,
            edge_in_dim=drug_edge_in_dim, edge_h_dim=drug_edge_h_dims,
        )
        drug_emb_dim = drug_node_h_dims[0]
        
        self.prot_model = Prot3DGraphModel(
            d_pretrained_emb=prot_emb_dim, d_gcn=prot_gcn_dims
        )
        prot_emb_dim = prot_gcn_dims[-1]
        
        self.drug_fc = self.get_fc_layers(
            [drug_emb_dim] + drug_fc_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)
       
        self.prot_fc = self.get_fc_layers(
            [prot_emb_dim] + prot_fc_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)
        
        # Shared representation layers
        self.shared_fc = self.get_fc_layers(
            [drug_fc_dims[-1] + prot_fc_dims[-1]] + mlp_dims,
            dropout=mlp_dropout, batchnorm=False,
            no_last_dropout=True, no_last_activation=True)
        
        # Task-specific heads (one for each task)
        self.task_heads = nn.ModuleDict({
            task: nn.Linear(mlp_dims[-1], 1) for task in task_names
        })
    
    def get_fc_layers(self, hidden_sizes,
            dropout=0, batchnorm=False,
            no_last_dropout=True, no_last_activation=True):
        act_fn = torch.nn.LeakyReLU()
        layers = []
        for i, (in_dim, out_dim) in enumerate(zip(hidden_sizes[:-1], hidden_sizes[1:])):
            layers.append(nn.Linear(in_dim, out_dim))
            if not no_last_activation or i != len(hidden_sizes) - 2:
                layers.append(act_fn)
            if dropout > 0:
                if not no_last_dropout or i != len(hidden_sizes) - 2:
                    layers.append(nn.Dropout(dropout))
            if batchnorm and i != len(hidden_sizes) - 2:
                layers.append(nn.BatchNorm1d(out_dim))
        return nn.Sequential(*layers)
    
    def forward(self, xd, xp):
        # Encode drug and protein
        xd = self.drug_model(xd)
        xp = self.prot_model(xp)
        
        # Process through FC layers
        xd = self.drug_fc(xd)
        xp = self.prot_fc(xp)
        
        # Concatenate and process through shared layers
        x = torch.cat([xd, xp], dim=1)
        shared_repr = self.shared_fc(x)
        
        # Generate predictions for each task
        outputs = []
        for task in self.task_names:
            task_pred = self.task_heads[task](shared_repr)
            outputs.append(task_pred)
        
        # Stack outputs: [batch_size, n_tasks]
        return torch.cat(outputs, dim=1)

# 3. MODIFIED DTA DATASET CLASS
class MTL_DTA(data.Dataset):
    def __init__(self, df=None, data_list=None, task_cols=None, onthefly=False,
                prot_featurize_fn=None, drug_featurize_fn=None):
        super(MTL_DTA, self).__init__()
        self.data_df = df
        self.data_list = data_list
        self.task_cols = task_cols or ['pKi', 'pEC50', 'pKd', 'pIC50']
        self.onthefly = onthefly
        self.prot_featurize_fn = prot_featurize_fn
        self.drug_featurize_fn = drug_featurize_fn
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        if self.onthefly:
            drug = self.drug_featurize_fn(
                self.data_list[idx]['drug'],
                name=self.data_list[idx]['drug_name']
            )
            prot = self.prot_featurize_fn(
                self.data_list[idx]['protein'],
                name=self.data_list[idx]['protein_name']
            )
        else:
            drug = self.data_list[idx]['drug']
            prot = self.data_list[idx]['protein']
        
        # Get multi-task targets
        y_multi = []
        for task in self.task_cols:
            val = self.data_list[idx].get(task, np.nan)
            y_multi.append(val if not pd.isna(val) else np.nan)
        
        y = torch.tensor(y_multi, dtype=torch.float32)
        
        item = {'drug': drug, 'protein': prot, 'y': y}
        return item

    



# Dataset builder for DTA class
def build_dataset(df_fold, pdb_structures, exp_cols = "pKi", is_pred = False):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        if is_pred == True:
            data_list.append({
                "protein": protein,
                "drug": drug,
                "y": 0
            })

        else:
            data_list.append({
                "protein": protein,
                "drug": drug,
                "y": float(row[exp_cols]),
            })
    return DTA(df=df_fold, data_list=data_list)



# 4. MODIFIED BUILD DATASET FUNCTION
def build_mtl_dataset(df_fold, pdb_structures, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)

# 5. MODIFIED TRAINING LOOP
def train_mtl_model(model, train_loader, valid_loader, task_cols, task_ranges, 
                    n_epochs=100, lr=0.0005, device='cuda', patience=20):
    """
    Training loop for multi-task learning model
    
    Args:
        task_cols: List of task column names
        task_ranges: Dict mapping task names to their value ranges
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = MaskedMSELoss(task_ranges=task_ranges)
    stopper = EarlyStopping(patience=patience, higher_better=False)
    best_model = None
    
    for epoch in range(n_epochs):
        # Training
        model.train()
        train_loss = 0
        n_batches = 0
        
        for batch in train_loader:
            xd = batch['drug'].to(device)
            xp = batch['protein'].to(device)
            y = batch['y'].to(device)  # [batch_size, n_tasks]
            
            optimizer.zero_grad()
            pred = model(xd, xp)  # [batch_size, n_tasks]
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            n_batches += 1
        
        # Validation
        model.eval()
        val_loss = 0
        val_n_batches = 0
        task_metrics = {task: {'mse': 0, 'n': 0} for task in task_cols}
        
        with torch.no_grad():
            for batch in valid_loader:
                xd = batch['drug'].to(device)
                xp = batch['protein'].to(device)
                y = batch['y'].to(device)
                
                pred = model(xd, xp)
                loss = criterion(pred, y)
                val_loss += loss.item()
                val_n_batches += 1
                
                # Calculate per-task metrics
                for i, task in enumerate(task_cols):
                    mask = ~torch.isnan(y[:, i])
                    if mask.sum() > 0:
                        task_mse = F.mse_loss(pred[mask, i], y[mask, i])
                        task_metrics[task]['mse'] += task_mse.item()
                        task_metrics[task]['n'] += 1
        
        avg_train_loss = train_loss / n_batches
        avg_val_loss = val_loss / val_n_batches if val_n_batches > 0 else float('inf')
        
        # Print metrics
        print(f"Epoch {epoch+1}/{n_epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Valid Loss: {avg_val_loss:.4f}")
        
        for task in task_cols:
            if task_metrics[task]['n'] > 0:
                avg_task_mse = task_metrics[task]['mse'] / task_metrics[task]['n']
                print(f"  {task} MSE: {avg_task_mse:.4f}")
        
        # Early stopping
        if stopper.update(avg_val_loss):
            best_model = model.state_dict()
        if stopper.early_stop:
            print("Early stopping triggered")
            break
    
    if best_model is not None:
        model.load_state_dict(best_model)
    
    return model

# 6. EXAMPLE USAGE
def prepare_mtl_experiment(df, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    """
    Prepare data for multi-task learning
    """
    # Calculate task ranges for weighting
    task_ranges = {}
    for task in task_cols:
        if task in df.columns:
            valid_values = df[task].dropna()
            if len(valid_values) > 0:
                task_ranges[task] = valid_values.max() - valid_values.min()
            else:
                task_ranges[task] = 1.0
        else:
            task_ranges[task] = 1.0
    
    print("Task ranges for weighting:")
    for task, range_val in task_ranges.items():
        weight = 1.0 / range_val if range_val > 0 else 1.0
        normalized_weight = weight / sum(1.0/r if r > 0 else 1.0 for r in task_ranges.values())
        print(f"  {task}: range={range_val:.2f}, weight={normalized_weight:.4f}")
    
    return task_ranges



import json

def structureJSON(df, esm_model):
    structure_dict = {}

    for i, row in tqdm(df.iterrows(), total=len(df)):
        pdb_path = row["standardized_protein_pdb"]
        try:

            pdb_id = os.path.basename(pdb_path).split('.')[0]

            structure = parser.get_structure(pdb_id, pdb_path)
            seq, coords, chain_id = extract_backbone_coords(structure, pdb_id, pdb_path)
            if seq is None:
                available = [c.id for c in structure[0]]
                print(f"[SKIP] {pdb_id}: no usable chain found (available: {available})")
                continue


            # Stack in order: N, CA, C, O --> [L, 4, 3]
            coords_stacked = []
            for i in range(len(coords["N"])):
                coord_group = []
                for atom in ["N", "CA", "C", "O"]:
                    coord_group.append(coords[atom][i])
                coords_stacked.append(coord_group)

            if coords_stacked is None:
                print(f"[SKIP] {pdb_id}: no usable coords found (available: {pdb_path})")
                continue

                
            embedding = get_esm_embedding(seq, esm_model)
            torch.save(embedding, f"esm_embeddings/{pdb_id}.pt")

            if coords_stacked != None and embedding != None:
                structure_dict[pdb_id] = {
                    "name": pdb_id,
                    "UniProt_id": "UNKNOWN",
                    "PDB_id": pdb_id,
                    "chain": chain_id,
                    "seq": seq,
                    "coords": coords_stacked,
                    "embed": f"esm_embeddings/{pdb_id}.pt"

                }

        except Exception as e:
            print(f"[ERROR] Failed to process {pdb_id}: {e}")
            continue



    # Save to JSON
    with open("../data/pockets_structure.json", "w") as f:
        json.dump(structure_dict, f, indent=2)


    print(f"\n✅ Done. Saved {len(structure_dict)} protein structures to pockets_structure.json")

    return(structure_dict)



In [None]:
import os
import json
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm.auto import tqdm
import torch
from Bio.PDB import PDBParser  # make sure Biopython is installed

# Assumes you already have:
# - extract_backbone_coords(structure, pdb_id, pdb_path)
# - get_esm_embedding(seq, esm_model)

ATOMS = ("N", "CA", "C", "O")
EMBED_DIR = Path("esm_embeddings")
EMBED_DIR.mkdir(parents=True, exist_ok=True)

def _stack_backbone(coords):
    # coords: dict with keys "N","CA","C","O", each a list of [x,y,z]
    L = len(coords["N"])
    return [[coords[a][i] for a in ATOMS] for i in range(L)]

def _process_pdb_path(pdb_path):
    """
    Worker: parse PDB, extract seq/coords/chain, return tuple or a skip marker.
    Runs in a separate process; initializes its own parser.
    """
    try:
        pdb_id = os.path.basename(pdb_path).split('.')[0]
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure(pdb_id, pdb_path)

        seq, coords, chain_id = extract_backbone_coords(structure, pdb_id, pdb_path)
        if seq is None:
            available = [c.id for c in structure[0]]
            return ("skip", pdb_id, f"no usable chain (available: {available})")

        if not coords or any(k not in coords for k in ATOMS) or len(coords["N"]) == 0:
            return ("skip", pdb_id, "no usable coords")

        coords_stacked = _stack_backbone(coords)
        if not coords_stacked:
            return ("skip", pdb_id, "empty coords after stacking")

        return ("ok", pdb_id, seq, coords_stacked, chain_id)

    except Exception as e:
        return ("error", os.path.basename(pdb_path).split('.')[0], str(e))

def structureJSON(df, esm_model, max_workers=None, embed_batch_size=8, out_json="../data/pockets_structure.json"):
    structure_dict = {}

    pdb_paths = df["standardized_protein_pdb"].tolist()
    results = []

    # Phase 1: parallel PDB parsing + coordinate extraction
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = {ex.submit(_process_pdb_path, p): p for p in pdb_paths}
        for fut in tqdm(as_completed(futures), total=len(futures), desc="PDB -> seq/coords"):
            status_tuple = fut.result()
            results.append(status_tuple)

    # Log skips/errors (fast)
    for r in results:
        tag = r[0]
        if tag == "skip":
            _, pdb_id, msg = r
            print(f"[SKIP] {pdb_id}: {msg}")
        elif tag == "error":
            _, pdb_id, err = r
            print(f"[ERROR] Failed to process {pdb_id}: {err}")

    # Keep only successful items
    ok_items = [(pdb_id, seq, coords_stacked, chain_id)
                for tag, pdb_id, *rest in results if tag == "ok"
                for (seq, coords_stacked, chain_id) in [tuple(rest)]]

    # Phase 2: embeddings on a single device (GPU/CPU) to avoid per-process model copies
    # Optionally batch if your get_esm_embedding supports lists; otherwise do per-sequence.
    # Here we do per-sequence by default; simple and safe.
    for pdb_id, seq, coords_stacked, chain_id in tqdm(ok_items, desc="ESM embeddings"):
        try:
            embedding = get_esm_embedding(seq, esm_model)  # ensure this returns a tensor
            torch.save(embedding, EMBED_DIR / f"{pdb_id}.pt")

            structure_dict[pdb_id] = {
                "name": pdb_id,
                "UniProt_id": "UNKNOWN",
                "PDB_id": pdb_id,
                "chain": chain_id,
                "seq": seq,
                "coords": coords_stacked,         # [[N,CA,C,O], ...], each as [x,y,z]
                "embed": str(EMBED_DIR / f"{pdb_id}.pt")
            }
        except Exception as e:
            print(f"[ERROR] ESM embedding failed for {pdb_id}: {e}")

    # Save to JSON
    os.makedirs(os.path.dirname(out_json), exist_ok=True)
    with open(out_json, "w") as f:
        json.dump(structure_dict, f, indent=2)

    print(f"\n✅ Done. Saved {len(structure_dict)} protein structures to {os.path.basename(out_json)}")
    return structure_dict





import os
import json
import pandas as pd
import torch
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np

def structureJSON_chunked(df, esm_model, max_workers=None, embed_batch_size=8, 
                          chunk_size=100000, out_dir="../data/structure_chunks/"):
    """
    Process structures in chunks to avoid memory issues.
    
    Args:
        df: DataFrame with protein PDB paths
        esm_model: ESM model for embeddings
        max_workers: Number of parallel workers
        chunk_size: Maximum entries per chunk (default 100000)
        out_dir: Directory to save chunked JSON files
    
    Returns:
        dict: Metadata about created chunks
    """
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(EMBED_DIR, exist_ok=True)
    
    pdb_paths = df["standardized_protein_pdb"].tolist()
    total_pdbs = len(pdb_paths)
    num_chunks = (total_pdbs + chunk_size - 1) // chunk_size
    
    print(f"Processing {total_pdbs} PDBs in {num_chunks} chunks of max {chunk_size} each")
    
    chunk_metadata = {
        "num_chunks": num_chunks,
        "chunk_size": chunk_size,
        "chunks": []
    }
    
    # Process in chunks
    for chunk_idx in range(num_chunks):
        start_idx = chunk_idx * chunk_size
        end_idx = min((chunk_idx + 1) * chunk_size, total_pdbs)
        chunk_paths = pdb_paths[start_idx:end_idx]
        
        print(f"\n=== Processing chunk {chunk_idx + 1}/{num_chunks} ({len(chunk_paths)} PDBs) ===")
        
        structure_dict = {}
        results = []
        
        # Phase 1: Parallel PDB parsing for this chunk
        with ProcessPoolExecutor(max_workers=max_workers) as ex:
            futures = {ex.submit(_process_pdb_path, p): p for p in chunk_paths}
            for fut in tqdm(as_completed(futures), total=len(futures), 
                          desc=f"Chunk {chunk_idx + 1} - PDB parsing"):
                status_tuple = fut.result()
                results.append(status_tuple)
        
        # Log errors for this chunk
        for r in results:
            tag = r[0]
            if tag == "skip":
                _, pdb_id, msg = r
                print(f"[SKIP] {pdb_id}: {msg}")
            elif tag == "error":
                _, pdb_id, err = r
                print(f"[ERROR] Failed to process {pdb_id}: {err}")
        
        # Keep only successful items
        ok_items = [(pdb_id, seq, coords_stacked, chain_id)
                    for tag, pdb_id, *rest in results if tag == "ok"
                    for (seq, coords_stacked, chain_id) in [tuple(rest)]]
        
        # Phase 2: ESM embeddings for this chunk
        for pdb_id, seq, coords_stacked, chain_id in tqdm(ok_items, 
                                                          desc=f"Chunk {chunk_idx + 1} - ESM embeddings"):
            try:
                embedding = get_esm_embedding(seq, esm_model)
                embed_path = EMBED_DIR / f"{pdb_id}.pt"
                torch.save(embedding, embed_path)
                
                structure_dict[pdb_id] = {
                    "name": pdb_id,
                    "UniProt_id": "UNKNOWN",
                    "PDB_id": pdb_id,
                    "chain": chain_id,
                    "seq": seq,
                    "coords": coords_stacked,
                    "embed": str(embed_path)
                }
            except Exception as e:
                print(f"[ERROR] ESM embedding failed for {pdb_id}: {e}")
        
        # Save this chunk
        chunk_filename = f"structures_chunk_{chunk_idx:04d}.json"
        chunk_path = os.path.join(out_dir, chunk_filename)
        with open(chunk_path, "w") as f:
            json.dump(structure_dict, f, indent=2)
        
        chunk_info = {
            "chunk_idx": chunk_idx,
            "filename": chunk_filename,
            "path": chunk_path,
            "num_structures": len(structure_dict),
            "start_idx": start_idx,
            "end_idx": end_idx
        }
        chunk_metadata["chunks"].append(chunk_info)
        
        print(f"✅ Chunk {chunk_idx + 1} saved: {len(structure_dict)} structures to {chunk_filename}")
    
    # Save metadata
    metadata_path = os.path.join(out_dir, "chunk_metadata.json")
    with open(metadata_path, "w") as f:
        json.dump(chunk_metadata, f, indent=2)
    
    print(f"\n✅ All chunks processed. Metadata saved to {metadata_path}")
    return chunk_metadata


class StructureChunkLoader:
    """
    Efficient loader for chunked structure dictionaries.
    Loads chunks on-demand and caches them.
    """
    def __init__(self, chunk_dir="../data/structure_chunks/", cache_size=2):
        self.chunk_dir = chunk_dir
        self.cache_size = cache_size
        self.cache = {}  # chunk_idx -> structure_dict
        self.cache_order = []  # LRU tracking
        
        # Load metadata
        metadata_path = os.path.join(chunk_dir, "chunk_metadata.json")
        with open(metadata_path, "r") as f:
            self.metadata = json.load(f)
        
        # Build lookup: pdb_id -> chunk_idx
        self.pdb_to_chunk = {}
        for chunk_info in self.metadata["chunks"]:
            chunk_path = os.path.join(chunk_dir, chunk_info["filename"])
            # Quick scan to build index (could be saved in metadata for efficiency)
            with open(chunk_path, "r") as f:
                chunk_data = json.load(f)
                for pdb_id in chunk_data.keys():
                    self.pdb_to_chunk[pdb_id] = chunk_info["chunk_idx"]
    
    def _load_chunk(self, chunk_idx):
        """Load a chunk into cache, managing cache size."""
        if chunk_idx in self.cache:
            # Move to end (most recently used)
            self.cache_order.remove(chunk_idx)
            self.cache_order.append(chunk_idx)
            return self.cache[chunk_idx]
        
        # Load chunk
        chunk_info = self.metadata["chunks"][chunk_idx]
        chunk_path = os.path.join(self.chunk_dir, chunk_info["filename"])
        with open(chunk_path, "r") as f:
            chunk_data = json.load(f)
        
        # Add to cache
        self.cache[chunk_idx] = chunk_data
        self.cache_order.append(chunk_idx)
        
        # Evict oldest if cache is full
        if len(self.cache) > self.cache_size:
            oldest = self.cache_order.pop(0)
            del self.cache[oldest]
        
        return chunk_data
    
    def get(self, pdb_id):
        """Get structure for a specific PDB ID."""
        if pdb_id not in self.pdb_to_chunk:
            return None
        
        chunk_idx = self.pdb_to_chunk[pdb_id]
        chunk_data = self._load_chunk(chunk_idx)
        return chunk_data.get(pdb_id)
    
    def get_batch(self, pdb_ids):
        """Get multiple structures efficiently by grouping by chunk."""
        # Group PDB IDs by chunk
        chunk_groups = {}
        for pdb_id in pdb_ids:
            if pdb_id in self.pdb_to_chunk:
                chunk_idx = self.pdb_to_chunk[pdb_id]
                if chunk_idx not in chunk_groups:
                    chunk_groups[chunk_idx] = []
                chunk_groups[chunk_idx].append(pdb_id)
        
        # Load each chunk and extract structures
        results = {}
        for chunk_idx, chunk_pdb_ids in chunk_groups.items():
            chunk_data = self._load_chunk(chunk_idx)
            for pdb_id in chunk_pdb_ids:
                if pdb_id in chunk_data:
                    results[pdb_id] = chunk_data[pdb_id]
        
        return results
    
    def get_available_pdb_ids(self):
        """Return set of all available PDB IDs."""
        return set(self.pdb_to_chunk.keys())


def build_mtl_dataset_optimized(df_fold, chunk_loader, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    """
    Build MTL dataset efficiently using chunked structure loader.
    
    Args:
        df_fold: DataFrame with fold data
        chunk_loader: StructureChunkLoader instance
        task_cols: List of task columns
    
    Returns:
        MTL_DTA dataset
    """
    data_list = []
    
    # Get all protein IDs from the fold
    protein_ids = df_fold["protein_id"].tolist()
    
    # Batch load structures (efficient chunk-based loading)
    print(f"Loading structures for {len(protein_ids)} proteins...")
    structures_batch = chunk_loader.get_batch(protein_ids)
    
    # Process each row
    skipped = 0
    for i, row in tqdm(df_fold.iterrows(), total=len(df_fold), desc="Building dataset"):
        protein_id = row["protein_id"]
        
        if protein_id not in structures_batch:
            skipped += 1
            continue
        
        protein_json = structures_batch[protein_id]
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    if skipped > 0:
        print(f"Warning: Skipped {skipped} entries due to missing structures")
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)

# 4. MODIFIED BUILD DATASET FUNCTION
def build_mtl_dataset(df_fold, pdb_structures, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)

# ============= USAGE EXAMPLE =============


In [None]:
import os
import json
import pandas as pd
import torch
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import numpy as np
from multiprocessing import Pool, cpu_count
import gc

def process_single_chunk(args):
    """
    Process a single chunk of PDB files independently.
    This function is designed to be run in parallel.
    
    Args:
        args: tuple of (chunk_idx, pdb_paths, out_dir, embed_dir, esm_model_name)
    
    Returns:
        dict with chunk processing results
    """
    chunk_idx, pdb_paths, out_dir, embed_dir, esm_model_name = args
    
    # Import inside function for multiprocessing
    from transformers import EsmModel, EsmTokenizer
    from concurrent.futures import ProcessPoolExecutor, as_completed
    from tqdm import tqdm
    import torch
    import json
    import os
    
    print(f"\n[Chunk {chunk_idx}] Starting processing of {len(pdb_paths)} PDBs")
    
    # Load ESM model for this process
    print(f"[Chunk {chunk_idx}] Loading ESM model...")
    tokenizer = EsmTokenizer.from_pretrained(esm_model_name)
    esm_model = EsmModel.from_pretrained(esm_model_name)
    esm_model.eval()
    
    # Move to GPU if available (each process gets its own GPU memory)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        # For multi-GPU, assign different chunks to different GPUs
        num_gpus = torch.cuda.device_count()
        gpu_id = chunk_idx % num_gpus
        device = torch.device(f'cuda:{gpu_id}')
        esm_model = esm_model.to(device)
        print(f"[Chunk {chunk_idx}] Using GPU {gpu_id}")
    else:
        print(f"[Chunk {chunk_idx}] Using CPU")
    
    structure_dict = {}
    results = []
    
    # Phase 1: Parallel PDB parsing within this chunk
    max_workers = min(8, cpu_count() // 4)  # Limit workers per chunk
    with ProcessPoolExecutor(max_workers=max_workers) as ex:
        futures = {ex.submit(_process_pdb_path, p): p for p in pdb_paths}
        for fut in tqdm(as_completed(futures), total=len(futures), 
                      desc=f"Chunk {chunk_idx} - PDB parsing", position=chunk_idx):
            try:
                status_tuple = fut.result(timeout=30)  # Add timeout
                results.append(status_tuple)
            except Exception as e:
                print(f"[Chunk {chunk_idx}] Error processing PDB: {e}")
    
    # Log errors
    error_count = 0
    skip_count = 0
    for r in results:
        tag = r[0]
        if tag == "skip":
            skip_count += 1
        elif tag == "error":
            error_count += 1
    
    if error_count > 0 or skip_count > 0:
        print(f"[Chunk {chunk_idx}] Skipped: {skip_count}, Errors: {error_count}")
    
    # Keep only successful items
    ok_items = [(pdb_id, seq, coords_stacked, chain_id)
                for tag, pdb_id, *rest in results if tag == "ok"
                for (seq, coords_stacked, chain_id) in [tuple(rest)]]
    
    # Phase 2: ESM embeddings (batch processing for efficiency)
    print(f"[Chunk {chunk_idx}] Computing ESM embeddings for {len(ok_items)} proteins...")
    
    os.makedirs(embed_dir, exist_ok=True)
    
    # Process in batches to optimize GPU usage
    batch_size = 8
    for i in tqdm(range(0, len(ok_items), batch_size), 
                  desc=f"Chunk {chunk_idx} - ESM embeddings", position=chunk_idx):
        batch = ok_items[i:i+batch_size]
        
        for pdb_id, seq, coords_stacked, chain_id in batch:
            try:
                # Compute embedding
                with torch.no_grad():
                    embedding = get_esm_embedding(seq, esm_model, tokenizer, device)
                
                # Save embedding
                embed_path = os.path.join(embed_dir, f"{pdb_id}.pt")
                torch.save(embedding.cpu(), embed_path)  # Save on CPU to free GPU memory
                
                structure_dict[pdb_id] = {
                    "name": pdb_id,
                    "UniProt_id": "UNKNOWN",
                    "PDB_id": pdb_id,
                    "chain": chain_id,
                    "seq": seq,
                    "coords": coords_stacked,
                    "embed": embed_path
                }
            except Exception as e:
                print(f"[Chunk {chunk_idx}] ESM embedding failed for {pdb_id}: {e}")
        
        # Periodic garbage collection
        if i % 100 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Save chunk
    chunk_filename = f"structures_chunk_{chunk_idx:04d}.json"
    chunk_path = os.path.join(out_dir, chunk_filename)
    with open(chunk_path, "w") as f:
        json.dump(structure_dict, f, indent=2)
    
    print(f"[Chunk {chunk_idx}] ✅ Completed: {len(structure_dict)} structures saved")
    
    # Clean up GPU memory
    del esm_model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return {
        "chunk_idx": chunk_idx,
        "filename": chunk_filename,
        "path": chunk_path,
        "num_structures": len(structure_dict),
        "num_errors": error_count,
        "num_skipped": skip_count
    }


def get_esm_embedding(seq, esm_model, tokenizer, device):
    """
    Get ESM embedding for a sequence.
    
    Args:
        seq: Protein sequence
        esm_model: ESM model
        tokenizer: ESM tokenizer
        device: torch device
    
    Returns:
        torch.Tensor: Embedding
    """
    inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = esm_model(**inputs)
        # Use mean pooling over sequence length
        embedding = outputs.last_hidden_state.mean(dim=1)
    
    return embedding


def structureJSON_chunked(df, esm_model_name="facebook/esm2_t33_650M_UR50D",
                         chunk_size=100000, max_chunks_parallel=4,
                         out_dir="../data/structure_chunks/",
                         embed_dir="../data/embeddings/"):
    """
    Process structures in parallel chunks to avoid memory issues and maximize speed.
    
    Args:
        df: DataFrame with protein PDB paths
        esm_model_name: Name of ESM model to use
        chunk_size: Maximum entries per chunk (default 100000)
        max_chunks_parallel: Maximum number of chunks to process in parallel
        out_dir: Directory to save chunked JSON files
        embed_dir: Directory to save embeddings
    
    Returns:
        dict: Metadata about created chunks
    """
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(embed_dir, exist_ok=True)
    
    # Get unique PDB paths (avoid duplicates)
    pdb_paths = df["standardized_protein_pdb"].unique().tolist()
    total_pdbs = len(pdb_paths)
    num_chunks = (total_pdbs + chunk_size - 1) // chunk_size
    
    print(f"=" * 80)
    print(f"Processing {total_pdbs} unique PDBs in {num_chunks} chunks")
    print(f"Chunk size: {chunk_size}, Parallel chunks: {max_chunks_parallel}")
    print(f"=" * 80)
    
    # Prepare chunk arguments
    chunk_args = []
    for chunk_idx in range(num_chunks):
        start_idx = chunk_idx * chunk_size
        end_idx = min((chunk_idx + 1) * chunk_size, total_pdbs)
        chunk_paths = pdb_paths[start_idx:end_idx]
        
        chunk_args.append((
            chunk_idx,
            chunk_paths,
            out_dir,
            embed_dir,
            esm_model_name
        ))
    
    # Process chunks in parallel
    chunk_results = []
    
    # Determine optimal number of parallel processes
    num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
    if num_gpus > 0:
        # If we have GPUs, process one chunk per GPU
        parallel_chunks = min(max_chunks_parallel, num_gpus, num_chunks)
        print(f"Using {num_gpus} GPUs, processing {parallel_chunks} chunks in parallel")
    else:
        # CPU only - limit parallelism to avoid memory issues
        parallel_chunks = min(max_chunks_parallel, cpu_count() // 4, num_chunks)
        print(f"Using CPU only, processing {parallel_chunks} chunks in parallel")
    
    # Process in batches of parallel chunks
    for batch_start in range(0, num_chunks, parallel_chunks):
        batch_end = min(batch_start + parallel_chunks, num_chunks)
        batch_args = chunk_args[batch_start:batch_end]
        
        print(f"\nProcessing chunk batch {batch_start+1}-{batch_end} of {num_chunks}")
        
        if len(batch_args) == 1:
            # Single chunk - process directly
            result = process_single_chunk(batch_args[0])
            chunk_results.append(result)
        else:
            # Multiple chunks - use multiprocessing
            with Pool(processes=len(batch_args)) as pool:
                results = pool.map(process_single_chunk, batch_args)
                chunk_results.extend(results)
        
        # Garbage collection between batches
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Create metadata
    chunk_metadata = {
        "num_chunks": num_chunks,
        "chunk_size": chunk_size,
        "total_structures": sum(r["num_structures"] for r in chunk_results),
        "total_errors": sum(r["num_errors"] for r in chunk_results),
        "total_skipped": sum(r["num_skipped"] for r in chunk_results),
        "chunks": []
    }
    
    # Add chunk info with proper indices
    start_idx = 0
    for result in sorted(chunk_results, key=lambda x: x["chunk_idx"]):
        end_idx = start_idx + result["num_structures"]
        chunk_info = {
            "chunk_idx": result["chunk_idx"],
            "filename": result["filename"],
            "path": result["path"],
            "num_structures": result["num_structures"],
            "num_errors": result["num_errors"],
            "num_skipped": result["num_skipped"],
            "start_idx": start_idx,
            "end_idx": end_idx
        }
        chunk_metadata["chunks"].append(chunk_info)
        start_idx = end_idx
    
    # Save metadata
    metadata_path = os.path.join(out_dir, "chunk_metadata.json")
    with open(metadata_path, "w") as f:
        json.dump(chunk_metadata, f, indent=2)
    
    print(f"\n{'=' * 80}")
    print(f"✅ Processing complete!")
    print(f"  - Total structures: {chunk_metadata['total_structures']}")
    print(f"  - Total errors: {chunk_metadata['total_errors']}")
    print(f"  - Total skipped: {chunk_metadata['total_skipped']}")
    print(f"  - Metadata saved: {metadata_path}")
    print(f"{'=' * 80}")
    
    return chunk_metadata


# ============= Helper function for PDB processing =============
def _process_pdb_path(pdb_path):
    """
    Process a single PDB file to extract sequence and coordinates.
    This function runs in a separate process.
    
    Returns:
        tuple: (status, pdb_id, data...) where status is "ok", "skip", or "error"
    """
    from Bio.PDB import PDBParser, is_aa
    
    parser = PDBParser(QUIET=True)
    pdb_id = os.path.splitext(os.path.basename(pdb_path))[0]
    
    try:
        structure = parser.get_structure(pdb_id, pdb_path)
        
        # Get first model
        model = structure[0]
        
        # Process each chain
        for chain in model:
            residues = [r for r in chain if is_aa(r)]
            if len(residues) == 0:
                continue
            
            # Extract sequence
            seq = ''.join([seq1(r.resname) for r in residues])
            
            # Extract coordinates [N, CA, C, O] for each residue
            coords = []
            for residue in residues:
                try:
                    n_coord = residue['N'].coord.tolist()
                    ca_coord = residue['CA'].coord.tolist()
                    c_coord = residue['C'].coord.tolist()
                    o_coord = residue['O'].coord.tolist()
                    coords.append([n_coord, ca_coord, c_coord, o_coord])
                except:
                    # Missing atoms - use zeros
                    coords.append([[0,0,0], [0,0,0], [0,0,0], [0,0,0]])
            
            return ("ok", pdb_id, seq, coords, chain.id)
        
        return ("skip", pdb_id, "No valid chains found")
        
    except Exception as e:
        return ("error", pdb_id, str(e))


# ============= Optimized Chunk Loader (same as before) =============
class StructureChunkLoader:
    """
    Efficient loader for chunked structure dictionaries.
    Loads chunks on-demand and caches them.
    """
    def __init__(self, chunk_dir="../data/structure_chunks/", cache_size=2):
        self.chunk_dir = chunk_dir
        self.cache_size = cache_size
        self.cache = {}  # chunk_idx -> structure_dict
        self.cache_order = []  # LRU tracking
        
        # Load metadata
        metadata_path = os.path.join(chunk_dir, "chunk_metadata.json")
        with open(metadata_path, "r") as f:
            self.metadata = json.load(f)
        
        print(f"Loaded metadata: {self.metadata['total_structures']} structures in {self.metadata['num_chunks']} chunks")
        
        # Build lookup: pdb_id -> chunk_idx
        self.pdb_to_chunk = {}
        for chunk_info in self.metadata["chunks"]:
            chunk_path = os.path.join(chunk_dir, chunk_info["filename"])
            if os.path.exists(chunk_path):
                with open(chunk_path, "r") as f:
                    chunk_data = json.load(f)
                    for pdb_id in chunk_data.keys():
                        self.pdb_to_chunk[pdb_id] = chunk_info["chunk_idx"]
            else:
                print(f"Warning: Chunk file not found: {chunk_path}")
    
    def _load_chunk(self, chunk_idx):
        """Load a chunk into cache, managing cache size."""
        if chunk_idx in self.cache:
            # Move to end (most recently used)
            self.cache_order.remove(chunk_idx)
            self.cache_order.append(chunk_idx)
            return self.cache[chunk_idx]
        
        # Load chunk
        chunk_info = self.metadata["chunks"][chunk_idx]
        chunk_path = os.path.join(self.chunk_dir, chunk_info["filename"])
        with open(chunk_path, "r") as f:
            chunk_data = json.load(f)
        
        # Add to cache
        self.cache[chunk_idx] = chunk_data
        self.cache_order.append(chunk_idx)
        
        # Evict oldest if cache is full
        if len(self.cache) > self.cache_size:
            oldest = self.cache_order.pop(0)
            del self.cache[oldest]
            gc.collect()  # Force garbage collection
        
        return chunk_data
    
    def get(self, pdb_id):
        """Get structure for a specific PDB ID."""
        if pdb_id not in self.pdb_to_chunk:
            return None
        
        chunk_idx = self.pdb_to_chunk[pdb_id]
        chunk_data = self._load_chunk(chunk_idx)
        return chunk_data.get(pdb_id)
    
    def get_batch(self, pdb_ids):
        """Get multiple structures efficiently by grouping by chunk."""
        # Group PDB IDs by chunk
        chunk_groups = {}
        for pdb_id in pdb_ids:
            if pdb_id in self.pdb_to_chunk:
                chunk_idx = self.pdb_to_chunk[pdb_id]
                if chunk_idx not in chunk_groups:
                    chunk_groups[chunk_idx] = []
                chunk_groups[chunk_idx].append(pdb_id)
        
        # Load each chunk and extract structures
        results = {}
        for chunk_idx, chunk_pdb_ids in chunk_groups.items():
            chunk_data = self._load_chunk(chunk_idx)
            for pdb_id in chunk_pdb_ids:
                if pdb_id in chunk_data:
                    results[pdb_id] = chunk_data[pdb_id]
        
        return results
    
    def get_available_pdb_ids(self):
        """Return set of all available PDB IDs."""
        return set(self.pdb_to_chunk.keys())


def build_mtl_dataset_optimized(df_fold, chunk_loader, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    """
    Build MTL dataset efficiently using chunked structure loader.
    
    Args:
        df_fold: DataFrame with fold data
        chunk_loader: StructureChunkLoader instance
        task_cols: List of task columns
    
    Returns:
        MTL_DTA dataset
    """
    data_list = []
    
    # Get all protein IDs from the fold
    protein_ids = df_fold["protein_id"].tolist()
    
    # Batch load structures (efficient chunk-based loading)
    print(f"Loading structures for {len(protein_ids)} proteins...")
    structures_batch = chunk_loader.get_batch(protein_ids)
    
    # Process each row
    skipped = 0
    for i, row in tqdm(df_fold.iterrows(), total=len(df_fold), desc="Building dataset"):
        protein_id = row["protein_id"]
        
        if protein_id not in structures_batch:
            skipped += 1
            continue
        
        protein_json = structures_batch[protein_id]
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    if skipped > 0:
        print(f"Warning: Skipped {skipped} entries due to missing structures")
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)


# ============= USAGE EXAMPLE =============


# Input

In [None]:
import pandas as pd
import os
import json
from Bio.PDB import PDBParser, is_aa
from tqdm import tqdm
from transformers import EsmModel, EsmTokenizer
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = PDBParser(QUIET=True)
import os
import pandas as pd
import torch
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt
import torch.nn.functional as F
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Defne tasks to train on
task_cols = ['pKi', 'pEC50', 'pKd', 'pKd (Wang, FEP)', 'pIC50', 'potency']


# Load your dataframe
df = pd.read_parquet("../data/standardized/standardized_input.parquet", engine="fastparquet")
df = df.sample(frac=1, random_state=42).reset_index(drop=True)

col_nan = ["standardized_protein_pdb", "standardized_ligand_sdf"] + task_cols
# df = df[df['is_experimental'] == True]
df = df.dropna(how = "all", subset=col_nan)
df = df.reset_index(drop=True)
df = df[df["standardized_protein_pdb"].isna()==False]
df = df[df["standardized_ligand_sdf"].isna()==False]
df = df[:50000]

# Calculate task ranges from your dataframe
task_ranges = prepare_mtl_experiment(df, task_cols)


# Load ESM2

In [None]:
# load ESM-2 model
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
esm_model = EsmModel.from_pretrained(model_name)
esm_model.eval().cuda() 


# Generate structure dict

In [None]:

# Optional: Load and use the chunks later
from parallel_structure_processing_optimized import StructureChunkLoader

# Create a chunk loader (caches 2 chunks in memory at a time)
chunk_loader = StructureChunkLoader(
    chunk_dir="../data/structure_chunks/",
    cache_size=2
)


# Check validity

In [None]:
import json

with open("../data/structure_chunks/structures_chunk_0000.json", "r") as f:
    pdb_structures = json.load(f)

In [None]:
# here need to load all the pdb_structures in one .... 

In [None]:
# Build a lookup set
pdb_keys = set(pdb_structures.keys())

# Add a canonical protein id column from the protein PDB path
df["protein_id"] = df["standardized_protein_pdb"].apply(
    lambda p: os.path.splitext(os.path.basename(p))[0]
)

# Keep only rows with a matching protein structure
df_clean = df[df["protein_id"].isin(pdb_keys)].reset_index(drop=True)

print("Available over total:", len(df_clean), len(df))

In [None]:
print(len(df), len(df_clean))

# Cross validation

# Check validity

In [None]:
# Build a lookup set
pdb_keys = set(pdb_structures.keys())

# Add a canonical protein id column from the protein PDB path
df["protein_id"] = df["standardized_protein_pdb"].apply(
    lambda p: os.path.splitext(os.path.basename(p))[0]
)

# Keep only rows with a matching protein structure
df_clean = df[df["protein_id"].isin(pdb_keys)].reset_index(drop=True)

print("Available over total:", len(df_clean), len(df))

# Build data 

In [None]:
# 4. MODIFIED BUILD DATASET FUNCTION
def build_mtl_dataset(df_fold, pdb_structures, task_cols=['pKi', 'pEC50', 'pKd', 'pIC50']):
    data_list = []
    for i, row in df_fold.iterrows():
        pdb_id = os.path.basename(row["standardized_ligand_sdf"]).split(".")[0]
        protein_json = pdb_structures.get(pdb_id)
        protein = featurize_protein_graph(protein_json)
        drug = featurize_drug(row["standardized_ligand_sdf"])
        
        # Collect all task values
        task_values = {}
        for task in task_cols:
            if task in row and not pd.isna(row[task]):
                task_values[task] = float(row[task])
            else:
                task_values[task] = np.nan
        
        data_entry = {
            "protein": protein,
            "drug": drug,
        }
        data_entry.update(task_values)
        data_list.append(data_entry)
    
    return MTL_DTA(df=df_fold, data_list=data_list, task_cols=task_cols)


# Cross validation

In [None]:
#!/usr/bin/env python3
"""
Fixed training for Graph Neural Networks
DataParallel doesn't work well with graph batches, so we'll use single GPU
but with optimized batching
"""

import os
import gc
import torch
import torch.nn as nn
import torch_geometric
from sklearn.model_selection import KFold
import numpy as np
import pandas as pd
from tqdm import tqdm
import math
from sklearn.metrics import r2_score, mean_squared_error
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Clean up
torch.cuda.empty_cache()
gc.collect()

def train_fold_single_gpu(
    fold_idx, n_folds,
    df_train, df_valid, df_test,
    chunk_loader, task_cols, task_ranges,
    n_epochs=100, batch_size=256, lr=0.0005, patience=20,
    device_id=0
):
    """
    Training function using single GPU (Graph NNs don't work well with DataParallel)
    
    Parameters:
    - device_id: which GPU to use (0-7 for your 8 GPUs)
    """
    
    print(f"\n{'='*60}")
    print(f"FOLD {fold_idx + 1}/{n_folds} - Using GPU {device_id}")
    print(f"{'='*60}")
    print(f"  Train: {len(df_train)} samples")
    print(f"  Valid: {len(df_valid)} samples")
    print(f"  Test:  {len(df_test)} samples")
    print(f"  Batch size: {batch_size}")
    
    # Create datasets
    print("\nBuilding datasets...")
    train_dataset = build_mtl_dataset_optimized(df_train, chunk_loader, task_cols)
    valid_dataset = build_mtl_dataset_optimized(df_valid, chunk_loader, task_cols)
    test_dataset = build_mtl_dataset_optimized(df_test, chunk_loader, task_cols)
    
    # Create data loaders - num_workers=0 to avoid file issues
    train_loader = torch_geometric.loader.DataLoader(
        train_dataset, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    valid_loader = torch_geometric.loader.DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    test_loader = torch_geometric.loader.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    print(f"✓ DataLoaders created")
    
    # Set device
    device = torch.device(f"cuda:{device_id}")
    print(f"✓ Using device: {device}")
    
    # Create model - NO DataParallel for graph models
    print("Initializing model...")
    model = MTL_DTAModel(
        task_names=task_cols,
        prot_emb_dim=1280,
        prot_gcn_dims=[128, 256, 256],
        prot_fc_dims=[1024, 128],
        drug_node_in_dim=[66, 1],
        drug_node_h_dims=[128, 64],
        drug_fc_dims=[1024, 128],
        mlp_dims=[1024, 512],
        mlp_dropout=0.25
    ).to(device)
    
    print(f"✓ Model loaded on GPU {device_id}")
    
    # Initialize optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = MaskedMSELoss(task_ranges=task_ranges).to(device)
    
    # Training state
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    train_losses = []
    val_losses = []
    
    # Training loop
    print("\nStarting training...")
    pbar = tqdm(range(n_epochs), desc=f"Training", ncols=100)
    
    for epoch in pbar:
        # ========== TRAINING PHASE ==========
        model.train()
        train_loss = 0
        n_train_batches = 0
        
        for batch_idx, batch in tqdm(enumerate(train_loader), desc = "Batch iter."):
            # Move batch to GPU
            xd = batch['drug'].to(device)
            xp = batch['protein'].to(device)
            y = batch['y'].to(device)
            
            optimizer.zero_grad()
            pred = model(xd, xp)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            n_train_batches += 1
            
            # Clear cache periodically
            if batch_idx % 20 == 0:
                torch.cuda.empty_cache()
        
        avg_train_loss = train_loss / n_train_batches if n_train_batches > 0 else 0
        train_losses.append(avg_train_loss)
        
        # ========== VALIDATION PHASE ==========
        model.eval()
        val_loss = 0
        n_val_batches = 0
        
        with torch.no_grad():
            for batch in valid_loader:
                xd = batch['drug'].to(device)
                xp = batch['protein'].to(device)
                y = batch['y'].to(device)
                
                pred = model(xd, xp)
                loss = criterion(pred, y)
                val_loss += loss.item()
                n_val_batches += 1
        
        avg_val_loss = val_loss / n_val_batches if n_val_batches > 0 else float('inf')
        val_losses.append(avg_val_loss)
        
        # Update progress bar
        pbar.set_postfix({
            'train': f"{avg_train_loss:.4f}",
            'val': f"{avg_val_loss:.4f}",
            'best': f"{best_val_loss:.4f}"
        })
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break
        
        # Periodic cleanup
        if epoch % 10 == 0:
            torch.cuda.empty_cache()
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    # ========== EVALUATION PHASE ==========
    print(f"\nEvaluating on test set...")
    model.eval()
    
    task_predictions = {task: [] for task in task_cols}
    task_targets = {task: [] for task in task_cols}
    
    with torch.no_grad():
        for batch in test_loader:
            xd = batch['drug'].to(device)
            xp = batch['protein'].to(device)
            y = batch['y'].to(device)
            
            pred = model(xd, xp)
            
            # Collect predictions for each task
            for i, task in enumerate(task_cols):
                mask = ~torch.isnan(y[:, i])
                if mask.sum() > 0:
                    task_preds = pred[mask, i].cpu().numpy()
                    task_trues = y[mask, i].cpu().numpy()
                    task_predictions[task].extend(task_preds)
                    task_targets[task].extend(task_trues)
    
    # Calculate and print metrics
    fold_results = {}
    print(f"\n{'='*50}")
    print(f"Fold {fold_idx + 1} Results:")
    print(f"{'='*50}")
    
    for task in task_cols:
        if len(task_predictions[task]) > 0:
            preds = np.array(task_predictions[task])
            targets = np.array(task_targets[task])
            
            r2 = r2_score(targets, preds)
            rmse = math.sqrt(mean_squared_error(targets, preds))
            
            fold_results[task] = {
                'predictions': preds,
                'targets': targets,
                'r2': r2,
                'rmse': rmse
            }
            
            print(f"{task:20s} | RMSE: {rmse:6.3f} | R²: {r2:6.3f} | n={len(preds):5d}")
    
    # Clean up
    del model
    del optimizer
    del train_loader
    del valid_loader
    del test_loader
    torch.cuda.empty_cache()
    gc.collect()
    
    return fold_results, train_losses, val_losses


In [None]:


# ============= MAIN TRAINING SCRIPT =============
print(f"\n{'='*70}")
print("GRAPH NEURAL NETWORK TRAINING - OPTIMIZED")
print(f"{'='*70}")
print(f"GPUs available: {torch.cuda.device_count()}")

# Configuration
BATCH_SIZE = 64     # Can use larger batch on single GPU
N_EPOCHS = 100       
LEARNING_RATE = 0.0001
PATIENCE = 30
N_FOLDS = 5

print(f"\nConfiguration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {N_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Patience: {PATIENCE}")
print(f"  Folds: {N_FOLDS}")

# Initialize results
cv_results = {
    task: {
        'r2_list': [],
        'rmse_list': [],
        'all_predictions': [],
        'all_targets': []
    } for task in task_cols
}

# K-Fold cross-validation
kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

print(f"\nStarting {N_FOLDS}-fold cross-validation...")
print("="*70)


In [None]:

# We'll rotate through GPUs for different folds
n_gpus = torch.cuda.device_count()

for fold_idx, (train_idx, test_idx) in enumerate(kf.split(df_clean)):
    # Select GPU for this fold (rotate through available GPUs)
    gpu_id = fold_idx % n_gpus
    
    # Split data
    df_train = df_clean.iloc[train_idx].reset_index(drop=True)
    df_test = df_clean.iloc[test_idx].reset_index(drop=True)
    
    # Create validation set (10% of training)
    valid_size = int(0.1 * len(df_train))
    df_valid = df_train.sample(n=valid_size, random_state=42)
    df_train = df_train.drop(df_valid.index).reset_index(drop=True)
    
    # Train fold on selected GPU
    fold_results, train_losses, val_losses = train_fold_single_gpu(
        fold_idx, N_FOLDS,
        df_train, df_valid, df_test,
        chunk_loader, task_cols, task_ranges,
        N_EPOCHS, BATCH_SIZE, LEARNING_RATE, PATIENCE,
        device_id=gpu_id
    )
    
    # Store results
    for task in task_cols:
        if task in fold_results:
            cv_results[task]['r2_list'].append(fold_results[task]['r2'])
            cv_results[task]['rmse_list'].append(fold_results[task]['rmse'])
            cv_results[task]['all_predictions'].extend(fold_results[task]['predictions'])
            cv_results[task]['all_targets'].extend(fold_results[task]['targets'])
    
    # Clean up after each fold
    torch.cuda.empty_cache()
    gc.collect()

In [None]:


# ========== FINAL SUMMARY ==========
print(f"\n{'='*70}")
print(f"CROSS-VALIDATION COMPLETE")
print(f"{'='*70}")

summary_results = []
for task in task_cols:
    if len(cv_results[task]['r2_list']) > 0:
        avg_r2 = np.mean(cv_results[task]['r2_list'])
        std_r2 = np.std(cv_results[task]['r2_list'])
        avg_rmse = np.mean(cv_results[task]['rmse_list'])
        std_rmse = np.std(cv_results[task]['rmse_list'])
        n_samples = len(cv_results[task]['all_targets'])
        
        summary_results.append({
            'Task': task,
            'R² (mean±std)': f"{avg_r2:.3f}±{std_r2:.3f}",
            'RMSE (mean±std)': f"{avg_rmse:.3f}±{std_rmse:.3f}",
            'N samples': n_samples
        })
        
        print(f"\n{task}:")
        print(f"  R²:    {avg_r2:.3f} ± {std_r2:.3f}")
        print(f"  RMSE:  {avg_rmse:.3f} ± {std_rmse:.3f}")
        print(f"  Total samples: {n_samples}")

# Create summary dataframe
summary_df = pd.DataFrame(summary_results)
print("\n" + "="*70)
print("SUMMARY TABLE:")
print(summary_df.to_string(index=False))
print("="*70)

# ========== VISUALIZATION ==========
print("\nCreating visualizations...")

# Individual task plots
n_tasks_with_data = sum(1 for task in task_cols if len(cv_results[task]['all_targets']) > 0)
n_cols = min(3, n_tasks_with_data)
n_rows = (n_tasks_with_data + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
if n_tasks_with_data == 1:
    axes = [axes]
elif n_rows == 1:
    axes = axes
else:
    axes = axes.flatten()

plot_idx = 0
for task in task_cols:
    if len(cv_results[task]['all_targets']) > 0:
        ax = axes[plot_idx] if n_tasks_with_data > 1 else axes
        
        targets = np.array(cv_results[task]['all_targets'])
        preds = np.array(cv_results[task]['all_predictions'])
        
        # Scatter plot
        ax.scatter(targets, preds, alpha=0.4, s=10, color='blue')
        
        # Diagonal line
        min_val = min(targets.min(), preds.min())
        max_val = max(targets.max(), preds.max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', lw=1, alpha=0.7)
        
        # Calculate overall metrics
        overall_r2 = r2_score(targets, preds)
        overall_rmse = math.sqrt(mean_squared_error(targets, preds))
        
        # Labels and title
        ax.set_xlabel(f'Experimental {task}')
        ax.set_ylabel(f'Predicted {task}')
        ax.set_title(f'{task}\nR²={overall_r2:.3f}, RMSE={overall_rmse:.3f}')
        ax.grid(True, alpha=0.3)
        
        plot_idx += 1

# Hide unused subplots
if n_tasks_with_data > 1:
    for idx in range(plot_idx, len(axes)):
        axes[idx].set_visible(False)

plt.suptitle(f'{N_FOLDS}-Fold Cross-Validation Results', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print(f"\n✓ Training complete!")
print(f"✓ Results saved in 'cv_results' variable")

In [None]:
1