# GP datasets with zero noise_var + NNs (2)

- on these GP datasets, the measures _do not_ increase with an increasing number of samples. But they _do_ increase on MNIST! Why is this?
  - GP datasets are too easy?
  - dimensionality?

In [None]:
import os
import sys
# If we don't need CUDA, do this before importing TF
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import numpy as np
import pandas as pd
import tqdm
import tqdm.notebook
import scipy.stats
import matplotlib.pyplot as plt
import seaborn as sns
import IPython
import GPy
sns.set()

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    tf.config.experimental.set_visible_devices([gpus[1]], 'GPU')

os.chdir("/nfs/scistore12/chlgrp/vvolhejn/smooth/logs/0306_gp_nn_noiseless/")

In [None]:
%load_ext autoreload
%aimport smooth.datasets
%aimport smooth.model
%aimport smooth.analysis
%aimport smooth.callbacks
%aimport smooth.measures
%aimport smooth.util
%autoreload 1

### Measures of shallow relu neural networks

In [None]:
ms_nn = pd.read_feather("measures.feather")
# ms_nn = smooth.analysis.expand_dataset_columns(ms_nn)
print("Before removal:", len(ms_nn))
# ms_nn = ms_nn.loc[np.isfinite(ms_nn["path_length_f_test"])]
ms_nn = ms_nn.loc[ms_nn["error"].isnull()]
print("After removal:", len(ms_nn))

ms_nn["lengthscale"] = ms_nn["dim"] * ms_nn["lengthscale_coef"]
# smooth.analysis.remove_constant_columns(ms_nn, verbose=True, to_keep=["dataset.name", "seed", "disjoint"])

ms_nn.head()

In [None]:
def get_unique_datasets(ms):
    dataset_cols = ["dataset.name", "dim", "seed", "lengthscale_coef", "samples_train", "noise_var", "disjoint"]
    datasets = ms.loc[:, dataset_cols]
    
    def strip(s, prefix):
        if s.startswith(prefix):
            return s[len(prefix):]
        else:
            return s
    
    renamed_cols = [strip(x, "dataset.") for x in dataset_cols]
    renaming = dict(zip(dataset_cols, renamed_cols))
    return (datasets
            .rename(columns=renaming)
            .drop_duplicates()
            .sort_values(renamed_cols)
            .reset_index(drop=True))

In [None]:
datasets

In [None]:
datasets[:1]

In [None]:
pd.concat([smooth.analysis.get_gp_measures(
        datasets[15:16],
        from_params=True,
        kernel_f=GPy.kern.Matern32,
        lengthscale_coef=1.,
    ),
           smooth.analysis.get_gp_measures(
        datasets[15:16],
        from_params=True,
        kernel_f=GPy.kern.Matern32,
        lengthscale_coef=0.1,
    ),
          ])

In [None]:
datasets

In [None]:
datasets = list(get_unique_datasets(ms_nn).to_dict("index").values())
datasets2 = get_unique_datasets(ms_nn)
datasets2 = list(datasets2[datasets2["lengthscale_coef"] == 1.0].to_dict("index").values())

ms_gp = smooth.analysis.compute_or_load_df(
    lambda: smooth.analysis.get_gp_measures(datasets, from_params=True),
    "measures_gp.feather",
    always_compute=False,
)

def f():
    l = []
    for lsc in [0.1, 1., 10., 100.]:
        cur = smooth.analysis.get_gp_measures(
            datasets2,
            from_params=True,
            kernel_f=GPy.kern.Matern32,
            lengthscale_coef=lsc,
        )
        cur["lengthscale_coef"] = lsc
        l.append(cur)
    
    return pd.concat(l).reset_index(drop=True)

ms_gp_m32 = smooth.analysis.compute_or_load_df(
    f,
    "measures_gp_matern32.feather",
    always_compute=True,
)

# ms_gp_m52 = smooth.analysis.compute_or_load_df(
#     lambda: smooth.analysis.get_gp_measures(datasets, from_params=True, kernel_f=GPy.kern.Matern52),
#     "measures_gp_matern52.feather",
#     always_compute=False,
# )

In [None]:
smooth.analysis.remove_constant_columns(ms_nn, verbose=True, to_keep=[])

In [None]:
measure_cols = [
    "loss_train", "loss_test",
    "path_length_f_train", "path_length_f_test",
    "path_length_d_train", "path_length_d_test",
    "weights_rms",
]

for measure in measure_cols:
    IPython.display.display(IPython.display.Markdown("### {}".format(measure)))
    if True:
        ms1 = ms_nn[(ms_nn["hidden_size"] == 64)
                    & (ms_nn["lengthscale_coef"] == 0.3)]
#         ms1 = ms_nn[(ms_nn["init_scale"] == 10.)]
        grid = sns.relplot(
            data=ms1,
            x="samples_train",
            y=measure,
            hue="init_scale",
#             style="learning_rate",
            col="dim",
            col_wrap=3,
            kind="line",
            palette=smooth.analysis.make_palette(ms1["init_scale"].unique()),
        )
        ax = grid.axes[0] #[0]

        ax.set_xscale("log")
        if "loss" in measure or True:
            ax.set_yscale("log")
        plt.show()

In [None]:
def add_normalized_cols(ms):
    ms["plc"] = ms["path_length_f_test"] / ms["path_length_f_test_bound"]
    ms["plct"] = ms["path_length_f_train"] / ms["path_length_f_train_bound"]
    return ms

#add_normalized_cols(ms_gp)

