# iCAM-Net: Inference and Case Study Reproduction

This Jupyter Notebook provides a step-by-step guide to perform inference using a pre-trained iCAM-Net model. It is designed to reproduce the case study results, such as predicting all disease associations for a specific herb, as mentioned in our paper.

**Objectives:**
1.  Load a pre-trained iCAM-Net model.
2.  Load all necessary data, including hypergraphs and entity mappings.
3.  Predict the association scores between a given herb and all diseases in the dataset.
4.  Display and save the top-ranked disease predictions.

This notebook corresponds to the request from **Reviewer #3, Issue 7**, to provide a dedicated script for reproducing key analyses.

In [19]:
import torch
import pandas as pd
from tqdm.notebook import tqdm  # Use tqdm.notebook for better display in Jupyter
import os
import sys

# Add the project's source code directory to the Python path
# This allows us to import our custom modules
# Please adjust the path if your notebook is not in the 'code' directory
sys.path.append('../code/') # Or the correct path to your .py files

# Import custom modules from the project
# Make sure the file names (graph.py, models.py) match your project
from graph import build_HC, build_DP
from models import iCAM

## 1. Configuration Setup

In the next cell, you need to configure the paths and parameters for the inference task. Please modify these variables according to your local setup.

In [20]:
# --- User Configuration Area ---

# 1. Path to the pre-trained model weights (.pth file)
# Example: A model from sensitivity analysis
MODEL_PATH = "/path/to/your/best_model.pth"

# 2. Root directory containing all data files (CSVs, embeddings, etc.)
DATA_ROOT_DIR = "../data/"

# 3. The original ID of the herb you want to predict for.
HERB_ID_TO_PREDICT = 2545

# 4. Name for the output CSV file that will store the predictions.
OUTPUT_CSV_NAME = f"predictions_herb_{HERB_ID_TO_PREDICT}.csv"

# 5. The GPU device to use for inference. Use "cpu" if you don't have a GPU.
DEVICE = "cuda:0"

# --- End of Configuration ---

# Validate paths
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model file not found at: {MODEL_PATH}")
if not os.path.exists(DATA_ROOT_DIR):
    raise FileNotFoundError(f"Data directory not found at: {DATA_ROOT_DIR}")

print("Configuration loaded successfully.")

Configuration loaded successfully.


## 2. The Predictor Class

Here, we define the `HerbDiseasePredictor` class. This class encapsulates all the logic for loading data, initializing the model, and performing predictions.

