# Inference simple

In [4]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path
import json
from tqdm import tqdm
from torch_geometric.data import Batch

class DTAPredictor:
    """Simple predictor for drug-target affinity."""
    
    def __init__(self, model_path, config_path=None, device='cuda'):
        """
        Args:
            model_path: Path to saved model checkpoint
            config_path: Path to config file (optional)
            device: Device to run on
        """
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        # Load config
        if config_path and Path(config_path).exists():
            with open(config_path, 'r') as f:
                self.config = json.load(f)
        else:
            # Default config
            self.config = {
                'task_cols': ['pKi', 'pEC50', 'pKd', 'pIC50', 'pKd (Wang, FEP)', 'potency'],
                'model_config': {
                    'prot_emb_dim': 1280,
                    'prot_gcn_dims': [128, 256, 256],
                    'prot_fc_dims': [1024, 128],
                    'drug_node_in_dim': [66, 1],
                    'drug_node_h_dims': [128, 64],
                    'drug_edge_in_dim': [16, 1],
                    'drug_edge_h_dims': [32, 1],
                    'drug_fc_dims': [1024, 128],
                    'mlp_dims': [1024, 512],
                    'mlp_dropout': 0.25
                }
            }
        
        # Load model
        
        self.model = MTL_DTAModel(
            task_names=self.config['task_cols'],
            **self.config['model_config']
        ).to(self.device)
        
        # Load weights
        checkpoint = torch.load(model_path, map_location=self.device)
        if 'model_state_dict' in checkpoint:
            self.model.load_state_dict(checkpoint['model_state_dict'])
        else:
            self.model.load_state_dict(checkpoint)
        self.model.eval()
        
        # Initialize processors
        self._init_processors()
    
    def _init_processors(self):
        """Initialize standardization and featurization."""
        from rdkit import Chem
        from rdkit.Chem import AllChem
        from Bio.PDB import PDBParser
        self.pdb_parser = PDBParser(QUIET=True)
        
    def predict(self, protein_ligand_pairs):
        """
        Predict affinities for protein-ligand pairs.
        
        Args:
            protein_ligand_pairs: List of tuples (protein_path, ligand_path) or
                                 DataFrame with 'protein_path' and 'ligand_path' columns
        
        Returns:
            DataFrame with predictions for each task
        """
        # Convert to DataFrame if needed
        if isinstance(protein_ligand_pairs, list):
            df = pd.DataFrame(protein_ligand_pairs, 
                            columns=['protein_path', 'ligand_path'])
        else:
            df = protein_ligand_pairs.copy()
        
        # Process and predict
        predictions = []
        
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Predicting"):
            try:
                # Featurize
                drug_graph = self._featurize_drug(row['ligand_path'])
                prot_graph = self._featurize_protein(row['protein_path'])
                
                # Create batch
                

                drug_batch = Batch.from_data_list([drug_graph]).to(self.device)
                prot_batch = Batch.from_data_list([prot_graph]).to(self.device)
                
                # Predict
                with torch.no_grad():
                    pred = self.model(drug_batch, prot_batch)
                    pred = pred.cpu().numpy()[0]  # Get first (only) batch item
                
                # Store predictions
                pred_dict = {task: pred[i] for i, task in enumerate(self.config['task_cols'])}
                predictions.append(pred_dict)
                
            except Exception as e:
                print(f"Error processing {idx}: {e}")
                pred_dict = {task: np.nan for task in self.config['task_cols']}
                predictions.append(pred_dict)
        
        # Create results DataFrame
        results_df = pd.DataFrame(predictions)
        results_df = pd.concat([df[['protein_path', 'ligand_path']], results_df], axis=1)
        
        return results_df
    
    def _featurize_drug(self, sdf_path):
        """Quick drug featurization."""
        import torch_geometric
        from rdkit import Chem
        import torch_cluster
        
        mol = Chem.MolFromMolFile(sdf_path)
        if mol is None:
            raise ValueError(f"Could not read molecule from {sdf_path}")
        
        # Get coordinates
        conf = mol.GetConformer()
        coords = torch.tensor(conf.GetPositions(), dtype=torch.float32)
        
        # Simple atom features (simplified)
        atom_features = []
        for atom in mol.GetAtoms():
            features = [
                atom.GetAtomicNum(),
                atom.GetDegree(),
                atom.GetTotalNumHs(),
                int(atom.GetIsAromatic())
            ]
            atom_features.append(features)
        
        atom_features = torch.tensor(atom_features, dtype=torch.float32)
        
        # Create edges (radius graph)
        edge_index = torch_cluster.radius_graph(coords, r=4.5)
        
        # Create minimal graph
        data = torch_geometric.data.Data(
            x=coords,
            edge_index=edge_index,
            node_s=atom_features,
            node_v=coords.unsqueeze(1),
            edge_s=torch.ones(edge_index.size(1), 16),  # Dummy edge features
            edge_v=torch.ones(edge_index.size(1), 1, 3)  # Dummy edge vectors
        )
        
        return data
    
    def _featurize_protein(self, pdb_path):
        """Quick protein featurization."""
        import torch_geometric
        import torch_cluster
        
        # Parse PDB
        structure = self.pdb_parser.get_structure('protein', pdb_path)
        
        # Get CA coordinates
        ca_coords = []
        for model in structure:
            for chain in model:
                for residue in chain:
                    if 'CA' in residue:
                        ca_coords.append(residue['CA'].coord)
                if ca_coords:  # Use first chain only
                    break
            break
        
        coords = torch.tensor(ca_coords, dtype=torch.float32)
        
        # Create edges
        edge_index = torch_cluster.radius_graph(coords, r=8.0)
        
        # Dummy features
        seq_len = len(coords)
        seq = torch.zeros(seq_len, dtype=torch.long)  # Dummy sequence
        node_s = torch.randn(seq_len, 6)  # Dummy dihedral features
        node_v = torch.randn(seq_len, 3, 3)  # Dummy orientations
        seq_emb = torch.randn(seq_len, 1280)  # Dummy ESM embeddings
        
        # Create graph
        data = torch_geometric.data.Data(
            x=coords,
            seq=seq,
            edge_index=edge_index,
            node_s=node_s,
            node_v=node_v,
            edge_s=torch.randn(edge_index.size(1), 39),
            edge_v=torch.randn(edge_index.size(1), 1, 3),
            seq_emb=seq_emb
        )
        
        return data


