# Zoom Polarization Analysis

## Introduction

```mermaid
graph TD
    A[Sample Run] --> B([SANS Workflow])
    B --> C["I(Qx, Qy) in event mode"]
    D[Runs with He3 cell at a few time points] --> E([SANS Workflow])
    E --> F[wavelength-dependent He3 cell transmission fraction at a few time points]
    F --> G([<font color=black>He3 Cell Workflow])
    G --> H[<font color=black>time- and wavelength-dependent transmission function]
    C --> I([Polarization Correction])
    H --> I
    I --> J["Corrected I(Qx, Qy) in 4 spin channels"]

    style B fill:green
    style D fill:green
    style E fill:green
    style F fill:green
    style G fill:yellowgreen
    style H fill:yellowgreen
```

In [None]:
%matplotlib widget
import sciline
import scipp as sc
from ess import polarization as pol
from ess import isissans as isis
from ess.sans.types import *

In [None]:
sans_workflow = isis.zoom.ZoomWorkflow()
sans_workflow.set_param_series(PixelMaskFilename, [])

In [None]:
from pathlib import Path

data_folder = Path('zoom_polarized_data')
# Runs with analyzer at 4 different times
cell_runs = [str(data_folder / f'ZOOM00022{run}.nxs') for run in [710, 712, 714, 716]]
empty_run = data_folder / 'ZOOM00034787.nxs'
depolarized_run = data_folder / 'ZOOM00022718.nxs'
cell_runs = cell_runs + [depolarized_run]

## Setup SANS workflow for computing transmission fraction

In [None]:
from ess.polarization.zoom import ZoomTransmissionFractionWorkflow

sans_workflow = ZoomTransmissionFractionWorkflow()
sans_workflow[Filename[EmptyBeamRun]] = str(empty_run)
sans_workflow[WavelengthBins] = sc.geomspace(
    'wavelength', start=1.75, stop=16.5, num=141, unit='Å'
)
sans_workflow[isis.mantidio.Period] = 0

## Inspect data for one of the runs with analyzer cell

In [None]:
# Load only first run
sans_workflow.set_param_series(Filename[TransmissionRun[SampleRun]], cell_runs[:1])
loaded = sans_workflow.compute(
    sciline.Series[
        Filename[TransmissionRun[SampleRun]],
        isis.mantidio.LoadedFileContents[TransmissionRun[SampleRun]],
    ]
)
first_run = list(loaded.values())[0]
sc.DataGroup(sc.collapse(first_run['data'], keep='tof')).plot()

In [None]:
first_run

We can load the combined time-dependent incident and transmission monitors.
Note that the last run is the depolarized run:

In [None]:
sans_workflow.set_param_series(Filename[TransmissionRun[SampleRun]], cell_runs)
mons = sans_workflow.compute(
    (
        RawMonitor[TransmissionRun[SampleRun], Incident],
        RawMonitor[TransmissionRun[SampleRun], Transmission],
    )
)
mons = sc.DataGroup(
    incident=mons[RawMonitor[TransmissionRun[SampleRun], Incident]],
    transmission=mons[RawMonitor[TransmissionRun[SampleRun], Transmission]],
)
display(sc.DataGroup(sc.collapse(mons['incident'], keep='tof')).plot())
display(sc.DataGroup(sc.collapse(mons['transmission'], keep='tof')).plot())

The task graph for computing the transmission fraction is:

In [None]:
sans_workflow.visualize(TransmissionFraction[SampleRun], graph_attr={'rankdir': 'LR'})

## Compute transmission fractions

There are multiple files which together define the time-dependence of the analyzer cell transmission.
Note that as before the final run (time) is the depolarized run:

In [None]:
raw_transmission = sans_workflow.compute(TransmissionFraction[SampleRun])

We can plot the computed transmission fractions:

In [None]:
transmission_depolarized = raw_transmission['time', -1].copy()
transmission = raw_transmission['time', :-1].copy()
trans = sc.DataGroup(
    {f"{time:c}": transmission['time', time] for time in transmission.coords['time']}
)
trans[f'depolarized'] = transmission_depolarized
display(trans.plot())

In [None]:
# Sanity check: Where can cosh yield values that can be fitted?
transmission_empty_glass = 0.9 * sc.Unit('dimensionless')
wavelength = sc.midpoints(transmission.coords['wavelength'])
opacity0 = 0.8797823016804095 * sc.Unit('1/angstrom')
(
    sc.acosh(transmission * sc.exp(opacity0 * wavelength) / transmission_empty_glass)
    / (opacity0 * wavelength)
).plot()

In [None]:
# TODO Which wavelength bounds should be used?
wav_min = 2.2 * sc.Unit('angstrom')
wav_max = 2.8 * sc.Unit('angstrom')
transmission_truncated = raw_transmission['wavelength', wav_min:wav_max]
transmission_depolarized = transmission_truncated['time', -1].copy()
transmission = transmission_truncated['time', :-1].copy()

We can now setup the polarization analysis workflow.
The previously computed transmission fractions are used as workflow inputs:

In [None]:
pol_workflow = pol.he3.He3CellWorkflow(in_situ=False, incoming_polarized=True)
pol_workflow[
    pol.he3.He3CellTransmissionFraction[pol.Analyzer, pol.Polarized]
] = transmission
pol_workflow[
    pol.he3.He3CellTransmissionFraction[pol.Analyzer, pol.Depolarized]
] = transmission_depolarized

# When in_situ=False, these params are used as starting guess for the fit
pol_workflow[pol.he3.He3CellLength[pol.Analyzer]] = 0.1 * sc.Unit('m')
pol_workflow[pol.he3.He3CellPressure[pol.Analyzer]] = 1.0 * sc.Unit('bar')
pol_workflow[pol.he3.He3CellTemperature[pol.Analyzer]] = 300.0 * sc.Unit('K')

pol_workflow[pol.he3.He3TransmissionEmptyGlass[pol.Analyzer]] = transmission_empty_glass
pol_workflow.visualize(
    pol.he3.He3TransmissionFunction[pol.Analyzer], graph_attr={'rankdir': 'LR'}
)

The workflow can compute the transmission function:

In [None]:
func = pol_workflow.compute(pol.he3.He3TransmissionFunction[pol.Analyzer])

We can evaluate this transmission function at desired time and wavelength points:

In [None]:
wavelength = sc.linspace('wavelength', start=2, stop=16.0, num=141, unit='angstrom')
time = sc.linspace('time', start=0, stop=100000, num=101, unit='s')
display(func.opacity_function(wavelength=wavelength).plot())
display(func.polarization_function(time=time).plot())
display(func(wavelength=wavelength, time=time, plus_minus='plus').plot(norm='log'))
display(func(wavelength=wavelength, time=time, plus_minus='minus').plot(norm='log'))

In [None]:
trans = func(wavelength=wavelength, time=time, plus_minus='plus')
sc.DataGroup(
    {f"{time:c}": trans['time', time] for time in trans.coords['time'][::20]}
).plot(norm='linear', linestyle='solid', marker=None)

In [None]:
trans = func(wavelength=wavelength, time=time, plus_minus='plus')
sc.DataGroup(
    {f"{wav:c}": trans['wavelength', wav] for wav in trans.coords['wavelength'][::20]}
).plot(norm='log', linestyle='solid', marker=None)

In [None]:
func.opacity_function.opacity0

In [None]:
func.polarization_function.C

In [None]:
func.polarization_function.T1