In [1]:
cd /home/tvangraft/tudelft/thesis/metaengineering

/home/tvangraft/tudelft/thesis/metaengineering


In [2]:
from collections import defaultdict
from typing import DefaultDict, List, Hashable, Dict, Any

from src.utils.utils import get_generator, get_project_root
from src.utils.test_result_store import TestResultStore

from src.pipeline.config import DataLoaderConfig, TaskLoaderConfig
from src.pipeline.taskloader import TaskLoader, TaskFrame
from src.pipeline.dataloader import DataLoader

from src.orchestrator.trainer import Trainer

from src.settings.tier import Tier
from src.settings.strategy import Strategy
from src.settings.metabolites import ENZYMES, METABOLITES, PRECURSOR_METABOLITES, PRECURSOR_METABOLITES_NO_TRANSFORM

from src.gnn.data_augmentation import DataAugmentation
from src.gnn.embeddings import generate_embedding
from src.gnn.graph_builder import get_samples_hetero_graph, get_graph_fc, get_graph_fc_protein_only, edge_index_from_df_protein_only, get_samples_graph
from src.gnn.shared_training import count_parameters, log_metrics, tune_metabolite_hyper_parameters

import pandas as pd
import numpy as np

import cobra
from cobra.util import create_stoichiometric_matrix
from cobra.core import Reaction

import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

from more_itertools import flatten

from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
from sklearn.metrics import mean_absolute_error

from scipy.stats import pearsonr

import os
from functools import partial

from tqdm import tqdm

import torch
from torch.nn import BatchNorm1d, ModuleList
import torch.nn.functional as F

from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import from_networkx, to_networkx
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch_geometric.nn import GAT, GCNConv, to_hetero, SAGEConv, GATConv, HeteroLinear, Linear, Node2Vec
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch_geometric.transforms as T
from torch_geometric.nn.conv import HeteroConv

import mlflow.pytorch

from config import HYPERPARAMETERS, BEST_PARAMETERS

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
from ray.air import session, RunConfig
from ray.tune.integration.mlflow import mlflow_mixin
from ray.tune.integration.mlflow import MLflowLoggerCallback

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mlflow.set_tracking_uri("http://localhost:5000")
device = torch.device("cpu")
torch.manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7fe1aa538730>

In [3]:
path = "/home/tvangraft/tudelft/thesis/metaengineering/data"
model = cobra.io.read_sbml_model(f'{path}/iMM904.xml')

In [4]:
def prepare_data(path: str, valid_metabolites: List[str]):
    edge_list_df = pd.read_csv(path)
    graph_fc = get_graph_fc_protein_only(edge_list_df, valid_metabolites)
    edge_index = edge_index_from_df_protein_only(graph_fc, edge_list_df, valid_metabolites)
    embedding = generate_embedding(edge_index.T, device, hetero=False)

    return edge_list_df, graph_fc, edge_index, embedding

# Model

In [5]:
class GATModel(torch.nn.Module):
    def __init__(self, model_config) -> None:
        super(GATModel, self).__init__()
        embedding_size = model_config["model_embedding_size"]
        n_heads = model_config["model_attention_heads"]
        self.n_layers = model_config["model_layers"]
        
        self.conv_layers = ModuleList([])
        self.transf_layers = ModuleList([])
        self.pooling_layers = ModuleList([])
        self.bn_layers = ModuleList([])
        
        self.conv1 = GATConv(
            -1, out_channels=embedding_size, heads=n_heads, add_self_loops=False, bias=False
        )
        self.transf1 = Linear(
            in_channels=embedding_size*n_heads, 
            out_channels=embedding_size, 
        )

        for i in range(self.n_layers):
            self.conv_layers.append(
                GATConv(
                    -1, 
                    out_channels=embedding_size, 
                    heads=n_heads, 
                    add_self_loops=False
                )
            )
            self.transf_layers.append(
                Linear(
                    embedding_size*n_heads, 
                    embedding_size
                )
            )

        self.linear1 = Linear(embedding_size * 2, embedding_size)
        self.linear2 = Linear(embedding_size, model_config['output_size'])
    
    def forward(self, x, edge_index, batch_index):
        x = self.conv1(x, edge_index)

        global_representation = []

        for i in range(self.n_layers):
            x = self.conv_layers[i](x, edge_index)
            print(x.shape)
            x = torch.relu(self.transf_layers[i](x))
            print(x.shape)
            global_representation.append(torch.cat([gmp(x, batch_index), gap(x, batch_index)], dim=1))

        # This generates the last embeddings for all enzymes in the graph
        x = sum(global_representation)
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.8, training=self.training)
        x = self.linear2(x)
        return x 