# ============ SIMPLE USAGE ============

def predict_affinity(
    model_path,
    protein_ligand_pairs,
    output_path=None,
    device='cuda'
):
    """
    Simple function to predict affinities.
    
    Args:
        model_path: Path to trained model
        protein_ligand_pairs: List of (protein_pdb, ligand_sdf) or DataFrame
        output_path: Optional path to save predictions
        device: Device to use
    
    Returns:
        DataFrame with predictions
    """
    # Initialize predictor
    predictor = DTAPredictor(model_path, device=device)
    
    # Predict
    results = predictor.predict(protein_ligand_pairs)
    
    # Save if requested
    if output_path:
        results.to_csv(output_path, index=False)
        print(f"Predictions saved to {output_path}")
    
    return results

In [5]:
from Bio.PDB import PDBParser
from Bio.PDB import MMCIFParser, PDBIO, Select

class ProteinSelect(Select):
    def accept_residue(self, residue):
        return residue.get_id()[0] == ' '

    
class LigandSelect(Select):
    def accept_residue(self, residue):
        return residue.get_id()[0] != ' '
        
def check_files_exist_and_valid(protein_path, ligand_path, min_size_bytes=50):
    """Check if both protein and ligand files exist and are valid"""
    try:
        if not (os.path.exists(protein_path) and os.path.exists(ligand_path)):
            return False
        
        protein_size = os.path.getsize(protein_path)
        ligand_size = os.path.getsize(ligand_path)
        
        if protein_size < min_size_bytes or ligand_size < min_size_bytes:
            return False
            
        # Quick content validation
        try:
            with open(protein_path, 'r') as f:
                first_line = f.readline().strip()
                if not (first_line.startswith(('ATOM', 'HETATM', 'MODEL', 'HEADER'))):
                    return False
            
            with open(ligand_path, 'r') as f:
                content = f.read(100)
                if len(content.strip()) < 10:
                    return False
                    
        except Exception:
            return False
            
        return True
        
    except Exception:
        return False

