# Example: benchmark for polydisperse spheres

In this notebook, we conduct a benchmark for size distribution inversion of polydisperse spheres using synthetic data. The ground truth of the size distribution contains two populations, a Gaussian and a Boltzmann. We benchmark four codes: `Irena`, `SasView`, `McSAS` and `ffsas`. 

This notebook uses the [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html) unit system.

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate
import torch

from ffsas.models import Sphere
from ffsas.system import SASGreensSystem

# avoid an OMP error on MacOS (nothing to do with ffsas)
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [None]:
# reproduce figures in the paper 
reproduce_paper_fig = True
if reproduce_paper_fig:
    # this will trigger an error if latex is not installed
    plt.rcParams.update({
        "text.usetex": True,
        "text.latex.preamble": r'\usepackage{bm,upgreek}',
        "font.family": "sans-serif",
        "font.serif": ["Times"]})
    # figure dir
    paper_fig_dir = Path('../paper_figs')
    Path(paper_fig_dir).mkdir(parents=True, exist_ok=True)

# Ground truth


### Ground truth of radius distribution

The ground truth of the radius distribution contains two populations, a Gaussian and a Boltzmann.

In [None]:
# radius vector
r = torch.linspace(400., 800., 500)

# weights
gaussian = torch.exp(-(r - 500.) ** 2 / (2 * 10 ** 2))
boltzmann = .7 * torch.exp(-torch.abs(r - 700.) / 20)
w_true = gaussian + boltzmann
w_true /= w_true.sum()  # normalize

# plot
plt.figure(dpi=100)
plt.plot(r, w_true)
plt.xlabel(r'Radius, $r$ ($\AA$)')
plt.ylabel(r'Weights, $w$')
plt.show()

### Ground truth of intensity

Now, based on the above radius distribution, we compute the ground truth of intensity. First, we need to compute the Green's tensor `G` and the SAS system.

In [None]:
# q vector in logscale
q = 10 ** torch.linspace(-3, 0, 200)

# contrast, (SLD - SLD_solvent) ^ 2
drho = 1.

# compute the Green's tensor
G = Sphere.compute_G_mini_batch([q], {'r': r}, {'drho': drho}, log_screen=True)

# define the G-based SAS system
g_sys = SASGreensSystem(G, Sphere.get_par_keys_G(), log_screen=True)

Now we need ground truth for `xi` and `b`. Note that the parameter `scale` in [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html) is not the same as $\xi$ in `ffsas`. For the particular unit system of [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html), the relation between `scale` and $\xi$ is

$$\xi=10^{-4}\times\dfrac{\mathrm{scale}}{V_\text{ave}},$$

where $10^{-4}$ comes from the unit system and $V_\text{ave}$ is the average volume. The `background` in [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html) has the same definition as $b$ in `ffsas`.

In [None]:
# ground truth of scale and background
scale_true = 2.
b_true = .5

# compute the ground truth of xi
V = Sphere.compute_V({'r': r})
V_ave = torch.dot(V, w_true)
xi_true = 1e-4 * scale_true / V_ave

Finally compute the ground truth of intensity and add some uncertainty to it.

In [None]:
# intensity
I_true = g_sys.compute_intensity({'r': w_true}, xi_true, b_true)

# add a 20%~30% error bar
torch.random.manual_seed(0)
I_true_std = (torch.rand(len(q)) * .1 + .2) * I_true

# plot
plt.figure(dpi=100)
plt.errorbar(q, I_true, yerr=I_true_std, ecolor='gray')
plt.xscale('log')
plt.yscale('log')
plt.xlabel(r'Scattering vector, $q$ ($\AA^{-1}$)')
plt.ylabel(r'Intensity, $I$ ($\mathrm{cm}^{-1}$)')
plt.show()

# save intensity data
np.savetxt('observation.txt', 
           torch.stack([q, I_true, I_true_std]).t().numpy())
# McSAS use nm^-1 for q
np.savetxt('mcsas/observation_McSAS.txt', 
           torch.stack([q * 10, I_true, I_true_std]).t().numpy())

---

#  Inversion

Now we invert for the radius distribution using `ffsas`. Just one line:

In [None]:
# solve the inverse problem using true intensity
result_dict = g_sys.solve_inverse(I_true, I_true_std, maxiter=200, verbose=1,
                                  trust_options={'xtol': 0., 'gtol': 0.})

---

# Visualizing results

