In [None]:
import pandas as pd
from utils.error_metrics import MulticlassErrorMetrics, DatasetCategory

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

In [None]:
df_train = pd.read_csv("../data/processed/SAIS_train.csv").drop("Date", axis=1)
df_dev = pd.read_csv("../data/processed/SAIS_dev.csv").drop("Date", axis=1)
df_test = pd.read_csv("../data/processed/SAIS_test.csv").drop("Date", axis=1)
most_frequent_value = df_train["mapped_hazard_forecast"].value_counts().index[0]
sets = {
    DatasetCategory.TRAINING: df_train,
    DatasetCategory.DEVELOPMENT: df_dev,
    DatasetCategory.TEST: df_test,
}

errs_summary = MulticlassErrorMetrics(
    dataset_name="sais",
    classes=sorted(df_train["mapped_hazard_forecast"].unique()),
    y_true_train=df_train["mapped_hazard_forecast"],
    y_true_dev=df_dev["mapped_hazard_forecast"],
    y_true_test=df_test["mapped_hazard_forecast"],
)

X_train = df_train.drop(
    columns=["mapped_hazard_forecast", "mapped_hazard_observed", "Area"]
)
y_train = df_train["mapped_hazard_forecast"]

pipeline = Pipeline(
    [("scaler", StandardScaler()), ("model", LogisticRegression(random_state=1))]
)

pipeline.fit(X_train, y_train)

errs_summary.compute_errors_all_sets(
    "constant", most_frequent_value, most_frequent_value, most_frequent_value
)
errs_summary.compute_errors_all_sets(
    "observed",
    df_train["mapped_hazard_observed"],
    df_dev["mapped_hazard_observed"],
    df_test["mapped_hazard_observed"],
)

for ds_type, ds in sets.items():
    y_pred = pipeline.predict(
        ds.drop(columns=["mapped_hazard_forecast", "mapped_hazard_observed", "Area"])
    )
    errs_summary.compute_errors("softmax regression", ds_type, y_pred)

errs_summary.save_assets()

In [3]:
# TODO:

#  - rerun previous models

df_train = pd.read_csv("../data/processed/SAIS_train.csv", index_col=0)
df_dev = pd.read_csv("../data/processed/SAIS_dev.csv", index_col=0)
df_combined = pd.concat([df_train, df_dev], ignore_index=False)