def process_single(input_path, protein_dir, ligand_dir):
    """Process a single CIF file"""
    input_filename = os.path.basename(input_path)
    
    # Parse CIF file
    if input_path.endswith(".cif"):
        parser = MMCIFParser(QUIET=True)
        pdb_filename = input_filename.replace(".cif", ".pdb")
        sdf_filename = input_filename.replace(".cif", ".sdf")

    if input_path.endswith(".pdb"):
        parser = PDBParser(QUIET=True)
        pdb_filename = input_filename.replace(".pdb", ".pdb")
        sdf_filename = input_filename.replace(".pdb", ".sdf")

        
    
    pdb_path = os.path.join(str(protein_dir), str(pdb_filename))
    sdf_path = os.path.join(str(ligand_dir), str(sdf_filename))
    
    # Skip if files already exist and are valid
    if check_files_exist_and_valid(pdb_path, sdf_path):
        return {
            'protein_path': pdb_path,
            'ligand_path': sdf_path,
            'status': 'already_exists',
            'success': True
        }
    
    # Clean up any partially created files
    for path in [pdb_path, sdf_path]:
        if os.path.exists(path) :
            try:
                os.remove(path)
            except:
                pass
    
    try:
        # Parse CIF file
        structure = parser.get_structure("complex", input_path)
        
        # Write protein PDB
        io = PDBIO()
        io.set_structure(structure)
        io.save(pdb_path, select=ProteinSelect())
        
        # Write ligand to temporary PDB first
        ligand_temp_pdb = sdf_path.replace(".sdf", "_temp.pdb")
        io.save(ligand_temp_pdb, select=LigandSelect())
        
        # Convert ligand PDB to SDF using RDKit
        mol = rdmolfiles.MolFromPDBFile(ligand_temp_pdb, removeHs=False)
        
        if mol is not None:
            writer = Chem.SDWriter(sdf_path)
            writer.write(mol)
            writer.close()
            
            # Clean up temp file
            if os.path.exists(ligand_temp_pdb):
                os.remove(ligand_temp_pdb)
            
            # Final validation
            if check_files_exist_and_valid(pdb_path, sdf_path):
                return {
                    'protein_path': pdb_path,
                    'ligand_path': sdf_path,
                    'status': 'converted_successfully',
                    'success': True
                }
            else:
                return {
                    'protein_path': pdb_path,
                    'ligand_path': sdf_path,
                    'status': 'validation_failed',
                    'success': False
                }
        else:
            # RDKit conversion failed
            if os.path.exists(ligand_temp_pdb):
                os.remove(ligand_temp_pdb)
            return {
                'protein_path': pdb_path,
                'ligand_path': sdf_path,
                'status': 'rdkit_failed',
                'success': False
            }
            
    except Exception as e:
        # Clean up any partial files
        for path in [pdb_path, sdf_path]:
            if os.path.exists(path):
                try:
                    os.remove(path)
                except:
                    pass
                    
        return {
            'protein_path': pdb_path,
            'ligand_path': sdf_path,
            'status': f'error: {str(e)}',
            'success': False
        }


In [6]:
import os
import sys
import json
import gc
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import torch
import torch_geometric
from pathlib import Path
from sklearn.model_selection import KFold, train_test_split
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Add the parent directory to the Python path
# This allows importing the gnn_dta_mtl package
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(''))))

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")