Now we plot the results of `Irena`, `SasView`, `McSAS` and `ffsas`. Screenshots are provided in `./screenshots` for users to reproduce the results from `Irena`, `SasView` and `McSAS`.

### Radius distributions

In [None]:
# plot settings
alpha = .25
lw = 1.5
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

style_irena = colors[0], _, 's', 4
style_sasview = colors[1], _, 'D', 3
style_mcsas = colors[3], _, '+', 4
style_ffsas = colors[2], _, 'x', 2

fontsize = 17
plt.rcParams.update({'font.size': fontsize})
plt.rcParams.update({'legend.fontsize': fontsize})
plt.rcParams.update({'axes.titlesize': fontsize})

In [None]:
# figure
fig = plt.figure(dpi=200, figsize=(6, 3.5))
plt.xlabel(r'Radius, $r$ (\AA)')
plt.ylabel(r'Weights, $w$')
plt.title(r'(a) Comparing $w(r)$ by different methods')

# downsample to make plot cleaner
st = 4
st1 = 3
wh = list(range(0, 80, st)) + list(range(80, 160, st1)) + list(range(160, len(r), st))

# truth
plot_data = plt.errorbar(r[wh], w_true[wh] * 100, yerr=None, 
                         c='k', fmt='-', zorder=-100, label='Truth')
plt.xlim(r[0], r[-1])

# Irena
vol_frac = np.loadtxt('irena/output_wv(r).txt')[:len(r)]
w_irena = vol_frac / r ** 3
w_irena /= w_irena.sum()
plot_irena = plt.errorbar(r[wh], w_irena[wh] * 100, yerr=None, 
                          c=style_irena[0], fmt=style_irena[2],  
                          markersize=style_irena[3], label='Irena')

# SasView
# these numbers are from the screenshots
A_scale = 0.24676
A_radius = 500.02
B_scale = 0.75329
B_radius = 699.97
A_PD = 0.020124
B_PD = 0.028707
# sigma
A_sigma = A_radius * A_PD
B_sigma = B_radius * B_PD
# scale in SasView are volumn-weighted
A_shape = torch.exp(-(r - A_radius) ** 2 / (2 * A_sigma ** 2))
B_shape = torch.exp(-torch.abs(r - B_radius) / B_sigma)
A_vol = (A_shape * r ** 3).sum()
B_vol = (B_shape * r ** 3).sum()
A_xi = A_scale / A_vol
B_xi = B_scale / B_vol
w_sasview = A_xi * A_shape + B_xi * B_shape
w_sasview /= w_sasview.sum()
plot_sasview = plt.errorbar(r[wh], w_sasview[wh] * 100, yerr=None, 
                            c=style_sasview[0], fmt=style_sasview[2],  
                            markersize=style_sasview[3], label='SasView')

# McSAS
mcsas_data = np.loadtxt('mcsas/output_w(r).dat', skiprows=1)
r_low = mcsas_data[:, 0] * 1e10
w_low = mcsas_data[:, 2]
std_w_low = mcsas_data[:, 3]
# interpolate result to original resolution of r 
# because the maximum bin number of McSAS output is 200
w_mcsas = interpolate.interp1d(r_low, w_low, fill_value="extrapolate")(r)
std_w_mcsas = interpolate.interp1d(r_low, std_w_low, fill_value="extrapolate")(r)
norm = w_mcsas.sum()
w_mcsas /= norm
std_w_mcsas /= norm
plot_mcsas = plt.errorbar(r[wh], w_mcsas[wh] * 100, yerr=std_w_mcsas[wh] * 100, 
                          c=style_mcsas[0], fmt=style_mcsas[2],  
                          markersize=style_mcsas[3], capsize=1, lw=.5, label='McSAS')

# ffsas
w_ffsas = result_dict['w_dict']['r']
plot_ffsas = plt.errorbar(r[wh], w_ffsas[wh] * 100, yerr=result_dict['std_w_dict']['r'][wh] * 100, 
                          c=style_ffsas[0], fmt=style_ffsas[2],  
                          markersize=style_ffsas[3], capsize=1, lw=.5, label='FFSAS')

# plot and save
plt.legend(handlelength=1, facecolor='whitesmoke')
if reproduce_paper_fig:
    plt.savefig(paper_fig_dir / 'bench_w.pdf', 
                bbox_inches='tight', facecolor='w', pad_inches=.1)
plt.show()

### Intensity curves