# Setup functions

In [6]:
def train_one_epoch(epoch, model, train_loader, optimizer, loss_fn, debug=False):
    # Enumerate over the data
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(train_loader):
        # Use GPU
        batch.to(device)
        # Reset gradients
        optimizer.zero_grad() 
        # Passing the node features and the connection info
        pred = model.forward(
            batch.x, 
            batch.edge_index,
            batch.batch
        )
        # Calculating the loss and gradients
        train_mask = torch.reshape(batch.train_mask.bool(), pred.shape)

        if debug:
            print(batch)
            print(f"{pred.shape=}")
            print(f"{train_mask.sum()=}")
            print(
                f"{pred.shape=} \n"
                f"{pred[train_mask].mean()=} \n"
                f"{pred[train_mask].max()=} \n"
                f"{pred[train_mask].min()=} \n"
                f"{pred[train_mask].shape=} \n"
            )
            print(
                f"{torch.squeeze(pred).shape} \n"
                f"{torch.squeeze(batch.y.float()).shape} \n"
            )

        loss = loss_fn(
            torch.squeeze(pred), 
            torch.squeeze(torch.reshape(batch.y.float(), pred.shape))
        )

        loss.backward()  
        optimizer.step()  
        # Update tracking
        running_loss += loss.item()
        step += 1
    
    tune.report(loss=(running_loss/step))

    return running_loss/step
    
