# Exampe: cylinder as a user-defined model

This notebook shows how to implement the `cylinder` model as a user-defined model class in `ffsas`, which is then used for modelling and inversion. The `cylinder` model have four parameters: length $l$, radius $r$, angle to beam $\theta$ and rotation about beam $\phi$. 

In [None]:
# avoid omp error on Mac
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# plotting setups
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.colors as colors
matplotlib.rcParams.update({'font.size': 12})
matplotlib.rcParams.update({'legend.fontsize': 11})
matplotlib.rcParams.update({'axes.titlesize': 12})
matplotlib.rcParams.update({'lines.linewidth': 2})

# create output dir
from pathlib import Path
output_dir = Path('./output/cylinder')
Path(output_dir / 'iterations').mkdir(parents=True, exist_ok=True)

In [None]:
# uncomment this line to install ffsas
# !pip install ffsas

# ffsas
import torch
import ffsas
# in this example, we use float32 as the data type for torch tensors
# this must be called before importing from ffsas.models and ffsas.system
ffsas.set_torch_dtype(torch.float32)

from ffsas.models import SASModel
from ffsas.system import SASGreensSystem

# set device to 'cuda' if you have a GPU
device = 'cpu'

# math tools
import math
from scipy.special import j1

# The `cylinder` model

A user-defined model is implemented by class inheritance. The base class is `SASModel`. Four abstract class methods are to be implemented: 

* `compute_G()`: compute the Green's tensor $G$
* `compute_V()`: compute volume $V$
* `get_par_keys_G()`: get the parameter keys in the order they appear in the dimensions of $G$
* `get_par_keys_V()`: get the parameter keys in the order they appear in the dimensions of $V$

The arguments for `compute_G()` are explained below:
* `q_list`: list of $q$-vectors; `len[q_list]` can be to 1 or 2, meaning that the intensity is a series or an image
* `par_dict`: a `dict` of model parameters, all being 1D `torch.Tensor`'s
* `const_dict`: a `dict` of model constants, all being scalars
* `V`: the volume tensor $V$; if not provided, compute it in `compute_G()` if needed

The `cylinder` model is implemented as below. The equation for the Green's tensor can be found [here](https://www.sasview.org/docs/user/models/cylinder.html). Avoiding `for` loop is the key to achieve high performance in forward modelling.

In [None]:
class Cylinder(SASModel):
    @classmethod
    def compute_G(cls, q_list, par_dict, const_dict, V=None):
        # get parameters
        qx, qy = q_list[0], q_list[1]
        l, r = par_dict['l'], par_dict['r']
        theta, phi = par_dict['theta'], par_dict['phi']
        drho = const_dict['drho']

        # compute volume
        if V is None:
            V = cls.compute_V(par_dict)
            
        #############
        # Compute G #
        #############

        # step 1: rotate q
        sin_theta = torch.sin(theta)
        r31 = torch.outer(sin_theta, torch.cos(phi))
        r32 = torch.outer(sin_theta, torch.sin(phi))
        qc = (qx[:, None, None] * r31[None, :, :])[:, None, :, :] + \
             (qy[:, None, None] * r32[None, :, :])[None, :, :, :]
        qa = torch.sqrt(torch.clip(
            (qx ** 2)[:, None, None, None] +
            (qy ** 2)[None, :, None, None] - qc ** 2, min=0.))

        # step 2: qa * r, qc * l
        qa_r = torch.moveaxis(qa[:, :, :, :, None] *
                              r[None, None, None, None, :], 4, 2)
        qc_l = torch.moveaxis(qc[:, :, :, :, None] *
                              l[None, None, None, None, :], 4, 2) * .5
        
        # step 3: shape factor
        sin_qc_l = torch.nan_to_num(2. * torch.sin(qc_l) / qc_l,
                                    nan=1., posinf=1., neginf=1.)
        # NOTE: scipy.special.j1() must be called on cpu
        j1_qa_r = torch.tensor(j1(qa_r.to('cpu').numpy()), device=qa_r.device)
        j1_qa_r = torch.nan_to_num(j1_qa_r / qa_r, nan=1., posinf=1., neginf=1.)
        # shape factor
        shape_factor = sin_qc_l[:, :, :, None, :, :] * j1_qa_r[:, :, None, :, :, :]
        
        # step 4: G
        G = (drho * shape_factor * V[None, None, :, :, None, None]) ** 2
        return G

    @classmethod
    def get_par_keys_G(cls):
        return ['l', 'r', 'theta', 'phi']

    @classmethod
    def compute_V(cls, par_dict):
        l, r = par_dict['l'], par_dict['r']
        return math.pi * l[:, None] * r[None, :] ** 2

    @classmethod
    def get_par_keys_V(cls):
        return ['l', 'r']

Having the `Cylinder` class defined, the remainder of this notebook looks mostly the same as [ellipsoid.ipynb](ellipsoid.ipynb). 

