In [1]:
cd /home/tvangraft/tudelft/thesis/metaengineering

/home/tvangraft/tudelft/thesis/metaengineering


In [2]:
from collections import defaultdict
from typing import DefaultDict, List, Hashable, Dict, Any

from src.utils.utils import get_generator, get_project_root
from src.utils.test_result_store import TestResultStore

from src.pipeline.config import DataLoaderConfig, TaskLoaderConfig
from src.pipeline.taskloader import TaskLoader, TaskFrame
from src.pipeline.dataloader import DataLoader

from src.orchestrator.trainer import Trainer

from src.settings.tier import Tier
from src.settings.strategy import Strategy
from src.settings.metabolites import ENZYMES, METABOLITES, PRECURSOR_METABOLITES, PRECURSOR_METABOLITES_NO_TRANSFORM

from src.gnn.data_augmentation import DataAugmentation
from src.gnn.embeddings import generate_embedding
from src.gnn.graph_builder import get_samples_hetero_graph, get_graph_fc

import pandas as pd
import numpy as np

import cobra
from cobra.util import create_stoichiometric_matrix
from cobra.core import Reaction

import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns

from more_itertools import flatten

from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
from sklearn.metrics import mean_absolute_error

from scipy.stats import pearsonr

import os
from functools import partial

from tqdm import tqdm

import torch
from torch.nn import BatchNorm1d, ModuleList
import torch.nn.functional as F

from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import from_networkx, to_networkx
from torch_geometric.loader import DataLoader as GeoDataLoader
from torch_geometric.nn import GAT, GCNConv, to_hetero, SAGEConv, GATConv, HeteroLinear, Linear, Node2Vec
import torch_geometric.transforms as T
from torch_geometric.nn.conv import HeteroConv

import mlflow.pytorch

from config import HYPERPARAMETERS, BEST_PARAMETERS

from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.air import session, RunConfig
from ray.tune.integration.mlflow import mlflow_mixin
from ray.tune.integration.mlflow import MLflowLoggerCallback

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mlflow.set_tracking_uri("http://localhost:5000")
device = torch.device("cpu")
torch.manual_seed(42)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f90e75af330>

In [3]:
path = "/home/tvangraft/tudelft/thesis/metaengineering/data"
model = cobra.io.read_sbml_model(f'{path}/iMM904.xml')

Scaling...
 A: min|aij| =  1.000e+00  max|aij| =  1.000e+00  ratio =  1.000e+00
Problem data seem to be well scaled


In [4]:
edge_list_df_unfiltered = pd.read_csv('./data/training/edge_list_unfiltered.csv')
graph_fc_unfiltered = get_graph_fc(edge_list_df_unfiltered, PRECURSOR_METABOLITES_NO_TRANSFORM)
embedding_unfiltered = generate_embedding(graph_fc_unfiltered, edge_list_df_unfiltered, PRECURSOR_METABOLITES_NO_TRANSFORM, device)

edge_list_df_strict = pd.read_csv('./data/training/edge_list_strict.csv')
graph_fc_strict = get_graph_fc(edge_list_df_strict, PRECURSOR_METABOLITES)
embedding_strict = generate_embedding(graph_fc_strict, edge_list_df_strict, PRECURSOR_METABOLITES, device)

edge_list_df_all = pd.read_csv('./data/training/edge_list_all.csv')
graph_fc_all = get_graph_fc(edge_list_df_all, METABOLITES)
embedding_all = generate_embedding(graph_fc_all, edge_list_df_all, METABOLITES, device)

