In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import statsmodels.formula.api as smf
from matplotlib import pyplot as plt

from alexlib.df import filter_df, get_distinct_col_vals
from alexlib.iters import idx_list

from etl.db_helpers import DbHelper


In [None]:
dbh = DbHelper()
dbh.generate_select_query("eval", "reg_lines")

In [None]:
abroca_file = "abroca_analysis.sql"
abroca_df = dbh.get_table("analysis", "abroca")
model_types = get_distinct_col_vals(abroca_df, "model_type")
abroca_df.head()

In [None]:
cols = [
    "is_female_abroca",
    "has_disability_abroca",
    "mean_test_roc_auc",
    "female_ratio",
    "disabled_ratio",
    "model_type",
]
df = abroca_df.loc[:, cols]

In [None]:
def get_attr_dict(obj: object):
    return {
        x: getattr(obj, x)
        for x in dir(obj)
        if (x[0] != "_" and x[:4] not in ["get_", "set_"])
    }


In [None]:
def polystr(var, n):
    return f"I({var} ** {n}.0)"


def mk_ols_str(
    xcol: str,
    ycol: str,
    n: int,
):
    formula = f"{ycol} ~ 1 + {xcol}"
    if n > 1:
        order_pols = [x for x in range(n + 1) if x > 1]
        order_strs = [polystr(xcol, x) for x in order_pols]
        order_strs.insert(0, formula)
        formula = " + ".join(order_strs)
    return formula

In [None]:
concat_lists = {}
records = []
polys = [2, 1]
abrocas = [x for x in cols if "abroca" in x]
ratios = [x for x in cols if "ratio" in x]
lp, la = len(polys), len(abrocas)
index = idx_list(lp, la)
wh = 8
fig, axs = plt.subplots(nrows=lp, ncols=la, figsize=(wh, wh))

for i, j in index:
    n = polys[i]
    y = abrocas[j]
    ratio = ratios[j]
    x = "mean_test_roc_auc" if n == 1 else ratio
    formula = mk_ols_str(x, y, n)
    # for fitdf in [df]:
    for mtype in model_types + ["all"]:
        if mtype in model_types:
            fitdf = filter_df(df, "model_type", mtype)
        else:
            fitdf = df
        polyfit = smf.ols(formula=formula, data=fitdf).fit()
        xlin = np.linspace(fitdf.loc[:, x].min(), fitdf.loc[:, x].max(), 100)
        xdf = pd.DataFrame.from_dict({x: xlin})
        ylin = polyfit.predict(xdf)
        ax = axs[i][j]
        sns.scatterplot(fitdf, x=x, y=y, ax=ax)
        ax.plot(xlin, ylin)
        record = {
            "poly": n,
            "x": x,
            "y": y,
            "model_type": mtype,
        }
        attr_dict = get_attr_dict(polyfit)
        conf_int = polyfit.conf_int().reset_index()
        conf_int["i"] = i
        conf_int["j"] = j
        try:
            concat_lists["conf_int"].append(conf_int)
        except KeyError:
            concat_lists["conf_int"] = [conf_int]
        print(conf_int)
        for key in [
            x
            for x in list(attr_dict.keys())
            if x
            not in [
                "HC0_se",
                "HC1_se",
                "HC2_se",
                "HC3_se",
                "resid",
                "wresid",
                "fittedvalues",
                "conf_int",
                "conf_int_el",
                "eigenvals",
                "el_test",
                "f_test",
                "info_criteria",
                "load",
                "save",
                "model",
            ]
        ]:
            val = attr_dict[key]
            _type = type(val)
            if _type in [str, float, int]:
                record[key] = val
            elif _type == np.float64:
                record[key] = float(val)
            elif _type == pd.Series:
                _df = pd.DataFrame(val)
                _df["i"] = i
                _df["j"] = j
                try:
                    concat_lists[key].append(_df)
                except KeyError:
                    concat_lists[key] = [_df]
            elif str(_type) == "" or _type == "":
                pass
            elif "test" in key:
                pass
            elif "use_" in key:
                pass
            elif "cov_" in key:
                pass
            elif "compare" in key:
                pass
            elif "summary" in key:
                pass
            elif "resid" in key:
                pass
            elif "remove" in key:
                pass
            elif "predict" in key:
                pass
            elif "initial" in key:
                pass
            else:
                raise ValueError(key, str(_type))
        records.append(record)
        """
        if (i==j and mtype=="all" and i == 1):
            attr_dict = get_attr_dict(polyfit)
            #save_keys = [x for x in save_keys if x not in ["bse", "params", "pvalues", "tvalues"]]
            save_keys = [x for x in save_keys if  type(attr_dict[x]) in [pd.Series]]#[str, int, float, np.float64, pd.Series]]
            #save_tkeys = [(x, type(attr_dict[x])) for x in list(attr_dict.keys()) if  type(attr_dict[x]) not in [str, int, float, np.float64, pd.Series]]
            #print(save_tkeys)
            save_dict = {x: attr_dict[x] for x in save_keys}
            for x in list(save_dict.keys()):

                print(x, str(save_dict[x].values.tolist()))
        #    print(get_attr_dict(polyfit))
        """


In [None]:
rec_df = pd.DataFrame.from_records(records)
rec_df.to_sql("reg_records", dbh.engine, schema="eval", if_exists="replace")
for listkey in list(concat_lists.keys()):
    tab = f"reg_{listkey}"
    _list = concat_lists[listkey]
    new_df = pd.concat(_list)
    print(new_df)
    new_df.to_sql(tab, dbh.engine, schema="eval", if_exists="replace")