# Scaling

## MTZ IO

``ess.nmx`` has ``MTZ`` IO helper functions.
They can be used as providers in a workflow of scaling routine.

They are wrapping ``MTZ`` IO functions of ``gemmi``.

In [None]:
from ess.nmx.mtz_io import read_mtz_file, mtz_to_pandas, MTZFilePath
from ess.nmx.data import get_small_random_mtz_samples


small_mtz_sample = get_small_random_mtz_samples()[0]
mtz = read_mtz_file(MTZFilePath(small_mtz_sample))
df = mtz_to_pandas(mtz)
df.head()

## Build Pipeline

Scaling routine includes:
- Reducing individual MTZ dataset
- Merging MTZ dataset 
- Reducing merged MTZ dataset

These operations are done on pandas dataframe as recommended in ``gemmi``.
And multiple MTZ files are expected, so we need to use ``sciline.ParamTable``.
<!--TODO: Update it to use cyclebane instead of ParamTable if needed.-->

In [None]:
import sciline as sl
import scipp as sc

from ess.nmx.mtz_io import mtz_io_providers, mtz_io_params
from ess.nmx.mtz_io import MTZFileIndex, SpaceGroupDesc
from ess.nmx.scaling import scaling_providers, scaling_params
from ess.nmx.scaling import (
    WavelengthBinSize,
    FilteredEstimatedScaledIntensities,
    ReferenceWavelength,
    ScaledIntensityLeftTailThreshold,
    ScaledIntensityRightTailThreshold,
)

pl = sl.Pipeline(
    providers=mtz_io_providers + scaling_providers,
    params={
        SpaceGroupDesc: "C 1 2 1",
        WavelengthBinSize: 250,
        ReferenceWavelength: sc.scalar(
            3, unit=sc.units.angstrom
        ),  # Remove it if you want to use the middle of the bin
        ScaledIntensityLeftTailThreshold: sc.scalar(
            0.1,  # Increase it to remove more outliers
        ),
        ScaledIntensityRightTailThreshold: sc.scalar(
            4.0,  # Decrease it to remove more outliers
        ),
        **mtz_io_params,
        **scaling_params,
    },
)


file_path_table = sl.ParamTable(
    # row_dim=MTZFileIndex,
    # columns={
    #     MTZFilePath: [pathlib.Path(f"../developer/sample_{i}.mtz") for i in range(1, 6)]
    # },
    row_dim=MTZFileIndex, columns={MTZFilePath: get_small_random_mtz_samples()}
)

pl.set_param_table(file_path_table)
pl

## Build Workflow

In [None]:
scaling_nmx_workflow = pl.get(FilteredEstimatedScaledIntensities)
scaling_nmx_workflow.visualize(graph_attr={"rankdir": "LR"})

## Compute Desired Type

In [None]:
scaling_nmx_workflow.compute(FilteredEstimatedScaledIntensities)

In [None]:
scaled = scaling_nmx_workflow.compute(FilteredEstimatedScaledIntensities)

sc.values(scaled.data).hist(intensities=30).plot(
    grid=True, linewidth=3, title="Density Plot of Estimated Scaled Intensities"
)

## Change Provider
Here is an example of how to insert different filter function.

In this example, we will swap a provider that filters ``EstimatedScaledIntensities`` and provide ``FilteredEstimatedScaledIntensities``.

In [None]:
from typing import NewType
import scipp as sc
from ess.nmx.scaling import (
    EstimatedScaledIntensities,
    FilteredEstimatedScaledIntensities,
)

# Define the new types for the filtering function
NRoot = NewType("NRoot", int)
"""The n-th root to be taken for the standard deviation."""
NRootStdDevCut = NewType("NRootStdDevCut", float)
"""The number of standard deviations to be cut from the n-th root data."""


def _calculate_sample_standard_deviation(var: sc.Variable) -> sc.Variable:
    """Calculate the sample variation of the data.

    This helper function is a temporary solution before
    we release new scipp version with the statistics helper.
    """
    import numpy as np

    return sc.scalar(np.nanstd(var.values))


# Define the filtering function with right argument types and return type
def cut_estimated_scaled_intensities_by_n_root_std_dev(
    scaled_intensities: EstimatedScaledIntensities,
    n_root: NRoot,
    n_root_std_dev_cut: NRootStdDevCut,
) -> FilteredEstimatedScaledIntensities:
    """Filter the mtz data array by the quad root of the sample standard deviation.

    Parameters
    ----------
    scaled_intensities:
        The scaled intensities to be filtered.

    n_root:
        The n-th root to be taken for the standard deviation.
        Higher n-th root means cutting is more effective on the right tail.
        More explanation can be found in the notes.

    n_root_std_dev_cut:
        The number of standard deviations to be cut from the n-th root data.

    Returns
    -------
    :
        The filtered scaled intensities.

    """
    # Check the range of the n-th root
    if n_root < 1:
        raise ValueError("The n-th root should be equal to or greater than 1.")

    copied = scaled_intensities.copy(deep=False)
    nth_root = copied.data ** (1 / n_root)
    # Calculate the mean
    nth_root_mean = nth_root.nanmean()
    # Calculate the sample standard deviation
    nth_root_std_dev = _calculate_sample_standard_deviation(nth_root)
    # Calculate the cut value
    half_window = n_root_std_dev_cut * nth_root_std_dev
    keep_range = (nth_root_mean - half_window, nth_root_mean + half_window)

    # Filter the data
    return FilteredEstimatedScaledIntensities(
        copied[(nth_root > keep_range[0]) & (nth_root < keep_range[1])]
    )


pl.insert(cut_estimated_scaled_intensities_by_n_root_std_dev)
pl[NRoot] = 4
pl[NRootStdDevCut] = 1.0

pl.compute(FilteredEstimatedScaledIntensities)