@mlflow_mixin
def test(epoch, model, test_loader, loss_fn, debug=False):
    all_preds_raw = []
    all_labels = []
    all_knockout_ids = []
    running_loss = 0.0
    step = 0
    for batch in test_loader:
        batch.to(device)
        pred = model(
            batch.x, 
            batch.edge_index,
            batch.batch,
        )
        test_mask = torch.reshape(batch.test_mask.bool(), pred.shape)

        if debug:
            print(batch)  
            print(f"{test_mask.sum()=}")
            print(
                f"{pred.shape=} \n"
                f"{pred[test_mask].mean()=} \n"
                f"{pred[test_mask].max()=} \n"
                f"{pred[test_mask].min()=} \n"
                f"{pred[test_mask].shape=} \n"
            ) 
        
        loss = loss_fn(
            torch.squeeze(torch.nan_to_num(pred[test_mask])), 
            torch.squeeze(torch.nan_to_num(batch.y.float()[torch.squeeze(test_mask)]))
        )

         # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds_raw.append(torch.nan_to_num(pred[test_mask]).cpu().detach().numpy())
        all_labels.append(torch.nan_to_num(batch.y[torch.squeeze(test_mask)]).cpu().detach().numpy())
        # all_knockout_ids.append(batch['enzymes'].knockout_label_id.cpu().detach().numpy())
    
    all_preds_raw = np.concatenate(all_preds_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    # all_knockout_ids = np.concatenate(all_knockout_ids).ravel()
    log_metrics(all_preds_raw, all_labels, all_knockout_ids, epoch, "test")
    return running_loss/step

def build_model(model_config, output_size, debug=False):
    if debug:
        print(f"creating model {model_config=}")
    # Set the number of outputs we expect
    model_config['output_size'] = output_size
    params = model_config

    if 'mlflow' in model_config:
        run_id = model_config['mlflow']['tags']['mlflow.parentRunId']
        mlflow.set_tag("mlflow.parentRunId", run_id)
    # Logging params
    for key in params.keys():
        mlflow.log_param(key, params[key])
  
    # Loading the model
    print("Loading model...")
    model = GATModel(model_config=params)
    model = model.to(device)

    return model

@mlflow_mixin
def run_one_training(model_config, train_samples, test_samples, checkpoint_dir):
    # Build the model
    model = build_model(model_config, train_samples[0].y.shape[0])

    # Preparing training
    train_loader = GeoDataLoader(train_samples, batch_size=model_config['batch_size'])
    test_loader = GeoDataLoader(test_samples, batch_size=1)

    # < 1 increases precision, > 1 recall
    # loss_fn = torch.nn.MSELoss(reduction='none')
    loss_fn = torch.nn.MSELoss()
    # we need to keep the lr quite low since otherwise the weights explode
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=model_config['learning_rate'],
        momentum=model_config['sgd_momentum'],
        # weight_decay=5e-4
    )
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=model_config['scheduler_gamma'])

    # Use this to debug the train and test function
    debug = False
    
    # Start training
    best_loss = 1000
    early_stopping_counter = 0
    max_epochs = 300
    for epoch in tqdm(range(max_epochs)): 
        if early_stopping_counter <= 25: # = x * 5 
            # Training
            model.train()
            loss = train_one_epoch(epoch, model, train_loader, optimizer, loss_fn, debug=debug)
            # print(f"Epoch {epoch} | Train Loss {loss}")
            mlflow.log_metric(key="Train loss", value=float(loss), step=epoch)
            
            # Testing
            model.eval()
            if epoch % 5 == 0 or epoch == max_epochs - 1:
                loss = test(epoch, model, test_loader, loss_fn, debug=debug)
                # print(f"Epoch {epoch} | Test Loss {loss}")
                mlflow.log_metric(key="Test loss", value=float(loss), step=epoch)
                
                # Update best loss
                if float(loss) < best_loss:
                    best_loss = loss
                    # Save the currently best model 
                    # mlflow.pytorch.log_model(model, "model", signature=SIGNATURE)
                    early_stopping_counter = 0
                else:
                    early_stopping_counter += 1

            scheduler.step()
            mlflow.log_metric(key="Learning rate", value=float(scheduler.get_last_lr()[0]), step=epoch)
            
        else:
            print("Early stopping due to no improvement.")
            session.report({
                "loss": best_loss
            })
            return {"loss": best_loss}
    print(f"Finishing training with best test loss: {best_loss}")

    with torch.no_grad():
        sample = train_samples[0].to(device)
        model.forward(sample.x_dict, sample.edge_index_dict)
        print(f"Number of parameters: {count_parameters(model)}")

    session.report({
        "loss": best_loss
    })

    mlflow.end_run()

    return {"loss": best_loss}

In [7]:
HYPERPARAMETERS = {
    "batch_size": tune.choice([2, 4, 8]),
    "learning_rate": tune.choice([0.1, 0.05, 0.01, 0.001]),
    "sgd_momentum": tune.choice([0.9, 0.8, 0.5]),
    "scheduler_gamma": tune.choice([0.995, 1]),
    "model_embedding_size": tune.choice([8, 16, 32, 64, 128]),
    "model_attention_heads": tune.choice([1, 2, 3, 4]),
    "model_layers": tune.choice([1, 3, 5, 7]),
}

In [8]:
edge_list_df_unfiltered, graph_fc_unfiltered, edge_index_unfiltered, embedding_unfiltered = prepare_data(
    './data/training/edge_list_unfiltered_protein_only.csv',
    PRECURSOR_METABOLITES_NO_TRANSFORM,
)
edge_list_df_strict, graph_fc_strict, edge_index_strict, embedding_strict = prepare_data(
    './data/training/edge_list_strict_protein_only.csv',
    PRECURSOR_METABOLITES,
)
edge_list_df_all, graph_fc_all, edge_index_all, embedding_all = prepare_data(
    './data/training/edge_list_all_protein_only.csv',
    METABOLITES,
)

