# 4.1. load train/test data

In [None]:
import os
import pandas as pd

import wandb
from utils.wandb_utils import wandb_log_artifact, get_wandb_artifact

from ml_src.data import preprocess
from ml_src.model import train_model, inference, compute_model_metrics

In [None]:
run = wandb.init(project='ex_census_wandb', job_type='model training')

In [None]:
_, train_file_path = get_wandb_artifact(run,
                               artifact_name="census_split", 
                               file_name="census_train.csv", 
                               tag="latest", 
                               artifact_type='data')

_, test_file_path = get_wandb_artifact(run,
                               artifact_name="census_split", 
                               file_name="census_test.csv", 
                               tag="latest", 
                               artifact_type='data')

print(train_file_path)
print(test_file_path)

train_df = pd.read_csv(train_file_path, sep='\t', encoding='utf-8')
test_df = pd.read_csv(test_file_path, sep='\t', encoding='utf-8')

# 4.2. load feature engineering artifacts (encoder, label binarizer)

In [None]:
from joblib import load

In [None]:
encoder_artifact, encoder_path = get_wandb_artifact(run,
                               artifact_name="feature_engineering", 
                               file_name="encoder.joblib", 
                               tag="latest", 
                               artifact_type='feature_engineering_artifact')

In [None]:
lb_artifact, lb_path = get_wandb_artifact(run,
                               artifact_name="feature_engineering", 
                               file_name="lb.joblib", 
                               tag="latest", 
                               artifact_type='feature_engineering_artifact')

In [None]:
encoder = load(encoder_path)
lb = load(lb_path)

In [None]:
cat_features = encoder_artifact.metadata['categorical_feature']
label = encoder_artifact.metadata['label']

# 4.3. Train model

In [None]:
parameters = {
    "n_estimators": 500,
    "min_samples_split": 3,
    "min_samples_leaf": 2,
    "max_features": "sqrt",
    "max_depth": 200,
    "criterion": "gini",
    "bootstrap": True,
}

In [None]:
X_train, y_train, encoder, lb = preprocess(train_df, categorical_features=cat_features, label=label, training=True)
X_test, y_test, _, _ = preprocess(test_df, categorical_features=cat_features, label=label, training=False, encoder=encoder, lb=lb)

In [None]:
model = train_model(X_train, y_train, params=parameters)

# 4.4. log model configs + parameters + performance

In [None]:
preds = inference(model, X_test)
precision, recall, fbeta = compute_model_metrics(y_test, preds)
print(precision, recall, fbeta)

In [None]:
metadata = {
    "categorical_feature": cat_features,
    "label" : label,
    "param": parameters,
    "train_data_path" : train_file_path,
    "test_data_path" : test_file_path,
    "encoder" : encoder_path,
    "lable_binarizer" : lb_path
}

In [None]:
run.config.update(metadata)

In [None]:
run.summary['precision'] = precision
run.summary['recall'] = recall
run.summary['fbeta'] = fbeta

# 4. 5. save model artifact

In [None]:
from joblib import dump

In [None]:
model_file_path = "./model/model.joblib"

dump(model, model_file_path)

In [None]:
wandb_log_artifact(run, "model", 
                   description="baseline RandomForest model", 
                   file_path=[model_file_path],
                   artifact_type="model_artifact",
                   remove_logged_file=True)

In [None]:
run.finish()