In [1]:
import os
os.environ['PYTHONHASHSEED'] = '0'
import random as rn
rn.seed(12345)
from numpy.random import seed
seed(42)

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from dipy.io.gradients import read_bvals_bvecs
import pickle

import yaml

from hashlib import md5

import warnings
warnings.filterwarnings('ignore')

In [2]:
def train_model(config):
    
    sample_path = os.path.join(config["sample_dir"], "samples.npz")
    
    hasher = md5()
    for v in config.values():
        hasher.update(str(v).encode())
    
    save_dir = os.path.join("..", "models", config["model_name"], hasher.hexdigest())
    
    if os.path.exists(save_dir):
        print("This model config has been trained already:\n{}".format(save_dir))
        return
    
    
    samples = np.load(os.path.join(config["sample_dir"], "samples.npz"))    
    inputs = samples["inputs"]
    outputs = samples["outputs"]
    inputs = inputs[:min(config["max_n_samples"], len(inputs))]
    outputs = outputs[:min(config["max_n_samples"], len(outputs))]
    
    _, bvecs = read_bvals_bvecs(None, config["bvecs"])
    
    output_classes = np.array([np.argmax([np.dot(base_vec, outvec) for base_vec in bvecs]) for outvec in outputs])
        
    clf = RandomForestClassifier(n_estimators=config["n_estimators"], 
                                 max_depth=config["max_depth"], 
                                 random_state=config["random_state"])
    clf.fit(inputs, output_classes) 
    
    os.makedirs(save_dir, exist_ok=True)
    
    config_path = os.path.join(save_dir, "config" + ".yml")
    print("Saving {}".format(config_path))
    with open(config_path, "w") as file:
        yaml.dump(config, file, default_flow_style=False)
                
                
    model_path = os.path.join(save_dir, 'model')
    print("Saving {}".format(model_path))
    with open(model_path, 'wb') as f:
        pickle.dump(clf, f)
    
    return clf, inputs, outputs, output_classes

In [3]:
config = dict(
    model_name="rf_model",
    sample_dir="../subjects/992774/samples/cd3f586ae51a17b1ca3fcaebfcd1484f",
    bvecs="../subjects/992774/bvecs_input",
    max_n_samples=1000,
    n_estimators=100,
    max_depth=25,
    random_state=0
)

In [4]:
clf, inputs, outputs, output_classes = train_model(config)

Saving ../models/rf_model/1f52753195c734dda3b10ec006b03f7c/config.yml
Saving ../models/rf_model/1f52753195c734dda3b10ec006b03f7c/model