---

# Ground truth

To do inversion, we need some intensity data as input. Here we create the data by modelling, using some noisy parameter distributions, or the ground truth.

### Ground truth of distributions of $l$, $r$, $\theta$, $\phi$


The following function creates a "crazy" distribution by adding up Gaussians and random noise.

In [None]:
def crazy_distribution(x, gaussians, noise_level, fade_start, fade_end, seed=0):
    # create
    w_true = torch.zeros(x.shape)
    
    # add Gaussians
    for factor, mean, stddev in gaussians:
        w_true += factor * torch.exp(-((x - mean) / stddev) ** 2)
    
    # add noise
    torch.random.manual_seed(seed)
    w_true += noise_level * torch.rand(x.shape) * torch.rand(x.shape)
    
    # fade both ends to make it look nicer
    if len(x) >= 3:
        w_true[0:fade_start] = 0.
        w_true[fade_start:fade_end] *= torch.linspace(0, 1, fade_end - fade_start)
        w_true[-fade_start:] = 0.
        w_true[-fade_end:-fade_start] *= torch.linspace(1, 0, fade_end - fade_start)
    
    # normalize to 1
    w_true /= torch.sum(w_true)
    return w_true

Make the parameter distributions:

In [None]:
# parameters
par_dict = {
    'l': torch.linspace(200., 600., 18),
    'r': torch.linspace(50., 90., 17),
    'theta': torch.linspace(5., 60., 16),
    'phi': torch.linspace(150., 240., 15)
}

# parameter distributions
w_true_dict = {
    'l': crazy_distribution(par_dict['l'], [(1.5, 300, 20), (1, 400, 20), (2, 500, 20)], 1, 1, 1),
    'r': crazy_distribution(par_dict['r'], [(1, 60, 3), (2, 70, 4), (2, 80, 3)], 1, 1, 1),
    'theta': crazy_distribution(par_dict['theta'], [(4, 15, 5), (2, 35, 5), (2, 50, 5)], 2, 1, 1),
    'phi': crazy_distribution(par_dict['phi'], [(2, 170, 10), (2, 200, 10), (4, 220, 10)], 3, 1, 1)
}

# plot distributions
fig, ax = plt.subplots(2, 2, dpi=200, figsize=(10, 6))
plt.subplots_adjust(hspace=.5, wspace=.3)
ax[0, 0].plot(par_dict['l'], w_true_dict['l'], label='Truth')
ax[0, 1].plot(par_dict['r'], w_true_dict['r'], label='Truth')
ax[1, 0].plot(par_dict['theta'], w_true_dict['theta'], label='Truth')
ax[1, 1].plot(par_dict['phi'], w_true_dict['phi'], label='Truth')
ax[0, 0].set_xlabel(r'Lenght, $l$ ($\AA$)')
ax[0, 1].set_xlabel(r'Radius, $r$ ($\AA$)')
ax[1, 0].set_xlabel(r'Cylinder axis to beam angle, $\theta$ (degree)')
ax[1, 1].set_xlabel(r'Rotation about beam, $\phi$ (degree)')
ax[0, 0].set_ylabel(r'Weights, $w$')
ax[0, 1].set_ylabel(r'Weights, $w$')
ax[1, 0].set_ylabel(r'Weights, $w$')
ax[1, 1].set_ylabel(r'Weights, $w$')
plt.savefig(output_dir / 'true_weights.png', bbox_inches='tight', facecolor='w')
plt.show()

# degree to radian
par_dict['theta'] = torch.deg2rad(par_dict['theta'])
par_dict['phi'] = torch.deg2rad(par_dict['phi'])

### Ground truth of intensity

Now we compute the true intensity from the above parameter distributions, assuming the truth of $\xi$ and $b$. See `example/sphere/sphere.ipynb` to see the difference between `scale` in `SASView/SASModels` and $\xi$ in `ffsas`.

In [None]:
### ground truth of scale and background
scale_true = 0.15
b_true = 2.2e-4

# q vectors
qx = torch.linspace(-.5, .5, 32)
qy = torch.linspace(-.75, .75, 48)

# (SLD - SLD_solvent) ^ 2
drho = 1.

# comupute the Green's function and volume
G = Cylinder.compute_G_mini_batch([qx, qy],  par_dict, {'drho': drho}, batch_size=8, device=device)

# define the G-based SAS system
g_sys = SASGreensSystem(G, Cylinder.get_par_keys_G(), batch_size=8, device=device, 
                        log_file=output_dir / 'GSys.log', log_screen=True)

# true xi
V = Cylinder.compute_V(par_dict)
V_ave = torch.tensordot(torch.tensordot(V, w_true_dict['r'], dims=1), w_true_dict['l'], dims=1)
xi_true = 1e-4 * scale_true / V_ave