# Import your package - use absolute import instead of relative
from gnn_dta_mtl import (
    MTL_DTAModel, DTAModel,
    MTL_DTA, DTA,
    CrossValidator, MTLTrainer,
    StructureStandardizer, StructureProcessor, StructureChunkLoader,
    ESMEmbedder,
    add_molecular_properties_parallel,
    compute_ligand_efficiency,
    compute_mean_ligand_efficiency,
    filter_by_properties,
    prepare_mtl_experiment,
    build_mtl_dataset, build_mtl_dataset_optimized,
    evaluate_model,
    plot_results, plot_predictions, create_summary_report,
    ExperimentLogger,
    save_model, save_results, create_output_dir,
    featurize_drug, featurize_protein_graph
)

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

Using device: cuda
GPU: Tesla T4
Number of GPUs: 4


In [27]:
import torch
import pandas as pd
import numpy as np
from pathlib import Path
import json
from tqdm import tqdm
import torch_geometric
from torch_geometric.data import Batch
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1
import os
import tempfile

class DTAPredictor:
    """Simple predictor for drug-target affinity."""
    
    def __init__(self, model_path, config_path=None, device='cuda', esm_model=None):
        """
        Args:
            model_path: Path to saved model checkpoint
            config_path: Path to config file (optional)
            device: Device to run on
            esm_model: Pre-loaded ESM model (optional)
        """
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        # Load config
        if config_path and Path(config_path).exists():
            with open(config_path, 'r') as f:
                self.config = json.load(f)
        else:
            # Default config
            self.config = {
                'task_cols': ['pKi', 'pEC50', 'pKd', 'pIC50', 'pKd (Wang, FEP)', 'potency'],
                'model_config': {
                    'prot_emb_dim': 1280,
                    'prot_gcn_dims': [128, 256, 256],
                    'prot_fc_dims': [1024, 128],
                    'drug_node_in_dim': [66, 1],
                    'drug_node_h_dims': [128, 64],
                    'drug_edge_in_dim': [16, 1],
                    'drug_edge_h_dims': [32, 1],
                    'drug_fc_dims': [1024, 128],
                    'mlp_dims': [1024, 512],
                    'mlp_dropout': 0.25
                }
            }
        
        # Load model
        self.model = MTL_DTAModel(
            task_names=self.config['task_cols'],
            **self.config['model_config']
        ).to(self.device)
        
        # Load weights
        checkpoint = torch.load(model_path, map_location=self.device)
        if 'model_state_dict' in checkpoint:
            self.model.load_state_dict(checkpoint['model_state_dict'])
        else:
            self.model.load_state_dict(checkpoint)
        self.model.eval()
        
        # Initialize ESM model for protein embeddings
        if esm_model is not None:
            self.esm_model = esm_model
            self.tokenizer = None
        else:
            from transformers import EsmModel, EsmTokenizer
            model_name = "facebook/esm2_t33_650M_UR50D"
            self.tokenizer = EsmTokenizer.from_pretrained(model_name)
            self.esm_model = EsmModel.from_pretrained(model_name)
            self.esm_model.eval()
            self.esm_model = self.esm_model.to(self.device)
        
        # Initialize parser
        self.parser = PDBParser(QUIET=True)
        
        print(f"✓ Model loaded from {model_path}")
        print(f"✓ Using device: {self.device}")
        
    def extract_backbone_coords(self, structure, pdb_id, pdb_path):
        """Extract backbone coordinates from protein structure (from your code)."""
        coords = {"N": [], "CA": [], "C": [], "O": []}
        seq = ""
        
        model = structure[0]
        
        # Find valid chain
        valid_chain = None
        for chain in model:
            for res in chain:
                if res.id[0] == ' ':  # Standard amino acid
                    valid_chain = chain
                    break
            if valid_chain:
                break
        
        if valid_chain is None:
            return None, None, None
        
        chain_id = valid_chain.id
        
        # Extract coordinates and sequence
        for res in valid_chain:
            if res.id[0] != ' ':  # Skip non-standard residues
                continue
            
            # Get one-letter code
            try:
                seq += seq1(res.resname)
            except:
                seq += 'X'  # Unknown residue
            
            # Get backbone atom coordinates
            for atom_name in ["N", "CA", "C", "O"]:
                if atom_name in res:
                    coords[atom_name].append(res[atom_name].coord.tolist())
                else:
                    coords[atom_name].append([float("nan")] * 3)
        
        return seq, coords, chain_id
    
    def get_esm_embedding(self, seq):
        """Get ESM embedding for sequence (from your code)."""
        if self.tokenizer is None:
            from transformers import EsmTokenizer
            self.tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
        
        inputs = self.tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.esm_model(**inputs)
            # Remove CLS and EOS tokens
            embedding = outputs.last_hidden_state[0, 1:-1]
        
        return embedding
    
    def create_protein_structure_dict(self, pdb_path):
        """Create protein structure dictionary with real ESM embeddings."""
        pdb_id = os.path.basename(pdb_path).split('.')[0]
        
        structure = self.parser.get_structure(pdb_id, pdb_path)
        seq, coords, chain_id = self.extract_backbone_coords(structure, pdb_id, pdb_path)
        
        if seq is None:
            raise ValueError(f"No valid chain found in {pdb_path}")
        
        # Stack coordinates in order: N, CA, C, O
        coords_stacked = []
        for i in range(len(coords["N"])):
            coord_group = []
            for atom in ["N", "CA", "C", "O"]:
                coord_group.append(coords[atom][i])
            coords_stacked.append(coord_group)
        
        # Get ESM embedding
        embedding = self.get_esm_embedding(seq)
        
        # Save embedding temporarily
        with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
            torch.save(embedding.cpu(), f.name)
            embed_path = f.name
        
        structure_dict = {
            "name": pdb_id,
            "UniProt_id": "UNKNOWN",
            "PDB_id": pdb_id,
            "chain": chain_id,
            "seq": seq,
            "coords": coords_stacked,
            "embed": embed_path
        }
        
        return structure_dict
    
    #def predict(self, protein_ligand_pairs):
    #    """
    #    Predict affinities for protein-ligand pairs.
    #    
    #    Args:
    #        protein_ligand_pairs: List of tuples (protein_path, ligand_path) or
    #                             DataFrame with 'protein_path' and 'ligand_path' columns
    #    
    #    Returns:
    #        DataFrame with predictions for each task
    #    """
    #    # Import featurization functions
    #    
    #    # Convert to DataFrame if needed
    #    if isinstance(protein_ligand_pairs, list):
    #        df = pd.DataFrame(protein_ligand_pairs, 
    #                        columns=['protein_path', 'ligand_path'])
    #    else:
    #        df = protein_ligand_pairs.copy()
    #    
    #    # Process and predict
    #    predictions = []
    #    
    #    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Predicting"):
    #        try:
    #            # Featurize drug
    #            drug_graph = featurize_drug(row['ligand_path'])
    #            
    #            # Create protein structure with real ESM embeddings
    #            protein_struct = self.create_protein_structure_dict(row['protein_path'])
    #            
    #            # Featurize protein
    #            prot_graph = featurize_protein_graph(protein_struct)
    #            
    #            # Batch the data properly
    #            drug_batch = Batch.from_data_list([drug_graph]).to(self.device)
    #            prot_batch = Batch.from_data_list([prot_graph]).to(self.device)
    #            
    #            # Predict
    #            with torch.no_grad():
    #                pred = self.model(drug_batch, prot_batch)
    #                pred = pred.cpu().numpy()[0]  # Get first (only) batch item
    #            
    #            # Store predictions
    #            pred_dict = {task: float(pred[i]) for i, task in enumerate(self.config['task_cols'])}
    #            predictions.append(pred_dict)
    #            
    #            # Clean up temp embedding file
    #            if os.path.exists(protein_struct['embed']):
    #                os.remove(protein_struct['embed'])
    #            
    #        except Exception as e:
    #            print(f"Error processing {idx} ({row['protein_path']}): {e}")
    #            pred_dict = {task: np.nan for task in self.config['task_cols']}
    #            predictions.append(pred_dict)
    #    
    #    # Create results DataFrame
    #    results_df = pd.DataFrame(predictions)
    #    results_df = pd.concat([df[['protein_path', 'ligand_path']], results_df], axis=1)
    #    
    #    return results_df

    def predict(self, protein_ligand_pairs):
        """
        Predict affinities for protein-ligand pairs.

        Args:
            protein_ligand_pairs: List of tuples (protein_path, ligand_path) or
                                 DataFrame with 'protein_path' and 'ligand_path' columns

        Returns:
            DataFrame with predictions for each task
        """
        # Convert to DataFrame if needed
        if isinstance(protein_ligand_pairs, list):
            df = pd.DataFrame(protein_ligand_pairs, 
                            columns=['protein_path', 'ligand_path'])
        else:
            df = protein_ligand_pairs.copy()

        # Step 1: Featurize all pairs
        drug_graphs = []
        prot_graphs = []
        valid_indices = []
        temp_embed_files = []

        print("Featurizing all protein-ligand pairs...")
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Featurizing"):
            try:
                # Featurize drug
                drug_graph = featurize_drug(row['ligand_path'])

                # Create protein structure with real ESM embeddings
                protein_struct = self.create_protein_structure_dict(row['protein_path'])
                temp_embed_files.append(protein_struct['embed'])

                # Featurize protein
                prot_graph = featurize_protein_graph(protein_struct)

                # Store graphs
                drug_graphs.append(drug_graph)
                prot_graphs.append(prot_graph)
                valid_indices.append(idx)

            except Exception as e:
                print(f"Error featurizing {idx} ({row['protein_path']}): {e}")

        # Step 2: Batch all data
        if drug_graphs:
            drug_batch = Batch.from_data_list(drug_graphs).to(self.device)
            prot_batch = Batch.from_data_list(prot_graphs).to(self.device)

            # Step 3: Predict on full batch
            print("Running batch prediction...")
            with torch.no_grad():
                batch_preds = self.model(drug_batch, prot_batch)
                batch_preds = batch_preds.cpu().numpy()

        # Step 4: Format results
        predictions = []
        pred_idx = 0

        for idx in range(len(df)):
            if idx in valid_indices:
                # Get predictions for this sample
                pred = batch_preds[pred_idx]
                pred_dict = {task: float(pred[i]) for i, task in enumerate(self.config['task_cols'])}
                pred_idx += 1
            else:
                # Failed featurization - use NaN
                pred_dict = {task: np.nan for task in self.config['task_cols']}

            predictions.append(pred_dict)

        # Clean up temp embedding files
        for embed_file in temp_embed_files:
            if os.path.exists(embed_file):
                os.remove(embed_file)

        # Create results DataFrame
        results_df = pd.DataFrame(predictions)
        results_df = pd.concat([df[['protein_path', 'ligand_path']], results_df], axis=1)

        return results_df