ms_bound = (ms_gp
    .drop(columns=["path_length_f_test", "path_length_f_train", "loss_train", "loss_test"])
    .rename(columns={
        "path_length_f_test_bound": "path_length_f_test",
        "path_length_f_train_bound": "path_length_f_train",
    })
)

In [None]:
ms_nn.iloc[0]

In [None]:
def plot_compare(groups, filter_f: None):
    filter_f = filter_f or (lambda df: df)
    l = []
    for group_name, group in groups:
        for name, ms_cur in group:
            ms_cur = ms_cur.copy()
            ms_cur.loc[:, "source"] = name
            ms_cur.loc[:, "group"] = group_name
            l.append(ms_cur)

    ms_all = pd.concat(l, sort=False)
    ms_all = filter_f(ms_all)
    
    for measure in ["loss_train", "loss_test", "path_length_f_test", "path_length_f_train"]:
        grid = sns.relplot(
            data=ms_all,
            x="samples_train",
            y=measure,
            hue="source",
            style="group",
            col="dim",
            col_wrap=2,
            kind="line",
        )
        ax = grid.axes[0]#[0]
        ax.set_xscale("log")
        if measure in ["loss_train", "loss_test",
                      "path_length_f_train", "path_length_f_test",
                      ]:
            ax.set_yscale("log")
#         if measure in ["path_length_f"]:
#             ax.set_ylim(0.03, 30)
        plt.show()

nn_group = []
for init in sorted(ms_nn["init_scale"].unique()):
#     for lr in sorted(ms_nn["learning_rate"].unique()):
    for lr in [(0.003 / init).round(5)]:
#     lr = (0.01 / init).round(5)
        nn_group.append((
            "nn, lr={:.1e}, is={:.1e}".format(lr, init),
            ms_nn.loc[
                (ms_nn["hidden_size"] == 64) &
                (ms_nn["init_scale"] == init) &
                (ms_nn["learning_rate"] == lr)
            ],
        ))
    break

nn_group.append(("gp", ms_gp))
nn_group.append(("bound", ms_bound))

def filter_f(ms):
    return ms.loc[
#         (ms["dim"] == dim)
#         & (ms_all["seed"] == 1)
        (ms["lengthscale"] == ms["dim"])
        & (ms["disjoint"] == 1)
#         & (ms["dim"] <= 512)
    ]

for dim in sorted(ms_nn["dim"].unique())[:1]:
    display(IPython.display.Markdown("### dim = {}".format(dim)))
    plot_compare(
        [
            ("nn", nn_group),
        ],
        filter_f,
    )

In [None]:
def plot_compare(groups, filter_f: None):
    filter_f = filter_f or (lambda df: df)
    l = []
    for group_name, group in groups:
        for name, ms_cur in group:
            ms_cur = ms_cur.copy()
            ms_cur.loc[:, "source"] = name
            ms_cur.loc[:, "group"] = group_name
            l.append(ms_cur)

    ms_all = pd.concat(l, sort=False)
    ms_all = filter_f(ms_all)
    
    for measure in ["loss_train", "loss_test", "path_length_f_test", "path_length_f_train"]:
        grid = sns.relplot(
            data=ms_all,
            x="samples_train",
            y=measure,
            hue="source",
            style="group",
            col="dim",
            col_wrap=2,
            kind="line",
        )
        ax = grid.axes[0]#[0]
        ax.set_xscale("log")
        if measure in ["loss_train", "loss_test",
                      "path_length_f_train", "path_length_f_test",
                      ]:
            ax.set_yscale("log")
#         if measure in ["path_length_f"]:
#             ax.set_ylim(0.03, 30)
        plt.show()

nn_group = []
for init in sorted(ms_nn["init_scale"].unique()):
#     for lr in sorted(ms_nn["learning_rate"].unique()):
    for lr in [(0.003 / init).round(5)]:
#     lr = (0.01 / init).round(5)
        nn_group.append((
            "nn, lr={:.3f}, is={:.3f}".format(lr, init),
            ms_nn.loc[
                (ms_nn["hidden_size"] == 256) &
                (ms_nn["init_scale"] == init) &
                (ms_nn["learning_rate"] == lr)
            ],
        ))

nn_group.append(("bound", ms_bound))
nn_group.append(("gp", ms_gp))

gp_group = []
for lsc in sorted(ms_gp_m32["lengthscale_coef"]):
    gp_group.append(("gp m32, lsc={}".format(lsc), ms_gp_m32.loc[ms_gp_m32["lengthscale_coef"] == lsc]))

def filter_f(ms):
    return ms.loc[
#         (ms["dim"] == dim)
#         & (ms_all["seed"] == 1)
        (ms["lengthscale"] == ms["dim"])
        & (ms["disjoint"] == 1)
#         & (ms["dim"] <= 512)
    ]

for dim in sorted(ms_nn["dim"].unique())[:1]:
    display(IPython.display.Markdown("### dim = {}".format(dim)))
    plot_compare(
        [
            ("nn", nn_group),
            ("gp", gp_group),
        ],
        filter_f,
    )

In [None]:
mnist = smooth.datasets.get_mnist()
mnist.y_train % 2 * 2 - 1

In [None]:
def f(a):
    
    class C:
        def __init__(self, b):
            self.val = a + b
    
    return C

In [None]:
f(3)(5).val

In [None]:
dataset = smooth.datasets.MnistLightnessDataset(10)

In [None]:
dataset = smooth.datasets.MnistParityDataset(10)

In [None]:
dataset.y_test

In [None]:
sorted(ms_nn["hidden_size"].unique())