In [5]:
import torch
import optuna
from torch import nn as nn
from pathlib import Path
import mlflow
from optuna.integration.mlflow import MLflowCallback

In [None]:
from esmhotprot.trunk import FoldingTrunk, FoldingTrunkConfig
from esmhotprot.esmfold import ESMFoldConfig

In [None]:
def prep_minimal_dataset(root_path: Path):
    pass
prep_minimal_dataset(Path('datasets/FLIP/'))

In [None]:
def get_minimal_dataset():
    pass

In [None]:
def build_model(
    learning_rate: float, 
    dropout_rate: float, 
    # batchnormalization: bool, 
    hidden_units: int, 
    # activation_function, 
    esmfold_config=None
    ):
    cfg = esmfold_config if esmfold_config else ESMFoldConfig()
    
    trunk = FoldingTrunk()

    thermo_predictor = nn.Sequential(
        nn.LayerNorm(cfg.structure_module.c_s),
        nn.Linear(cfg.trunk.structure_module.c_s, hidden_units),
        nn.Dropout(dropout_rate),
        nn.Linear(hidden_units, hidden_units),
        nn.Dropout(dropout_rate),
        nn.Linear(hidden_units, 1),        
    )

    # TODO Define model class concatenating trunk and thermo_predictor and defining custom forward pass
    # Like ESMFold...
    model = thermo_predictor
    return model

In [None]:
YOUR_TRACKING_URI = "http://127.0.0.1:5000"
mlflc = MLflowCallback(
    tracking_uri=YOUR_TRACKING_URI,
    metric_name="metric_score"
)

def optimize_thermostability(trial):
    X_train, X_test, Y_train, Y_test = get_minimal_dataset()

    params = {
        'model_learning_rate', trial.suggest_float('model_learning_rate', 0.025, 2, step=0.025),
        'model_dropout_rate', trial.suggest_float('model_dropout_rate', 0.0, 0.5, step=0.1),
        'model_hidden_units', trial.suggest_int('model_hidden_units', 100, 1000, step=100),
    }

    model = build_model(
        learning_rate=params['model_learning_rate'],
        dropout_rate=params['model_dropout_rate'],
        hidden_units=params['model_hidden_units'],
        )

    # TODO training of model
    model.fit(X_train, Y_train)

    mlflow.log_params(params)
    eval_data = X_test
    eval_data["label"] = Y_test
    candidate_model_uri = mlflow.pytorch.log_model(model).model_uri
    mlflow.evaluate(model=candidate_model_uri, data=eval_data, targets="label", model_type="regressor")

    # TODO evaluation of model
    score = model.evaluate()
    mlflow.log_metric("score", score)
    return score



    