_metabolites_of_interest=['pyr', 'e4p', 'accoa', 'r5p', 'akg', 'f6p', 'oaa']
train set	 [23 42 49 53 37  5 13 39 59 74]
test set 	 [38 22 19 44 62 69 50 46  7 69]
val set  	 [12 33 68 21 54 47 41 20 60 18]
train mask 	 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
test mask  	 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
val mask   	 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Epoch: 10, Loss: 2.1837
Epoch: 20, Loss: 1.6550
Epoch: 30, Loss: 1.4554
Epoch: 40, Loss: 1.2567
Epoch: 50, Loss: 1.0666
Epoch: 60, Loss: 0.9513
Epoch: 70, Loss: 0.8678
Epoch: 80, Loss: 0.8396
Epoch: 90, Loss: 0.8221
Epoch: 100, Loss: 0.7696
_metabolites_of_interest=['3pg;2pg', 'pep', 'e4p', 'accoa', 'pyr', 'r5p', 'akg', 'g6p;f6p;g6p-B', 'f6p', 'oaa', 'dhap', 'g6p;g6p-B']
train set	 [32 74 14 53 28  2 36  5 10 43]
test set 	 [20 70 66 82 64 50 88  8 55 69]
val set  	 [85 51 27 31  1  6 76 48 52 67]
train mask 	 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
test mask  	 tensor([0,

# Modeling

## Data preparation

## Model prep

In [5]:
def count_parameters(model):
    # for p in model.parameters():
    #     print(p)

    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def log_metrics(all_preds, all_ground_truth, all_knockout_ids, epoch, type: str, debug=False):
    mae = mean_absolute_error(all_ground_truth, all_preds)
    r2 = pearsonr(all_preds, all_ground_truth)[0]
    # all_knockout_ids = [data_augmentation_get_knockout_label(knockout_id) for knockout_id in all_knockout_ids]
    
    # k = np.array(X_test[X_test['metabolite_id'] == 'pyr']['KO_ORF'].unique())
    # mask_idx = np.argwhere(np.isin(all_knockout_ids, k)).flatten()
    # masked_mae = mean_absolute_error(all_ground_truth[mask_idx], all_preds[mask_idx])
    # masked_r2 = pearsonr(all_preds[mask_idx], all_ground_truth[mask_idx])[0]
    
    if debug:
        print(f"{mae=}")
        print(f"{r2=}")
    mlflow.log_metric(key="Mean absolute error", value=float(mae), step=epoch)
    mlflow.log_metric(key="R2 score", value=float(r2), step=epoch)
    # mlflow.log_metric(key="Masked Mean absolute error", value=float(masked_mae), step=epoch)
    # mlflow.log_metric(key="Masked R2 score", value=float(masked_r2), step=epoch)

In [21]:
def train_one_epoch(epoch, model, train_loader, optimizer, loss_fn):
    # Enumerate over the data
    running_loss = 0.0
    step = 0
    for _, batch in enumerate(train_loader):
        # Use GPU
        batch.to(device)
        # Reset gradients
        optimizer.zero_grad() 
        # Passing the node features and the connection info
        pred = model.forward(
            batch.x_dict, 
            batch.edge_index_dict,
            # batch.edge_attr.float(),
            # batch.batch
        ) 
        # Calculating the loss and gradients
        train_mask = batch['metabolites'].train_mask.bool()
        loss = loss_fn(
            torch.squeeze(torch.nan_to_num(pred['metabolites'])), 
            batch['metabolites'].y.float()
        )

        loss.backward()  
        optimizer.step()  
        # Update tracking
        running_loss += loss.item()
        step += 1
    
    with tune.checkpoint_dir(epoch) as checkpoint_dir:
        path = os.path.join(checkpoint_dir, "checkpoint")
        torch.save((model.state_dict(), optimizer.state_dict()), path)

    tune.report(loss=(running_loss/step))

    return running_loss/step
    
@mlflow_mixin
def test(epoch, model, test_loader, loss_fn, debug=False):
    all_preds_raw = []
    all_labels = []
    all_knockout_ids = []
    running_loss = 0.0
    step = 0
    for batch in test_loader:
        batch.to(device)
        # print(batch)  
        pred = model(
            batch.x_dict, 
            batch.edge_index_dict,
        )
        test_mask = batch['metabolites'].test_mask.bool()

        if debug:
            print(f"{test_mask.sum()=}")
            print(
                f"{pred['metabolites'].shape=} \n"
                f"{pred['metabolites'][test_mask].mean()=} \n"
                f"{pred['metabolites'][test_mask].max()=} \n"
                f"{pred['metabolites'][test_mask].min()=} \n"
                f"{pred['metabolites'][test_mask].shape=} \n"
            ) 
        
        loss = loss_fn(
            torch.squeeze(pred['metabolites'][test_mask]), 
            torch.squeeze(batch['metabolites'].y.float()[test_mask])
        )

         # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds_raw.append(pred['metabolites'][test_mask].cpu().detach().numpy())
        all_labels.append(batch['metabolites'].y[test_mask].cpu().detach().numpy())
        # all_knockout_ids.append(batch['enzymes'].knockout_label_id.cpu().detach().numpy())
    
    all_preds_raw = np.concatenate(all_preds_raw).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    # all_knockout_ids = np.concatenate(all_knockout_ids).ravel()
    log_metrics(all_preds_raw, all_labels, all_knockout_ids, epoch, "test")
    return running_loss/step

## Model construction

In [22]:
class HeteroGCNModel(torch.nn.Module):
    def __init__(self, model_config) -> None:
        super(HeteroGCNModel, self).__init__()
        embedding_size = model_config["model_embedding_size"]
        n_heads = model_config["model_attention_heads"]
        self.n_layers = model_config["model_layers"]
        
        self.conv_layers = ModuleList([])
        self.transf_layers = ModuleList([])
        self.pooling_layers = ModuleList([])
        self.bn_layers = ModuleList([])
        
        self.conv1 = GATConv(
            (-1, -1), out_channels=embedding_size, heads=n_heads, add_self_loops=False, bias=False
        )
        self.transf1 = Linear(
            in_channels=embedding_size*n_heads, 
            out_channels=embedding_size, 
        )

        for i in range(self.n_layers):
            self.conv_layers.append(
                GATConv(
                    (-1, -1), 
                    out_channels=embedding_size, 
                    heads=n_heads, 
                    add_self_loops=False
                )
            )
            self.transf_layers.append(
                Linear(
                    embedding_size*n_heads, 
                    embedding_size
                )
            )

        self.conv2 = GATConv(
            (-1, -1), out_channels=1, add_self_loops=False, bias=False
        )

        self.linear1 = Linear(n_heads * embedding_size, embedding_size)
        self.linear2 = Linear(embedding_size, 1)
    
    def forward(self, x, edge_index):
        # random weights for metabolite nodes should cancel out their contribution
        # metabolite_fc = torch.rand(x.shape[0], device=device)
        
        x = self.conv1(x, edge_index)
        # print(x)
        # x = torch.relu(self.transf1(x))

        for i in range(self.n_layers):
            x = self.conv_layers[i](x, edge_index)
            x = torch.relu(self.transf_layers[i](x))

        x = self.conv2(x, edge_index)
        
        # x = torch.relu(self.linear1(x))
        # x = F.dropout(x, p=0.8, training=self.training)
        # x = self.linear2(x)
        return x 

## Model running

In [23]:
model_config = {
    "model_embedding_size": 64,
    "model_attention_heads": 3,
    "model_layers": 5,
    "batch_size": 4,
    "learning_rate": 0.01,
    "sgd_momentum": 0.8,
    "scheduler_gamma": 1,
}

@mlflow_mixin
def run_one_training(model_config, samples, checkpoint_dir):
    print(f"creating model {model_config=}")
    params = model_config
    run_id = model_config['mlflow']['tags']['mlflow.parentRunId']
    with mlflow.start_run(nested=True) as run:
        mlflow.set_tag("mlflow.parentRunId", run_id)
        # Logging params
        for key in params.keys():
            mlflow.log_param(key, params[key])

        # Preparing training
        train_loader = GeoDataLoader(samples, batch_size=params['batch_size'])
        test_loader = GeoDataLoader(samples, batch_size=1)
        
        # Loading the model
        print("Loading model...")
        model = HeteroGCNModel(model_config=params)
        model = to_hetero(model, samples[0].metadata(), aggr='mean')
        model = model.to(device)

        if checkpoint_dir:
            model_state, optimizer_state = torch.load(
                os.path.join(checkpoint_dir, "checkpoint"))
            model.load_state_dict(model_state)
            optimizer.load_state_dict(optimizer_state)

        # < 1 increases precision, > 1 recall
        # loss_fn = torch.nn.MSELoss(reduction='none')
        loss_fn = torch.nn.MSELoss()
        # we need to keep the lr quite low since otherwise the weights explode
        optimizer = torch.optim.SGD(
            model.parameters(), 
            lr=params['learning_rate'],
            momentum=params['sgd_momentum'],
            # weight_decay=5e-4
        )
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params['scheduler_gamma'])
        
        # Start training
        best_loss = 1000
        early_stopping_counter = 0
        max_epochs = 1
        for epoch in tqdm(range(max_epochs)): 
            if early_stopping_counter <= 10: # = x * 5 
                # Training
                model.train()
                loss = train_one_epoch(epoch, model, train_loader, optimizer, loss_fn)
                # print(f"Epoch {epoch} | Train Loss {loss}")
                mlflow.log_metric(key="Train loss", value=float(loss), step=epoch)

                # Testing
                model.eval()
                if epoch % 5 == 0 or epoch == max_epochs - 1:
                    loss = test(epoch, model, test_loader, loss_fn)
                    # print(f"Epoch {epoch} | Test Loss {loss}")
                    mlflow.log_metric(key="Test loss", value=float(loss), step=epoch)
                    
                    # Update best loss
                    if float(loss) < best_loss:
                        best_loss = loss
                        # Save the currently best model 
                        # mlflow.pytorch.log_model(model, "model", signature=SIGNATURE)
                        early_stopping_counter = 0
                    else:
                        early_stopping_counter += 1

                scheduler.step()
                mlflow.log_metric(key="Learning rate", value=float(scheduler.get_last_lr()[0]), step=epoch)
            else:
                print("Early stopping due to no improvement.")
                session.report({
                    "loss": best_loss
                })
                return {"loss": best_loss}
    print(f"Finishing training with best test loss: {best_loss}")

    with torch.no_grad():
        sample = samples[0].to(device)
        model.forward(sample.x_dict, sample.edge_index_dict)
        print(f"Number of parameters: {count_parameters(model)}")

    session.report({
        "loss": best_loss
    })
    return {"loss": best_loss}

In [24]:
def get_dataloader(samples, staregy, bs):
    train_loader = GeoDataLoader(samples, batch_size=bs)
    test_loader = GeoDataLoader(samples, batch_size=1)
    return train_loader, test_loader

@mlflow_mixin
def run_one_training(model_config, train_samples, test_samples, checkpoint_dir):
    print(f"creating model {model_config=}")
    params = model_config
    run_id = model_config['mlflow']['tags']['mlflow.parentRunId']
    mlflow.set_tag("mlflow.parentRunId", run_id)
    # Logging params
    for key in params.keys():
        mlflow.log_param(key, params[key])

    # Preparing training
    train_loader = GeoDataLoader(train_samples, batch_size=params['batch_size'])
    test_loader = GeoDataLoader(test_samples, batch_size=1)
    
    # Loading the model
    print("Loading model...")
    model = HeteroGCNModel(model_config=params)
    model = to_hetero(model, train_samples[0].metadata(), aggr=params['aggr_strategy'])
    model = model.to(device)

    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)

    # < 1 increases precision, > 1 recall
    # loss_fn = torch.nn.MSELoss(reduction='none')
    loss_fn = torch.nn.MSELoss()
    # we need to keep the lr quite low since otherwise the weights explode
    optimizer = torch.optim.SGD(
        model.parameters(), 
        lr=params['learning_rate'],
        momentum=params['sgd_momentum'],
        # weight_decay=5e-4
    )
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=params['scheduler_gamma'])
    
    # Start training
    best_loss = 1000
    early_stopping_counter = 0
    max_epochs = 300
    for epoch in tqdm(range(max_epochs)): 
        if early_stopping_counter <= 25: # = x * 5 
            # Training
            model.train()
            loss = train_one_epoch(epoch, model, train_loader, optimizer, loss_fn)
            # print(f"Epoch {epoch} | Train Loss {loss}")
            mlflow.log_metric(key="Train loss", value=float(loss), step=epoch)
            
            # Testing
            model.eval()
            if epoch % 5 == 0 or epoch == max_epochs - 1:
                loss = test(epoch, model, test_loader, loss_fn)
                # print(f"Epoch {epoch} | Test Loss {loss}")
                mlflow.log_metric(key="Test loss", value=float(loss), step=epoch)
                
                # Update best loss
                if float(loss) < best_loss:
                    best_loss = loss
                    # Save the currently best model 
                    # mlflow.pytorch.log_model(model, "model", signature=SIGNATURE)
                    early_stopping_counter = 0
                else:
                    early_stopping_counter += 1

            scheduler.step()
            mlflow.log_metric(key="Learning rate", value=float(scheduler.get_last_lr()[0]), step=epoch)
            
        else:
            print("Early stopping due to no improvement.")
            session.report({
                "loss": best_loss
            })
            return {"loss": best_loss}
    print(f"Finishing training with best test loss: {best_loss}")

    with torch.no_grad():
        sample = train_samples[0].to(device)
        model.forward(sample.x_dict, sample.edge_index_dict)
        print(f"Number of parameters: {count_parameters(model)}")

    session.report({
        "loss": best_loss
    })

    mlflow.end_run()

    return {"loss": best_loss}