In [21]:
class HerbDiseasePredictor:
    """
    A predictor class to facilitate inference with a pre-trained iCAM-Net model.
    """
    def __init__(self, model_path, data_dir, device="cuda:0"):
        """
        Initializes the predictor.

        Args:
            model_path (str): Path to the pre-trained model weights (.pth file).
            data_dir (str): Directory containing dataset CSVs and embedding files.
            device (str): The device to run the model on (e.g., "cuda:0" or "cpu").
        """
        self.device = torch.device(device if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        self.model_path = model_path
        self.data_dir = data_dir
        
        # Load data, hypergraphs, and ID mappings
        print("\nStep 1: Loading data and building hypergraphs...")
        self._load_data()
        
        # Load the pre-trained model
        print("\nStep 2: Loading the pre-trained model...")
        self.model = self._load_model()
        self.model.eval()
        print("Initialization complete. Predictor is ready.")

    def _load_data(self):
        """Loads all necessary data files and creates mappings."""
        # Define file paths using the data root directory
        hc_path = os.path.join(self.data_dir, "H_C_TCM.csv")
        ce_path = os.path.join(self.data_dir, "ce_TCM.txt")
        dp_path = os.path.join(self.data_dir, "D_P_TCM.csv")
        pe_path = os.path.join(self.data_dir, "pe_TCM.txt")
        
        # Build herb-component hypergraph
        self.H_C, self.X_compounds, self.herb_to_idx, self.comp_to_idx, self.herb2comp = build_HC(
            hc_path, ce_path, device=self.device
        )
        # Build disease-protein hypergraph
        self.D_P, self.X_proteins, self.disease_to_idx, self.protein_to_idx, self.disease2prot = build_DP(
            dp_path, pe_path, device=self.device
        )
        
        # Create a reverse mapping from internal index back to original disease ID
        self.idx_to_disease = {v: k for k, v in self.disease_to_idx.items()}
        print("Data loading successful.")

    def _load_model(self):
        """Loads and initializes the iCAM-Net model architecture."""
        # Define model parameters (these should match the parameters used during training)
        in_dim = self.X_compounds.shape[1]
        hidden_dim = 128
        out_dim = 64
        n_layers = 3
        
        # Instantiate the model
        model = iCAM(
            herb_graph=self.H_C,
            disease_graph=self.D_P,
            X_C=self.X_compounds,
            X_P=self.X_proteins,
            in_dim=in_dim,
            hidden_dim=hidden_dim,
            out_dim=out_dim,
            device=self.device,
            n_layers=n_layers,
            herb2comp=self.herb2comp,
            disease2prot=self.disease2prot,
            return_attn=False  # Set to False for faster inference
        ).to(self.device)
        
        # Load the saved weights
        model.load_state_dict(torch.load(self.model_path, map_location=self.device, weights_only=True))
        print("Model weights loaded successfully.")
        return model

    @torch.no_grad()
    def predict_all_diseases_for_herb(self, herb_id, batch_size=256):
        """
        Predicts association scores between a specific herb and all diseases in the dataset.

        Args:
            herb_id (int): The original ID of the herb to predict for.
            batch_size (int): Batch size for inference to manage memory and improve speed.
            
        Returns:
            pd.DataFrame: A DataFrame with all diseases and their predicted scores, sorted descending.
        """
        if herb_id not in self.herb_to_idx:
            raise ValueError(f"Error: Herb ID {herb_id} not found in the dataset.")
        
        print(f"\nStep 3: Starting prediction for Herb ID {herb_id} against all diseases...")
        
        herb_idx = self.herb_to_idx[herb_id]
        herb_tensor = torch.tensor([herb_idx], device=self.device)
        
        all_results = []
        num_diseases = len(self.disease_to_idx)
        
        # Create a progress bar for the prediction loop
        progress_bar = tqdm(range(0, num_diseases, batch_size), desc=f"Predicting (Batch Size: {batch_size})")

        for i in progress_bar:
            end_idx = min(i + batch_size, num_diseases)
            disease_indices_batch = list(range(i, end_idx))
            
            # Prepare batch tensors
            disease_tensor_batch = torch.tensor(disease_indices_batch, device=self.device)
            herb_tensor_batch = herb_tensor.expand(len(disease_indices_batch))
            
            # Get model predictions for the batch
            scores_batch = self.model(herb_tensor_batch, disease_tensor_batch)
            
            # Collect results for each item in the batch
            for j, disease_idx in enumerate(disease_indices_batch):
                original_disease_id = self.idx_to_disease[disease_idx]
                score = scores_batch[j].item()
                
                all_results.append({
                    "Disease Index": disease_idx,
                    "Disease ID": original_disease_id,
                    "Prediction Score": score
                })
        
        # Convert list of dictionaries to a Pandas DataFrame
        results_df = pd.DataFrame(all_results)
        # Sort the results by prediction score in descending order
        results_df = results_df.sort_values(by="Prediction Score", ascending=False).reset_index(drop=True)
        
        print("Prediction for all diseases is complete.")
        return results_df

## 3. Run Inference

Now we will instantiate the predictor and run the prediction for the herb specified in the configuration section.

In [22]:
# Instantiate the predictor
predictor = HerbDiseasePredictor(model_path=MODEL_PATH, data_dir=DATA_ROOT_DIR, device=DEVICE)

# Run the prediction
predictions_df = predictor.predict_all_diseases_for_herb(herb_id=HERB_ID_TO_PREDICT)

# Save the full results to a CSV file
predictions_df.to_csv(OUTPUT_CSV_NAME, index=False, encoding='utf-8-sig')

print(f"\nStep 4: Prediction finished!")
print(f"Full results have been saved to: {OUTPUT_CSV_NAME}")

Using device: cuda:0

Step 1: Loading data and building hypergraphs...
Data loading successful.

Step 2: Loading the pre-trained model...
Model weights loaded successfully.
Initialization complete. Predictor is ready.

Step 3: Starting prediction for Herb ID 2545 against all diseases...


Predicting (Batch Size: 256):   0%|          | 0/43 [00:00<?, ?it/s]

Prediction for all diseases is complete.

Step 4: Prediction finished!
Full results have been saved to: predictions_herb_2545.csv
