# train

> Fill in a module description here

In [1]:
#| export
from pathlib import Path
from joblib import Parallel, delayed
import copy
import numpy as np

from mokapot.model import PercolatorModel
from mokapot.brew import _fit_model

from misugaru.core import *
from misugaru.data import Data

In [2]:
RNG = np.random.default_rng(42)
N_FOLDS = 3

def train_models(
    data: Data,
    subset_max_train: int,
) -> [PercolatorModel, PercolatorModel, PercolatorModel]:  
    
    model = PercolatorModel()
    model.rng = RNG
    psm_folds = data.get_train_psms_splits(n_subset=subset_max_train)
    
    # fitted looks like so:
    # `[(model_a: PercolatorModel, reset_a: bool), (model_b, reset_b), (model_c, reset_c)]`
    fitted = Parallel(n_jobs=N_FOLDS, require="sharedmem")(
        delayed(_fit_model)(x, [data.psms], copy.deepcopy(model), i)
        for i, x in enumerate(psm_folds)
    )
    # sort by fold for a deterministic order
    models, needs_reset = zip(*sorted(fitted, key=lambda x: x[0].fold))
    if any(needs_reset):
        # TODO: create specific exception
        raise ValueError("Model training failed")
    return models

In [3]:
# Usage
path = Path("~/repos/matcha/data/10k_psms_test.parquet")
data = Data(path)

models = train_models(data, subset_max_train=10000)

In [4]:
assert len(models) == N_FOLDS

In [5]:
models[0]

A trained mokapot.model.Model object:
	estimator: LinearSVC(class_weight={0: 10, 1: 10}, dual=False, random_state=7)
	scaler: StandardScaler()
	features: ['Mass', 'MS8_feature_5', 'missedCleavages', 'MS8_feature_7', 'MS8_feature_13', 'MS8_feature_20', 'MS8_feature_21', 'MS8_feature_22', 'MS8_feature_24', 'MS8_feature_29', 'MS8_feature_30', 'MS8_feature_32']