In [25]:
HYPERPARAMETERS = {
    "batch_size": tune.choice([2, 4, 8]),
    "learning_rate": tune.choice([0.1, 0.05, 0.01, 0.001]),
    "sgd_momentum": tune.choice([0.9, 0.8, 0.5]),
    "scheduler_gamma": tune.choice([0.995, 1]),
    "aggr_strategy": tune.choice(['mean', 'sum']),
    "model_embedding_size": tune.choice([8, 16, 32, 64, 128]),
    "model_attention_heads": tune.choice([1, 2, 3, 4]),
    "model_layers": tune.choice([1, 3, 5, 7]),
}

In [26]:
train_samples, test_samples = get_samples_hetero_graph(
    "pyr", Strategy.ALL, METABOLITES, graph_fc_all, edge_list_df_all, embedding_all)
print(len(train_samples))
print(len(test_samples))

print(f"{train_samples[0]['metabolites']['train_mask']=}")
print(f"{train_samples[0]['metabolites']['test_mask']=}")

print(f"{test_samples[0]['metabolites']['train_mask']=}")
print(f"{test_samples[0]['metabolites']['test_mask']=}")

66
29
train_samples[0]['metabolites']['train_mask']=tensor([1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 1.,
        0., 1., 1., 1., 0., 0.])