def predict_affinity(
    model_path,
    protein_ligand_pairs,
    output_path=None,
    device='cuda',
    esm_model=None
):
    """
    Simple function to predict affinities with real featurization.
    
    Args:
        model_path: Path to trained model
        protein_ligand_pairs: List of (protein_pdb, ligand_sdf) or DataFrame
        output_path: Optional path to save predictions
        device: Device to use
        esm_model: Pre-loaded ESM model (optional, will load if not provided)
    
    Returns:
        DataFrame with predictions
    """
    # Initialize predictor
    predictor = DTAPredictor(model_path, device=device, esm_model=esm_model)
    
    # Predict
    results = predictor.predict(protein_ligand_pairs)
    
    # Save if requested
    if output_path:
        results.to_csv(output_path, index=False)
        print(f"Predictions saved to {output_path}")
    
    return results

In [28]:
# Complete working script
import torch
import pandas as pd
from pathlib import Path
from rdkit import Chem
from rdkit.Chem import rdmolfiles

protein_dir = Path('../output/protein')
ligand_dir = Path('../output/ligand')
input_path = '6OH4.pdb'
    
processed_files = process_single(input_path, protein_dir, ligand_dir)

processed_files

{'protein_path': '../output/protein/6OH4.pdb',
 'ligand_path': '../output/ligand/6OH4.sdf',
 'status': 'already_exists',
 'success': True}

