In [None]:
from dataclasses import dataclass
from collections.abc import Callable

import scipp as sc
import numpy as np

from scippneutron.absorption.base import compute_transmission_map
from scippneutron.absorption.cylinder import Cylinder


@dataclass
class Material:
    c: sc.Variable
    mu: Callable[[sc.Variable], sc.Variable]


sample_shape = Cylinder(
    symmetry_line=sc.vector([0, 1, 0,]),
    center_of_base=sc.vector([-4, 12, 5]),
    radius=sc.scalar(1),
    height=sc.scalar(1.5)
)
# Some model for the wavelength dependence, not sure what makes sense to put here
material = Material(1, lambda wav: 0.4 * wav * (3 - wav))


def transmission(quadrature_kind):
    return compute_transmission_map(
        sample_shape, material,
        beam_direction=sc.vector([1, 0, 0]),
        # Wavelengts to compute correction for
        wavelength=sc.linspace('wavelength', 0.5, 2.5, 10),
        # Spherical coordinates
        theta=sc.linspace('theta', 0, np.pi, 100, endpoint=False, unit='rad'),
        phi=sc.linspace('phi', 0, 2 * np.pi, 200, endpoint=False, unit='rad'),
        quadrature_kind=quadrature_kind,
    )


def show_correction_map(da):
    return (
        da['wavelength', 0].plot() /
        da['wavelength', 0]['theta', 50].plot() /
        da['wavelength', 0]['theta', 50]['phi', 30:70].plot()
    )

In [None]:
transmission('cheap')

In [None]:
show_correction_map(transmission('cheap'))

In [None]:
show_correction_map(transmission('medium'))

In [None]:
show_correction_map(transmission('expensive'))