train_samples[0]['metabolites']['test_mask']=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
test_samples[0]['metabolites']['train_mask']=tensor([1., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 1.,
        0., 1., 1., 1., 0., 0.])
test_samples[0]['metabolites']['test_mask']=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0.])


In [27]:
def tune_metabolite_hyper_parameters(
    experiment_name, 
    train_samples, 
    test_samples,
    num_samples=10,
):

    parent_id = mlflow.active_run().info.run_id if mlflow.active_run() else None

    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=300,
        grace_period=10,
        reduction_factor=2
    )
    reporter = CLIReporter(
        metric_columns=["loss", "training_iteration"]
    )

    trainable = tune.with_parameters(
        run_one_training, 
        train_samples=train_samples,
        test_samples=test_samples,
        checkpoint_dir=None,
    )

    result = tune.run(
        trainable,
        config={
            **HYPERPARAMETERS,
            "mlflow": {
                "experiment_name": experiment_name,
                "tracking_uri": mlflow.get_tracking_uri(),
                "save_artifacts": True,
                "tags": {
                    "mlflow.parentRunId": parent_id
                }
            }
        },
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
    )
    return result

# tune_metabolite_hyper_parameters("pyr")

In [29]:
for strategy in [Strategy.ALL, Strategy.ONE_VS_ALL, Strategy.METABOLITE_CENTRIC]:
    for mode, graph_fc in [('all', graph_fc_all)]:
        run_name = "model_gat_node_embeddings"
        experiment_name = f'metabolite_gnn_sweep_full_{mode}'
        mlflow.set_experiment(experiment_name)
        with mlflow.start_run(run_name=run_name) as run:
            for metabolite_id in list(set(graph_fc.columns.difference(ENZYMES).to_list()) & set(PRECURSOR_METABOLITES)):
                print(f"training {metabolite_id=}")

                if mode == 'unfiltered':
                    train_samples, test_samples = get_samples_hetero_graph(metabolite_id, strategy, PRECURSOR_METABOLITES_NO_TRANSFORM, graph_fc_unfiltered, edge_list_df_unfiltered, embedding_unfiltered)
                elif mode == 'strict':
                    train_samples, test_samples = get_samples_hetero_graph(metabolite_id, strategy, PRECURSOR_METABOLITES, graph_fc_strict, edge_list_df_strict, embedding_strict)
                elif mode == 'all':
                    train_samples, test_samples = get_samples_hetero_graph(metabolite_id, strategy, METABOLITES, graph_fc_all, edge_list_df_all, embedding_all)

                with mlflow.start_run(run_name=f"model_{metabolite_id}_{strategy}", nested=True):
                    result = tune_metabolite_hyper_parameters(experiment_name, train_samples, test_samples, num_samples=16)
                    print(result)