In [29]:
# Load ESM model once (optional - for efficiency with multiple predictions)
from transformers import EsmModel, EsmTokenizer
from rdkit import RDLogger

# Disable RDKit warnings
RDLogger.DisableLog('rdApp.*')

print("Loading ESM model...")
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
esm_model = EsmModel.from_pretrained(model_name)
esm_model.eval()
if torch.cuda.is_available():
    esm_model = esm_model.cuda()
print("✓ ESM model loaded")

# Now run predictions
model_checkpoint = "../output/experiments/gnn_dta_mtl_experiment_20250910_153508/models/final_model.pt"

# Your test complexes
test_complexes = [
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
    ('../output/protein/6OH4.pdb', '../output/ligand/6OH4.sdf'),
]

# Get predictions with real ESM embeddings
predictions = predict_affinity(
    model_path=model_checkpoint,
    protein_ligand_pairs=test_complexes,
    output_path='affinity_predictions.csv',
    device='cuda',
    esm_model=esm_model  # Pass pre-loaded model
)

# Display results
print("\nPrediction Results:")
print("="*60)
for idx, row in predictions.iterrows():
    print(f"\nComplex {idx+1}:")
    print(f"  Protein: {Path(row['protein_path']).name}")
    print(f"  Ligand: {Path(row['ligand_path']).name}")
    print(f"  Predictions:")
    for task in ['pKi', 'pEC50', 'pKd', 'pIC50', 'pKd (Wang, FEP)', 'potency']:
        if task in predictions.columns and not pd.isna(row[task]):
            print(f"    {task}: {row[task]:.3f}")

