### ~ Overview of our project ~

1. We use a Graph Attention Network (GAT) on the protein–protein interaction (PPI) network to learn a numeric vector (embedding) for each protein.  
   These embeddings combine:
   
   - the protein’s own GNPC summary statistics, and  
   - information from its neighbors in the PPI graph
     

2. We train the model to predict **disease related protein signals** from the GNPC data, specifically one continuous score per protein for each of four neurodegenerative diseases:
   
3. 
   - Alzheimer’s disease (AD)  
   - Parkinson’s disease (PD)  
   - Frontotemporal dementia (FTD)  
   - ALS

4. We compare this **graph-based model** (GAT, and optionally GCN) against a **non-graph baseline model** (i.e. decision tree) that uses the same protein features but ignores the PPI network.

   **With that comparison we will see if using the PPI structure actually improve prediction performance**

5. We use the learned protein embeddings for downstream analysis:
   
   - cluster proteins in the embedding space,  
   - identify “disease modules” (clusters that are AD-enriched, ALS-enriched, or pan-neurodegenerative), and  
   - perform pathway enrichment to link each cluster to biological processes and pathways


### 1. Importing the MultiTaskGNN Model 

This imports the GAT model class defined in 'src/model.py'

The model includes:

- 2 GAT layers to learn shared embeddings 
- 4 output heads (AD, PD, FTD, and ALS)

Why this model?

- It uses GAT incorporate PPI structure
- It learns a shared embedding for each protein 
- It predicts four disease signals simultaneously (multi-task learning)

Why multi-task?

- The 4 diseases share biological pathways 
- Also joint learning captures shared and disease-specific information more effectively 
  

### 2. Hyperparameters:

- **lr (learning rate)** – step size when we correct mistakes 
- **weight_decay** – helps prevent overfitting
- **epochs** – how long we train
- **hidden_dim** – size of the protein embedding  
- **num_heads** – number of attention heads in GAT
- **dropout** – how much activation we drop to improve generalization
 


## 3. Loading the Pre-processed PPI Graph processed_graph.pt


### Node Features ('x')

Protein-level GNPC values (i.e beta coefs, p-values from the summary stats)
Each row = one protein

### Node Targets ('y')

Four disease summary statistics per protein: 

AD, PD, FTD, ALS 

### Edge Index ('edge_index')

The STRING PPI Network defininf which proteins interact

### Train / Val / Test Masks 

These ensure:

- we train on **some proteins**
- validate on **unseen proteins**
- test on **held-out proteins**

### The training loop runs 300 epocs: 
- At each epoch, we compute validation loss
- Whenever validation loss improves, we save best_weights



### 4. train_step (model, data, optimizer)

What the function does:

1. Runs the model
2. Gets 4 predictions (predicting summary statistics per protein) for (AD, PD, FTD, and ALS)
3. Computes loss only on 'train_mask' nodes
4. Uses **MSE** (Mean Squared Error) for each disease seperately

   For one protein and one disease:
   - True value from GNPC: y_true
   - Model's prediction: y_pred

--> Error = y_pred - y_true
   
6. Sums all 4 losses to get total multi-task loss
7. Backprop + weight update

**We are summing the losses becasue that way each disease contributes equally. The model learns a shared embedding useful for all four tasks.**




### 5. Evaluation of the Model on Validation or Test Nodes

### eval_model (model, data, mask)

- Turns gradients off #in training we want to compute gradients (information about how to change the weights to reduce error) but in evaluation (validation/test) we dont want to change the weights, we just want to measure how good the model is 
- Computes predictions for all proteins
- Applies loss only on the mask (val or test) (mask is the True/False vector, in which, when we say mask = which proteins we are using for this stage )
--> We randomly split proteins into training, validation, and test sets using Boolean masks. These masks indicate which proteins are used for learning (train), model selection (validation), and final performance estimation (test). We used approximately 70% of nodes for training, 15% for validation, and 15% for testing
- Returns the total multi-task loss

 **Used masked evaluation becasue we want to simulate how well the model generalized to '*unseen*' proteins**

### 6. Initialize the MultiTask GAT Model 

Creating GAT: 

- 2 GATv2 layer --1st layer: 
- 64 dimensional hidden representation #length of the embedding vector for each protein, each protein we end up with a 64 vector representing them, chosen based on empirical performance 
- 8 attention heads # 8 separate ways of looking at neighbors  
- shared embeddings # computing one embedding per protein, then use same embedding to predict AD, PD, FTD, ALS. 
- 4 disease outputs # for each protein, the model produces 4 outputs: 1 number for each of AD, PD, FTD, ALS. 

Then we create:

- Adam optimizer (default choice)
- Optimal weight decay for regularization #if weights get huge the model can overfit and memorize noise, penalizing large weights encourages simpler models that generalize better. 

**Aim:**  

Learn protein embeddings that combines 

- GNPC features  
- PPI neighborhood structure  
- Multi-disease signals  


### 7. Training Loop

1. Runs one training step
2. Computes validation loss
3. Checks if validation improved
4. Saves the best weights

**Why validation tracking?**
    
We want the model that performs best on unseen data, not necessarily the one that fits training data best, and eventually this prevents overfitting and ensures generalizability.

### 8. Final Step 

--> We load the best model weights test on held-out proteins, this tells us how well the model generalizes to new proteins it has never seen

--> Saving the learning weights to a file so we can later:

- extract node embeddings
- cluster proteins
- run UMAP
- study disease modules
- perform pathway enrichment