The history saving thread hit an unexpected error (OperationalError('database or disk is full')).History will not be written to the database.


KeyboardInterrupt: 

[2m[33m(raylet)[0m [2022-12-13 19:49:12,029 E 7059 7106] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2022-12-13_17-18-41_527887_16466 is over 95% full, available space: 9530351616; capacity: 269490393088. Object creation will fail if spilling is required.


# Fetch MLflow results

In [112]:
import mlflow
from mlflow.entities import ViewType, Run
from mlflow import MlflowClient

# MlflowClient().search_experiments(

# )

# run = MlflowClient().search_runs(
#   experiment_ids="0",
#   filter_string="",
#   run_view_type=ViewType.ACTIVE_ONLY,
#   max_results=1,
#   order_by=["metrics.accuracy DESC"]
# )[0]

modes = ['unfiltered', 'strict', 'all']

experiments = []

for mode in modes:
    experiment_query = f"name = 'metabolite_gnn_sweep_full_{mode}'"
    experiments.append([exp.experiment_id for exp in mlflow.search_experiments(max_results=1, filter_string=experiment_query)][0])
print(experiments)

run_query = f"tags.mlflow.runName = 'model_gat_node_embeddings'"
newest_root_runs = mlflow.search_runs(experiment_ids=experiments, max_results=3, filter_string=run_query)['run_id'].values.tolist()

optimal_runs = []

for newest_root_run, mode in zip(newest_root_runs, modes):
    run_query = f"tags.mlflow.parentRunId = '{newest_root_run}'"
    metabolite_runs: List[Run] = mlflow.search_runs(experiment_ids=experiments, filter_string=run_query, output_format='list')

    for metabolite_run in metabolite_runs:
        metabolite_id = metabolite_run.data.tags['mlflow.runName'].strip('model_')
        run_query = f"tags.mlflow.parentRunId = '{metabolite_run.info.run_id}'"
        runs = mlflow.search_runs(experiment_ids=experiments, filter_string=run_query)

        if 'metrics.R2 score' not in runs.columns:
            continue

        # print(runs)
        runs = runs.assign(
            metabolite_id=metabolite_id,
            mode=f"metabolite_gnn_sweep_full_{mode}",
            strategy=Strategy.ONE_VS_ALL
        ).fillna(0.0)
        optimal_runs.append(runs.loc[runs['metrics.R2 score'].idxmax()].to_frame())

result = pd.concat(optimal_runs, axis=1).T \
    .drop(
        ['experiment_id', 'status', 'artifact_uri', 'start_time', 'end_time', 'tags.mlflow.source.git.commit', 'tags.mlflow.source.type', 'tags.mlflow.user', 'params.mlflow', 'tags.mlflow.runName', 'tags.mlflow.parentRunId', 'tags.mlflow.source.name'], 
        axis=1
    )
result

['405896940355503870', '624325160207995459', '128972388166220243']


Unnamed: 0,run_id,metrics.Learning rate,metrics.Train loss,metrics.Mean absolute error,metrics.Test loss,metrics.R2 score,params.model_layers,params.model_embedding_size,params.model_attention_heads,params.scheduler_gamma,params.learning_rate,params.batch_size,params.sgd_momentum,metabolite_id,mode,strategy
0,145d754f688d43c993dc3de9226b3e61,0.001,0.303672,0.025652,0.000658,0.0,5,32,3,1.0,0.001,4,0.5,succ,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
0,55ebe439866540bea3d1ca44161ab90a,0.001342,0.304054,0.022022,0.000485,0.0,5,8,3,0.8,0.01,2,0.9,s7p,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
3,837d12241bf149c2941457dbf8ba106a,0.0,0.303434,0.731192,0.995041,0.223131,1,8,4,0.5,0.1,8,0.5,r5p,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
9,44478193acd247928ead8798900e74dd,2e-06,0.305192,0.792403,1.000494,0.137848,1,8,4,0.5,0.001,2,0.8,pyr,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
8,5a0434e742ee4574a3ffae3d9d59f294,0.1,0.240636,0.609749,0.68381,0.632126,1,32,3,1.0,0.1,8,0.5,pep,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
0,b7b7a3bc14b94bdeb644ac6789b4f144,0.0,0.30296,0.01256,0.000246,0.0,1,64,4,0.9,0.01,8,0.5,rn,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
10,b9e5baa1819c4a808562253bb4f61dc3,0.0,0.303432,0.143993,0.189474,0.127368,5,16,4,0.5,0.05,8,0.8,aa,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
0,7fb1b075939646c5828cf598cc1dba77,2e-06,0.309702,0.080201,0.006433,0.0,3,64,1,0.5,0.001,4,0.5,gly,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
9,80c9ac0c7b27478e8f145152093720c0,0.05,0.294254,0.189109,0.186903,0.167754,3,8,3,1.0,0.05,8,0.9,g6p;g6p-B,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL
7,457768291ce04d8497cac0e032389cf7,0.002234,0.263649,0.630982,0.772137,0.628208,1,128,2,0.995,0.01,4,0.9,f6p,metabolite_gnn_sweep_full_unfiltered,Strategy.ONE_VS_ALL


In [110]:
pd.concat(optimal_runs, axis=1).T.drop(['experiment_id', 'status', 'artifact_uri', 'start_time', 'end_time', 'tags.mlflow.source.git.commit', 'tags.mlflow.source.type', 'tags.mlflow.user', 'params.mlflow', 'tags.mlflow.runName', 'tags.mlflow.parentRunId', 'tags.mlflow.source.name'], axis=1)

Unnamed: 0,run_id,metrics.Learning rate,metrics.Train loss,metrics.Mean absolute error,metrics.Test loss,metrics.R2 score,params.model_layers,params.model_embedding_size,params.model_attention_heads,params.scheduler_gamma,params.learning_rate,params.batch_size,params.sgd_momentum,metabolite_id,mode,strategy
0,145d754f688d43c993dc3de9226b3e61,0.001,0.303672,0.025652,0.000658,0.0,5,32,3,1.0,0.001,4,0.5,succ,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
0,55ebe439866540bea3d1ca44161ab90a,0.001342,0.304054,0.022022,0.000485,0.0,5,8,3,0.8,0.01,2,0.9,s7p,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
3,837d12241bf149c2941457dbf8ba106a,0.0,0.303434,0.731192,0.995041,0.223131,1,8,4,0.5,0.1,8,0.5,r5p,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
9,44478193acd247928ead8798900e74dd,2e-06,0.305192,0.792403,1.000494,0.137848,1,8,4,0.5,0.001,2,0.8,pyr,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
8,5a0434e742ee4574a3ffae3d9d59f294,0.1,0.240636,0.609749,0.68381,0.632126,1,32,3,1.0,0.1,8,0.5,pep,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
0,b7b7a3bc14b94bdeb644ac6789b4f144,0.0,0.30296,0.01256,0.000246,0.0,1,64,4,0.9,0.01,8,0.5,rn,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
10,b9e5baa1819c4a808562253bb4f61dc3,0.0,0.303432,0.143993,0.189474,0.127368,5,16,4,0.5,0.05,8,0.8,aa,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
0,7fb1b075939646c5828cf598cc1dba77,2e-06,0.309702,0.080201,0.006433,0.0,3,64,1,0.5,0.001,4,0.5,gly,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
9,80c9ac0c7b27478e8f145152093720c0,0.05,0.294254,0.189109,0.186903,0.167754,3,8,3,1.0,0.05,8,0.9,g6p;g6p-B,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
7,457768291ce04d8497cac0e032389cf7,0.002234,0.263649,0.630982,0.772137,0.628208,1,128,2,0.995,0.01,4,0.9,f6p,metabolite_gnn_sweep_full_all,Strategy.ONE_VS_ALL
