# Exampe: ellipsoid

This notebook shows how to use `ffsas` to invert for the parameter distributions of a `ellipsoid` model, including polar radius $r_p$, equatorial radius $r_e$, angle to beam $\theta$ and rotation about beam $\phi$.

In [None]:
# avoid omp error on Mac
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# plotting setups
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.colors as colors
matplotlib.rcParams.update({'font.size': 12})
matplotlib.rcParams.update({'legend.fontsize': 11})
matplotlib.rcParams.update({'axes.titlesize': 12})
matplotlib.rcParams.update({'lines.linewidth': 2})

# create output dir
from pathlib import Path
output_dir = Path('./output/ellipsoid')
Path(output_dir / 'iterations').mkdir(parents=True, exist_ok=True)

In [None]:
# uncomment this line to install ffsas
# !pip install ffsas

# ffsas
import torch
from ffsas.models import Ellipsoid
from ffsas.system import SASGreensSystem

# set device to 'cuda' if you have a GPU
device = 'cpu'

# Ground truth

To do inversion, we need some intensity data as input. Here we create the data by modelling, using some noisy parameter distributions, or the ground truth.

### Ground truth of distributions of $r_p$, $r_e$, $\theta$, $\phi$


The following function creates a "crazy" distribution by adding up Gaussians and random noise.

In [None]:
def crazy_distribution(x, gaussians, noise_level, fade_start, fade_end, seed=0):
    # create
    w_true = torch.zeros(x.shape)
    
    # add Gaussians
    for factor, mean, stddev in gaussians:
        w_true += factor * torch.exp(-((x - mean) / stddev) ** 2)
    
    # add noise
    torch.random.manual_seed(seed)
    w_true += noise_level * torch.rand(x.shape) * torch.rand(x.shape)
    
    # fade both ends to make it look nicer
    if len(x) >= 3:
        w_true[0:fade_start] = 0.
        w_true[fade_start:fade_end] *= torch.linspace(0, 1, fade_end - fade_start)
        w_true[-fade_start:] = 0.
        w_true[-fade_end:-fade_start] *= torch.linspace(1, 0, fade_end - fade_start)
    
    # normalize to 1
    w_true /= torch.sum(w_true)
    return w_true

Make the parameter distributions:

In [None]:
# parameters
par_dict = {
    'rp': torch.linspace(200., 600., 18),
    're': torch.linspace(50., 90., 17),
    'theta': torch.linspace(5., 60., 16),
    'phi': torch.linspace(150., 240., 15)
}

# parameter distributions
w_true_dict = {
    'rp': crazy_distribution(par_dict['rp'], [(1.5, 300, 20), (1, 400, 20), (2, 500, 20)], 1, 1, 1),
    're': crazy_distribution(par_dict['re'], [(1, 60, 3), (2, 70, 4), (2, 80, 3)], 1, 1, 1),
    'theta': crazy_distribution(par_dict['theta'], [(4, 15, 5), (2, 35, 5), (2, 50, 5)], 2, 1, 1),
    'phi': crazy_distribution(par_dict['phi'], [(2, 170, 10), (2, 200, 10), (4, 220, 10)], 3, 1, 1)
}

# plot distributions
fig, ax = plt.subplots(2, 2, dpi=200, figsize=(10, 6))
plt.subplots_adjust(hspace=.5, wspace=.3)
ax[0, 0].plot(par_dict['rp'], w_true_dict['rp'], label='Truth')
ax[0, 1].plot(par_dict['re'], w_true_dict['re'], label='Truth')
ax[1, 0].plot(par_dict['theta'], w_true_dict['theta'], label='Truth')
ax[1, 1].plot(par_dict['phi'], w_true_dict['phi'], label='Truth')
ax[0, 0].set_xlabel(r'Polar radius, $r_p$ ($\AA$)')
ax[0, 1].set_xlabel(r'Equatorial radius, $r_e$ ($\AA$)')
ax[1, 0].set_xlabel(r'Ellipsoid axis to beam angle, $\theta$ (degree)')
ax[1, 1].set_xlabel(r'Rotation about beam, $\phi$ (degree)')
ax[0, 0].set_ylabel(r'Weights, $w$')
ax[0, 1].set_ylabel(r'Weights, $w$')
ax[1, 0].set_ylabel(r'Weights, $w$')
ax[1, 1].set_ylabel(r'Weights, $w$')
plt.savefig(output_dir / 'true_weights.png', bbox_inches='tight', facecolor='w')
plt.show()

# degree to radian
par_dict['theta'] = torch.deg2rad(par_dict['theta'])
par_dict['phi'] = torch.deg2rad(par_dict['phi'])

### Ground truth of intensity