In [None]:
# figure
fig=plt.figure(dpi=200, figsize=(6, 3.5))
plt.xscale('log')
plt.yscale('log')
plt.xlabel(r'Scattering vector, $q$ (\AA$^{-1}$)')
plt.ylabel(r'Intensity, $I$ ($\mathrm{cm}^{-1}$)')
plt.title(r'(b) Comparing $I(q)$ by different methods')

# truth
st = 2
plot_data = plt.errorbar(q[::st], I_true[::st], yerr=I_true_std[::st] * 3, capsize=1, lw=lw, elinewidth=.5,
                         c='k', fmt='-', zorder=-100, label='Truth')
plt.xlim(q[0], q[-1])

# Irena
I_irena = np.loadtxt('irena/output_I(q).txt')
I_irena += b_true  # Irena saves data without background
plot_mcsas = plt.errorbar(q[::st], I_irena[::st], yerr=None, c=style_irena[0], lw=0, 
                          fmt=style_irena[2], markersize=style_irena[3], label='Irena')

# SasView
I_sasview = np.loadtxt('sasview/output_I(q).txt', skiprows=1)[:, 1]
plot_sasview = plt.errorbar(q[::st], I_sasview[::st], yerr=None, c=style_sasview[0], lw=0, 
                            fmt=style_sasview[2], markersize=style_sasview[3], label='SasView')

# McSAS
mc_data = np.loadtxt('mcsas/output_I(q).dat', skiprows=1)
q_low = mc_data[:, 0] / 1e10
I_low = mc_data[:, 3]
I_mcsas = interpolate.interp1d(q_low, I_low, fill_value="extrapolate")(q)
plot_mcsas = plt.errorbar(q[::st], I_mcsas[::st], yerr=None, c=style_mcsas[0], lw=0, 
                          fmt=style_mcsas[2], markersize=style_mcsas[3], label='McSAS')

# # ffsas
I_ffsas = result_dict['I']
plot_ffsas = plt.errorbar(q[::st], I_ffsas[::st], yerr=None, c=style_ffsas[0], lw=0, 
                          fmt=style_ffsas[2], markersize=style_ffsas[3], label='FFSAS')

# plot and save
plt.legend(handlelength=1, facecolor='whitesmoke')
if reproduce_paper_fig:
    plt.savefig(paper_fig_dir / 'bench_I.pdf', 
                bbox_inches='tight', facecolor='w', pad_inches=.1)
plt.show()

### Metrics

In [None]:
# chi2
def compute_chi2(I_fit):
    if not isinstance(I_fit, torch.Tensor):
        I_fit = torch.from_numpy(I_fit)
    return torch.norm((I_fit - I_true) / I_true_std) ** 2
    
print(f'Intensity chi2')
print(f'Irena: {compute_chi2(I_irena):.0e}')
print(f'SasView: {compute_chi2(I_sasview):.0e}')
print(f'McSAS: {compute_chi2(I_mcsas):.0e}')
print(f'FFSAS: {compute_chi2(I_ffsas):.0e}')
print()

# delta w
def compute_delta_w_norm(w_fit):
    if not isinstance(w_fit, torch.Tensor):
        w_fit = torch.from_numpy(w_fit)
    return torch.norm(w_fit - w_true)

def compute_delta_w_max(w_fit):
    if not isinstance(w_fit, torch.Tensor):
        w_fit = torch.from_numpy(w_fit)
    return torch.max(torch.abs(w_fit - w_true))


print(f'|delta w|')
print(f'Irena: {compute_delta_w_norm(w_irena):.0e}')
print(f'SasView: {compute_delta_w_norm(w_sasview):.0e}')
print(f'McSAS: {compute_delta_w_norm(w_mcsas):.0e}')
print(f'FFSAS: {compute_delta_w_norm(w_ffsas):.0e}')
print()

print(f'max(delta w)')
print(f'Irena: {compute_delta_w_max(w_irena):.0e}')
print(f'SasView: {compute_delta_w_max(w_sasview):.0e}')
print(f'McSAS: {compute_delta_w_max(w_mcsas):.0e}')
print(f'FFSAS: {compute_delta_w_max(w_ffsas):.0e}')
print()

print('wct / sec')
print(f'Irena: 2')
print(f'SasView: 0.0979')
print(f'McSAS: {3.30376664797 * 60}')
print(f'FFSAS: {result_dict["wct"]}')

---