In [1]:
"""Analysis of model weight distribution."""
# pylint: disable=redefined-outer-name, expression-not-assigned, import-error, not-callable, pointless-statement, no-value-for-parameter, undefined-variable, unused-argument
from __future__ import annotations

from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.io as pio

# from IPython.display import display

pio.renderers.default = "notebook"

# from epi_ml.core.model_pytorch import LightningDenseClassifier

In [None]:
%matplotlib inline

In [23]:
def get_parent_directories(directory: str | Path) -> Tuple[str, str]:
    """Return the name of the two first parent directories of a given directory."""
    path = Path(directory).resolve()
    return path.parent.name, path.parent.parent.name

In [3]:
home = Path("/home/local/USHERBROOKE/rabj2301/Projects")
input_dir = home / "epilap/input"
metadata_path = (
    input_dir
    / "metadata/hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
)

output = home / "epilap/output"
# model_dir = output / "models/harmonized_donor_sex_1l_3000n-10fold_binary_onlyl1-split0_l1_0.1/"

In [11]:
def load_model(model_dir):
    """Load model. Print some stuff."""
    model = LightningDenseClassifier.restore_model(model_dir)
    # print(model.mapping)
    # model.summarize(max_depth=-1)
    return model

In [43]:
def plot_weights_dist(model: LightningDenseClassifier, logdir: Path):
    """Plot and save the distribution of weights of the first layer of a model.

    This function creates a violin plot of the absolute weights in the first layer of the given
    model, saving both the plot and a description of the weights as CSV to the specified directory.

    Args:
        model (LightningDenseClassifier): The model from which to extract the weights.
        logdir (Path): The directory where the plot and CSV file will be saved.
    """
    for layer in list(model.parameters())[0:1]:
        weights = layer.detach().flatten().numpy()
        weights = np.absolute(
            np.random.choice(weights, min(1000, len(weights)), replace=False)
        )

        weigths_description = pd.Series(weights).describe(
            percentiles=[0.05, 0.25, 0.5, 0.75, 0.95]
        )
        # display(weigths_description)
        weigths_description.to_csv(logdir / "weights_description.csv", sep=",")

        # pylint: disable=consider-using-f-string
        fig_title = "Weights distribution <br> {0} : {1}".format(
            *get_parent_directories(logdir)
        )
        fig = px.violin(
            weights, box=True, points=False, range_y=[0, 0.01], title=fig_title
        )
        # fig.show()
        fig.write_image(logdir / "weights_dist.png")
        return weigths_description

In [44]:
dirname = output / "models" / "harmonized_donor_sex_1l_3000n"
for model_dir in dirname.glob("*dropout-*"):
    model_dir = model_dir / "split0"
    if "weights_dist.png" not in [f.name for f in model_dir.glob("*")]:
        model = load_model(model_dir)
        plot_weights_dist(model, model_dir)

Reading checkpoint list and taking last line.
Loading model from /home/local/USHERBROOKE/rabj2301/Projects/epilap/output/models/harmonized_donor_sex_1l_3000n/10fold-binary-l1-500_dropout-0.25/split0/EpiLaP/52c1ebf0aac8470c8b161ed46ecf9ff5/checkpoints/last.ckpt
Reading checkpoint list and taking last line.
Loading model from /home/local/USHERBROOKE/rabj2301/Projects/epilap/output/models/harmonized_donor_sex_1l_3000n/10fold-binary-l1-100_dropout-0.10/split0/EpiLaP/2a6d35735c06427599d7bb619b142e87/checkpoints/last.ckpt
Reading checkpoint list and taking last line.
Loading model from /home/local/USHERBROOKE/rabj2301/Projects/epilap/output/models/harmonized_donor_sex_1l_3000n/10fold-binary-l1-500_dropout-0.10/split0/EpiLaP/80a2792dde9f4932b5bedcb7d72dcabd/checkpoints/last.ckpt
Reading checkpoint list and taking last line.
Loading model from /home/local/USHERBROOKE/rabj2301/Projects/epilap/output/models/harmonized_donor_sex_1l_3000n/10fold-binary-l1-50_dropout-0.25/split0/EpiLaP/7d0d12b7a18d