Now we compute the true intensity from the above parameter distributions, assuming the truth of $\xi$ and $b$. See [sphere.ipynb](sphere.ipynb) to see the difference between `scale` in [SASView/SASModels](http://www.sasview.org/docs/user/models/ellipsoid.html) and $\xi$ in `ffsas`.

In [None]:
### ground truth of scale and background
scale_true = 0.15
b_true = 2.2e-4

# q vectors
qx = torch.linspace(-.25, .25, 32)
qy = torch.linspace(-.25, .25, 30)

# (SLD - SLD_solvent) ^ 2
drho = 1.

# comupute the Green's function and volume
G = Ellipsoid.compute_G_mini_batch([qx, qy],  par_dict, {'drho': drho}, batch_size=8, device=device)

# define the G-based SAS system
g_sys = SASGreensSystem(G, Ellipsoid.get_par_keys_G(), batch_size=8, device=device, 
                        log_file=output_dir / 'GSys.log', log_screen=True)

# true xi
V = Ellipsoid.compute_V(par_dict)
V_ave = torch.tensordot(torch.tensordot(V, w_true_dict['re'], dims=1), w_true_dict['rp'], dims=1)
xi_true = 1e-4 * scale_true / V_ave

# finally compute the ground truth of intensity
I_true = g_sys.compute_intensity(w_true_dict, xi_true, b_true)

Plot the intensity truth:

In [None]:
plt.figure(dpi=150)
plt.imshow(I_true.t(), 
           extent=(qx[0], qx[-1], qy[0], qy[-1]), cmap='hot',
           norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
plt.xlabel(r'Scattering vector, $q_x$ ($\AA^{-1}$)')
plt.ylabel(r'Scattering vector, $q_y$ ($\AA^{-1}$)')
plt.colorbar()
plt.savefig(output_dir / 'true_intensity.png', bbox_inches='tight', facecolor='w')
plt.show()

---

#  Inversion

With the G-based SAS system defined, inversion only requires one line, taking the simulated intensity data (`I_truth`) as input (for both mean and stddev). We perform 300 iterations by sending `maxiter=300` and save the results every 10 steps by `save_iter=10`.

In [None]:
# invert the G-system
res_dict = g_sys.solve_inverse(I_true, I_true, auto_scaling=True, 
                               maxiter=300, verbose=2, save_iter=10)

Plot the convergence history of the weights of $r_p$, $r_e$, $\theta$, $\phi$:

In [None]:
# scaling factor to plot 1/sensitivity as uncertainty 
scaling_factor_sensitivity = .5
par_keys = ['rp', 're', 'theta', 'phi']
par_keys_latex = [r'$r_p$', r'$r_e$', r'$\theta$', r'$\phi$']

# plot history
for i, saved_res_dict in enumerate(res_dict['saved_res']):
    fig, ax = plt.subplots(1, 4, dpi=200, figsize=(15, 2))
    for ikey, key in enumerate(par_keys):
        # truth
        ax[ikey].plot(par_dict[key], w_true_dict[key], label=f'Truth')
        # fitted
        ax[ikey].plot(par_dict[key], saved_res_dict['w_dict'][key], label=f'MLE')
        # sensitivity
        sens = saved_res_dict['sens_w_dict'][key] / saved_res_dict['sens_w_dict'][key].sum()
        ax[ikey].fill_between(par_dict[key], 
                              saved_res_dict['w_dict'][key] - abs(scaling_factor_sensitivity * sens), 
                              saved_res_dict['w_dict'][key] + abs(scaling_factor_sensitivity * sens), 
                              alpha=.3, color='gray', zorder=-100, label=r'Sens${}^{-1}$')
        # settings
        ax[ikey].set_xlabel(par_keys_latex[ikey])
        ax[ikey].set_ylim(-.01, w_true_dict[key].max() * 1.1)
        ax[ikey].set_xlim(par_dict[key].min(), par_dict[key].max())
        ax[ikey].set_xticks([])
    ax[0].set_ylabel(f'iter={saved_res_dict["nit"]}')
    ax[2].legend(loc='upper right')
    plt.savefig(output_dir / f'iterations/{i}.png', bbox_inches='tight', facecolor='w')
    plt.show()

Finally, plot the fitted intensity and misfit:

In [None]:
fig, ax = plt.subplots(1, 3, dpi=200, figsize=(12, 4), sharex=True, sharey=True)
plt.subplots_adjust(wspace=.2)
# truth
cm = ax[0].imshow(I_true.t(), extent=(qx[0], qx[-1], qy[0], qy[-1]), cmap='hot',
                  norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
ax[0].set_title('True intensity')
ax[0].set_xlabel(r'Scattering vector, $q_x$ ($\AA^{-1}$)')
ax[0].set_ylabel(r'Scattering vector, $q_y$ ($\AA^{-1}$)')
# fitted
ax[1].imshow(res_dict['I'].t(), extent=(qx[0], qx[-1], qy[0], qy[-1]), cmap='hot',
             norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
ax[1].set_title('Fitted intensity')
ax[1].set_xlabel(r'Scattering vector, $q_x$ ($\AA^{-1}$)')
# error
im = ax[2].imshow(torch.abs(res_dict['I'].t() - I_true.t()), 
                  extent=(qx[0], qx[-1], qy[0], qy[-1]), cmap='hot',
                  norm=colors.LogNorm(vmin=I_true.min(), vmax=I_true.max()))
ax[2].set_title('Misfit')
ax[2].set_xlabel(r'Scattering vector, $q_x$ ($\AA^{-1}$)')

# colorbar
cbar_ax = fig.add_axes([.92, 0.15, 0.02, 0.7])
fig.colorbar(im, cax=cbar_ax)
plt.savefig(output_dir / 'intensity_fit.png', bbox_inches='tight', facecolor='w')
plt.show()

---