In [None]:
%matplotlib inline
from __future__ import absolute_import, division, print_function, unicode_literals
from builtins import range

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns; sns.set_style("white")

from functools import partial
from joblib import delayed, Parallel
from mmit import MaxMarginIntervalTree
from mmit.core.solver import compute_optimal_costs
from mmit.metrics import mean_squared_error, zero_one_loss
from mmit.model import TreeExporter
from mmit.model_selection import GridSearchCV
from os import listdir, mkdir, system
from os.path import abspath, basename, exists, join
from shutil import rmtree as rmdir
from time import time

In [None]:
class Dataset(object):
    def __init__(self, path):
        self.path = path
        feature_data = pd.read_csv(join(path, "features.csv"))
        self.X = feature_data.values
        self.feature_names = feature_data.columns.values
        del feature_data
        self.y = pd.read_csv(join(path, "targets.csv")).values
        self.folds = pd.read_csv(join(path, "folds.csv")).values.reshape(-1, )
        self.name = basename(path)
    
    @property
    def n_examples(self):
        return self.X.shape[0]
    
    @property
    def n_features(self):
        return self.X.shape[1]
    
def find_datasets(path):
    for d in listdir(path):
        if exists(join(path, d, "features.csv")) and \
           exists(join(path, d, "targets.csv")) and \
           exists(join(path, d, "folds.csv")):
            yield Dataset(abspath(join(path, d)))

datasets = list(find_datasets("./data"))

In [None]:
def evaluate_on_dataset(d, metric, result_dir):
    start_time = time()
    
    ds_result_dir = join(result_dir, d.name)
    if not exists(ds_result_dir):
        mkdir(ds_result_dir)
    
    fold_models = []
    fold_predictions = np.zeros(d.n_examples)
    for fold in np.unique(d.folds):
        fold_train = d.folds != fold
        
        X_train = d.X[fold_train]
        y_train = d.y[fold_train]
        X_test = d.X[~fold_train]
        y_test = d.y[~fold_train]
        
        cv = GridSearchCV(estimator=MaxMarginIntervalTree(), param_grid=params, cv=10, n_jobs=-1, 
                          scoring=metric)
        cv.fit(X_train, y_train, d.feature_names)
        fold_predictions[~fold_train] = cv.predict(X_test)
        fold_models.append(cv.best_estimator_)
    print("MSE:", mean_squared_error(d.y, fold_predictions))
    print("ACC:", 1.0 - zero_one_loss(d.y, fold_predictions))
    open(join(ds_result_dir, "predictions.csv"), "w").write("\n".join(str(x) for x in fold_predictions))
    
    latex_exporter = TreeExporter("latex")
    string_exporter = TreeExporter("string")
    f_models = open(join(ds_result_dir, "models.tsv"), "w")
    for i, m in enumerate(fold_models):
        open(join(ds_result_dir, "model_fold_{0:d}.tex".format(i + 1)), "w").write(latex_exporter(m))
        f_models.write("{0:d}\t{1!s}\n".format((i + 1), string_exporter(m)))
    f_models.close()
    print("Took", time() - start_time, "seconds.")
    
    # Generate the PDF file for each tree
    build_cmd = "cd {0!s}; for i in ./model_fold_*.tex; do lualatex $i > /dev/null; rm ./*.aux ./*.log;done".format(ds_result_dir)
    !$build_cmd

In [None]:
params = {"max_depth": [1000],
          "min_samples_split": [0],
          "margin": np.logspace(-3, 0, 20)}

def prep_result_dir(result_dir):
    if exists(result_dir):
       rmdir(result_dir)
    mkdir(result_dir)

def mse_metric(estimator, X, y):
    """
    Negative mean squared error, since GridSearchCV maximizes a metric
    """
    return -mean_squared_error(y_pred=estimator.predict(X), y_true=y)

params["loss"] = ["hinge"]
result_dir = "./predictions/mmit.linear.hinge"
prep_result_dir(result_dir)
for d in datasets:
    print(d.name)
    evaluate_on_dataset(d, mse_metric, result_dir)
    print()
    
params["loss"] = ["squared_hinge"]
result_dir = "./predictions/mmit.squared.hinge"
prep_result_dir(result_dir)
for d in datasets:
    print(d.name)
    evaluate_on_dataset(d, mse_metric, result_dir)
    print()