print(f"\n✓ Predictions complete!")

Loading ESM model...


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ ESM model loaded
✓ Model loaded from ../output/experiments/gnn_dta_mtl_experiment_20250910_153508/models/final_model.pt
✓ Using device: cuda
Featurizing all protein-ligand pairs...


Featurizing: 100%|██████████| 10/10 [00:38<00:00,  3.89s/it]


Running batch prediction...
Predictions saved to affinity_predictions.csv

Prediction Results:

Complex 1:
  Protein: 6OH4.pdb
  Ligand: 6OH4.sdf
  Predictions:
    pKi: 5.644
    pEC50: 6.288
    pKd: 6.650
    pIC50: 5.907
    pKd (Wang, FEP): 0.332
    potency: -0.738

Complex 2:
  Protein: 6OH4.pdb
  Ligand: 6OH4.sdf
  Predictions:
    pKi: 5.644
    pEC50: 6.288
    pKd: 6.650
    pIC50: 5.907
    pKd (Wang, FEP): 0.332
    potency: -0.738

Complex 3:
  Protein: 6OH4.pdb
  Ligand: 6OH4.sdf
  Predictions:
    pKi: 5.644
    pEC50: 6.288
    pKd: 6.650
    pIC50: 5.907
    pKd (Wang, FEP): 0.332
    potency: -0.738

Complex 4:
  Protein: 6OH4.pdb
  Ligand: 6OH4.sdf
  Predictions:
    pKi: 5.644
    pEC50: 6.288
    pKd: 6.650
    pIC50: 5.907
    pKd (Wang, FEP): 0.332
    potency: -0.738

Complex 5:
  Protein: 6OH4.pdb
  Ligand: 6OH4.sdf
  Predictions:
    pKi: 5.644
    pEC50: 6.288
    pKd: 6.650
    pIC50: 5.907
    pKd (Wang, FEP): 0.332
    potency: -0.738

Complex 6:
  Protei

In [39]:
predictions

Unnamed: 0,protein_path,ligand_path,pKi,pEC50,pKd,pIC50,"pKd (Wang, FEP)",potency
0,../output/protein/6OH4.pdb,../output/ligand/6OH4.sdf,5.643721,6.287929,6.650221,5.907095,0.332366,-0.738483