train set	 [48 22 25  8  8 57 54 41  5 52]
test set 	 [23 44 56 16 23 51  5 46 43  2]
val set  	 [62 47  2 14 14 44 42 30 68 35]
train mask 	 tensor([1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
test mask  	 tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0])
val mask   	 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0])
Epoch: 10, Loss: 1.9115
Epoch: 20, Loss: 1.4799
Epoch: 30, Loss: 1.1357
Epoch: 40, Loss: 0.9913
Epoch: 50, Loss: 0.8774
Epoch: 60, Loss: 0.8635
Epoch: 70, Loss: 0.7983
Epoch: 80, Loss: 0.7897
Epoch: 90, Loss: 0.8029
Epoch: 100, Loss: 0.7558
train set	 [12 17  7 56 60 47  5 72 48 62]
test set 	 [21 19 50 45 52 31 16 63 10 23]
val set  	 [ 3  6 55 38  1 66 31 35 67 27]
train mask 	 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1])
test mask  	 tensor([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
val mask   	 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0])
Epoch: 10, Loss: 1.9326
Epoch: 20, Loss: 1.3530
Epoch: 30, Loss: 1.0757
Epoch: 40, Loss: 0.9186
Epo

In [9]:
def get_train_test_split(mode, metabolite_id, strategy):
    if mode == 'unfiltered':
        train_samples, test_samples = get_samples_graph(
            metabolite_id, strategy, PRECURSOR_METABOLITES_NO_TRANSFORM, graph_fc_unfiltered, edge_index_unfiltered, embedding_unfiltered
        )
    elif mode == 'strict':
        train_samples, test_samples = get_samples_graph(
            target_metabolite_id=metabolite_id, strategy=strategy, valid_metabolites=PRECURSOR_METABOLITES, graph_fc_df=graph_fc_strict, edge_index=edge_index_strict, node_embeddings=embedding_strict,
        )
    elif mode == 'all':
        train_samples, test_samples = get_samples_graph(
            metabolite_id, strategy, METABOLITES, graph_fc_all, edge_index_all, embedding_all
        )
    return train_samples, test_samples

In [None]:
mlflow.end_run()
train_samples, test_samples = get_train_test_split('unfiltered', 'pyr', Strategy.ALL)
print(test_samples[0])

model_config = {
    "batch_size": 6,
    "learning_rate": 0.001,
    "sgd_momentum": 0.8,
    "scheduler_gamma": 1,
    "model_embedding_size": 16,
    "model_attention_heads": 1,
    "model_layers": 1,
}

run_one_training(model_config, train_samples, test_samples, None)
mlflow.end_run()

# Training

In [11]:
mlflow.end_run()
for strategy in [Strategy.ONE_VS_ALL, Strategy.METABOLITE_CENTRIC]:
    for mode, graph_fc in [('unfiltered', graph_fc_unfiltered), ('strict', graph_fc_strict), ('all', graph_fc_all)]:
    # for mode, graph_fc in [('unfiltered', graph_fc_unfiltered)]:
        run_name = "model_gat_node_embeddings"
        experiment_name = f'protein_only_sweep_{mode}'
        mlflow.set_experiment(experiment_name)
        with mlflow.start_run(run_name=run_name) as run:
            for metabolite_id in list(set(graph_fc.columns.difference(ENZYMES).to_list()) & set(PRECURSOR_METABOLITES)):
                print(f"training {metabolite_id=}")
                train_samples, test_samples = get_train_test_split(mode, metabolite_id, strategy)
                print(len(train_samples))

                with mlflow.start_run(run_name=f"model_{metabolite_id}_{strategy}", nested=True):
                    result = tune_metabolite_hyper_parameters(experiment_name, HYPERPARAMETERS, train_samples, test_samples, run_fn=run_one_training, num_samples=16)
                    print(result)

training metabolite_id='accoa'
95


2022-12-19 14:35:49,042	INFO worker.py:1538 -- Started a local Ray instance.

from ray.air import session

def train(config):
    # ...
    session.report({"metric": metric}, checkpoint=checkpoint)

For more information please see https://docs.ray.io/en/master/tune/api_docs/trainable.html