# finally compute the ground truth of intensity
I_true = g_sys.compute_intensity(w_true_dict, xi_true, b_true)

Plot the intensity truth:

In [None]:
plt.figure(dpi=150)
plt.imshow(I_true.t(), 
           extent=(qx[0], qx[-1], qy[0], qy[-1]), aspect=1., cmap='hot',
           norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
plt.xlabel(r'Scattering vector, $qx$ ($\AA^{-1}$)')
plt.ylabel(r'Scattering vector, $qy$ ($\AA^{-1}$)')
plt.colorbar()
plt.savefig(output_dir / 'true_intensity.png', bbox_inches='tight', facecolor='w')
plt.show()

---

#  Inversion

With the G-based SAS system defined, inversion only requires one line, taking the simulated intensity data (`I_truth`) as input (for both mean and stddev). We perform 300 iterations by sending `maxiter=300` and save the results every 10 steps by `save_iter=10`.

In [None]:
# invert the G-system
res_dict = g_sys.solve_inverse(I_true, I_true, auto_scaling=True, 
                               maxiter=300, verbose=2, save_iter=10,
                               trust_options={'xtol': 0.})  # disable early stop and finish 300 iters

Plot the convergence history of the weights of $l$, $r$, $\theta$, $\phi$:

In [None]:
# scaling factor to plot 1/sensitivity as uncertainty 
scaling_factor_sensitivity = .5
par_keys = ['l', 'r', 'theta', 'phi']
par_keys_latex = [r'$l$', r'$r$', r'$\theta$', r'$\phi$']

# plot history
for i, saved_res_dict in enumerate(res_dict['saved_res']):
    fig, ax = plt.subplots(1, 4, dpi=200, figsize=(15, 2))
    for ikey, key in enumerate(par_keys):
        # truth
        ax[ikey].plot(par_dict[key], w_true_dict[key], label=f'Truth')
        # fitted
        ax[ikey].plot(par_dict[key], saved_res_dict['w_dict'][key], label=f'MLE')
        # sensitivity
        sens = saved_res_dict['sens_w_dict'][key] / saved_res_dict['sens_w_dict'][key].sum()
        ax[ikey].fill_between(par_dict[key], 
                              saved_res_dict['w_dict'][key] - abs(scaling_factor_sensitivity * sens), 
                              saved_res_dict['w_dict'][key] + abs(scaling_factor_sensitivity * sens), 
                              alpha=.3, color='gray', zorder=-100, label=r'Sens${}^{-1}$')
        # settings
        ax[ikey].set_xlabel(par_keys_latex[ikey])
        ax[ikey].set_ylim(-.01, w_true_dict[key].max() * 1.1)
        ax[ikey].set_xlim(par_dict[key].min(), par_dict[key].max())
        ax[ikey].set_xticks([])
    ax[0].set_ylabel(f'iter={saved_res_dict["nit"]}')
    ax[2].legend(loc='upper right')
    plt.savefig(output_dir / f'iterations/{i}.png', bbox_inches='tight', facecolor='w')
    plt.show()

Note that the distributions of $l$ and $\theta$ are not perfectly recovered. This is because we used `ffsas.set_torch_dtype(torch.float32)` at the very beginning. Switch to `torch.float64` to
recover the ground truth accurately.

Finally, plot the fitted intensity and fitting error:

In [None]:
fig, ax = plt.subplots(1, 3, dpi=200, figsize=(12, 4), sharex=True, sharey=True)
plt.subplots_adjust(wspace=-.3)
# truth
cm = ax[0].imshow(I_true.t(), extent=(qx[0], qx[-1], qy[0], qy[-1]), cmap='hot',
                  norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
ax[0].set_title('True intensity')
ax[0].set_xlabel(r'Scattering vector, $q_x$ ($\AA^{-1}$)')
ax[0].set_ylabel(r'Scattering vector, $q_y$ ($\AA^{-1}$)')
# fitted
ax[1].imshow(res_dict['I'].t(), extent=(qx[0], qx[-1], qy[0], qy[-1]), cmap='hot',
             norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
ax[1].set_title('Fitted intensity')
ax[1].set_xlabel(r'Scattering vector, $q_x$ ($\AA^{-1}$)')
# error
im = ax[2].imshow(torch.clip(torch.abs(res_dict['I'].t() - I_true.t()), min=1e-10), 
                  extent=(qx[0], qx[-1], qy[0], qy[-1]), cmap='hot',
                  norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
ax[2].set_title('Misfit')
ax[2].set_xlabel(r'Scattering vector, $q_x$ ($\AA^{-1}$)')

# colorbar
cbar_ax = fig.add_axes([.86, 0.15, 0.02, 0.7])
fig.colorbar(im, cax=cbar_ax)
plt.savefig(output_dir / 'intensity_fit.png', bbox_inches='tight', facecolor='w')
plt.show()

---