# Checking init scale

Our initialization method in the previous experiments was somewhat non-standard (due to my mistake):

```python
model.add(
        layers.Dense(
            hidden_size,
            activation=activation,
            kernel_initializer=VarianceScaling(scale=init_scale, mode="fan_out"), # this
            bias_initializer=VarianceScaling(scale=init_scale, mode="fan_out"), # this
        )
    )
    model.add(
        layers.Dense(
            dataset.n_classes if classification else 1,
            kernel_initializer=VarianceScaling(scale=init_scale, mode="fan_in"), # this
            bias_initializer=VarianceScaling(scale=init_scale, mode="fan_in"), # this
            activation=None,
        )
    )
```

This is to check whether it had a significant effect on performance, namely, what happens if we comment out the lines annotated with `# this`. The results of such training are in `0402_check_init_scale`.

(Indeed I later switched to a more standard initialization method)

In [None]:
import os
import sys
# If we don't need CUDA, do this before importing TF
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
import numpy as np
import pandas as pd
import tqdm
import tqdm.notebook
import matplotlib.pyplot as plt
import seaborn as sns
import IPython
sns.set()

In [None]:
%load_ext autoreload
%aimport smooth.config
%aimport smooth.datasets
%aimport smooth.model
%aimport smooth.analysis
%aimport smooth.callbacks
%aimport smooth.measures
%aimport smooth.util
%autoreload 1

In [None]:
os.chdir("/nfs/scistore12/chlgrp/vvolhejn/smooth/logs/")

ms1 = pd.read_feather("0326_mnist_binary/measures.feather")
ms1["init_type"] = "ours"
ms2 = pd.read_feather("0402_check_init_scale/measures.feather")
ms2["init_type"] = "standard"
ms3 = pd.read_feather("0407_check_init_scale/measures.feather")
ms3["init_type"] = "standard-bias"

# ms = pd.concat([ms1, ms2], sort=False)
# ms = ms.reset_index(drop=True)

# print("Removing {} entries".format(sum(ms["gradient_norm_test"].isna())))
ms1 = ms1[~ms1["gradient_norm_test"].isna()]
ms2 = ms2[~ms2["gradient_norm_test"].isna()]
ms3 = ms3[~ms3["gradient_norm_test"].isna()]
# ms["model.weights_product_reg_coef"] = ms["model.weights_product_reg_coef"].fillna(value=0)

# smooth.analysis.remove_constant_columns(ms1, verbose=True)
# smooth.analysis.remove_constant_columns(ms2, verbose=True)
# smooth.analysis.remove_constant_columns(ms3, verbose=True)

ms1 = ms1[
    (ms1["model.hidden_size"] == 256) &
    (ms1["model.gradient_norm_reg_coef"] == 0) &
    (ms1["model.weights_product_reg_coef"] == 0)
]

ms1 = ms1.set_index("dataset.name")
ms2 = ms2.set_index("dataset.name")
ms3 = ms3.set_index("dataset.name")

In [None]:
ms1.join(ms2, lsuffix="_l", rsuffix="_r")

In [None]:
def compare(ms1, ms2, name1=None, name2=None):
    if name1 is None:
        name1 = ms1.iloc[0]["init_type"]

    if name2 is None:
        name2 = ms2.iloc[0]["init_type"]

    ms = ms1.join(ms2, lsuffix="_{}".format(name1), rsuffix="_{}".format(name2))

    for measure in ["loss_train", "loss_test", "gradient_norm_test", "weights_product"]:
        grid = sns.relplot(
            data=ms,
            x="{}_{}".format(measure, name1),
            y="{}_{}".format(measure, name2),
        )
        ax = grid.axes[0][0]
        if "loss" in measure:
            ax.set_xscale("log")
            ax.set_yscale("log")
        
        lim = (
            min(ms["{}_{}".format(measure, name1)].min(), ms["{}_{}".format(measure, name2)].min()),
            max(ms["{}_{}".format(measure, name1)].max(), ms["{}_{}".format(measure, name2)].max()),
        )
        ax.set_xlim(lim)
        ax.set_ylim(lim)

        plt.show()

In [None]:
compare(ms1, ms2)

In [None]:
compare(ms3, ms2)

In [None]:
ms = ms1.join(ms2, lsuffix="_l", rsuffix="_r")

for measure in ["loss_train", "loss_test", "gradient_norm_test", "weights_product"]:
    grid = sns.relplot(
        data=ms,
        x="{}_l".format(measure),
        y="{}_r".format(measure),
    )
    ax = grid.axes[0][0]
    if "loss" in measure:
        ax.set_xscale("log")
        ax.set_yscale("log")
        lim = (
            min(ms["{}_l".format(measure)].min(), ms["{}_r".format(measure)].min()),
            max(ms["{}_l".format(measure)].max(), ms["{}_r".format(measure)].max()),
        )
        ax.set_xlim(lim)
        ax.set_ylim(lim)

    plt.show()