# Exampe: sphere

This notebook shows how to use `ffsas` to invert for the radius distribution of a `sphere` model. It uses the [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html) unit system.

In [None]:
# avoid omp error on Mac
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# plotting setup
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 12})
matplotlib.rcParams.update({'legend.fontsize': 11})
matplotlib.rcParams.update({'axes.titlesize': 12})
matplotlib.rcParams.update({'lines.linewidth': 2})
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

# create output dir
from pathlib import Path
output_dir = Path('./output/sphere')
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [None]:
# uncomment this line to install ffsas
# !pip install ffsas

# ffsas
import torch
from ffsas.models import Sphere
from ffsas.system import SASGreensSystem

# math tools
from scipy import interpolate

# Ground truth

To do inversion, we need some intensity data as input. Here we create the data by modelling, using a noisy radius distribution, or the ground truth.  

### Ground truth of radius distribution

The following function creates a "crazy" distribution by adding up Gaussians and random noise.

In [None]:
def crazy_distribution(x, gaussians, noise_level, fade_start, fade_end, seed=0):
    # create
    w_true = torch.zeros(x.shape)
    
    # add Gaussians
    for factor, mean, stddev in gaussians:
        w_true += factor * torch.exp(-((x - mean) / stddev) ** 2)
    
    # add noise
    torch.random.manual_seed(seed)
    w_true += noise_level * torch.rand(x.shape) * torch.rand(x.shape)
    
    # fade both ends to make it look nicer
    w_true[0:fade_start] = 0.
    w_true[fade_start:fade_end] *= torch.linspace(0, 1, fade_end - fade_start)
    w_true[-fade_start:] = 0.
    w_true[-fade_end:-fade_start] *= torch.linspace(1, 0, fade_end - fade_start)
    
    # normalize to 1
    w_true /= torch.sum(w_true)
    return w_true

Make a radius distribution:

In [None]:
# radius vector
r = torch.linspace(500., 1000., 300)

# Make a crazy radius distribution with Gaussian and random
w_true = crazy_distribution(r, [(4, 580, 10), (6, 630, 20), (10, 700, 20), 
                                (12, 750, 20), (8, 850, 15), (5, 930, 15)],
                            noise_level=10, fade_start=10, fade_end=40)

# upsample weights to a higher resolution for later use
high_reso = 400
r_high = torch.linspace(r[0], r[-1], high_reso)
w_true_high = torch.tensor(interpolate.interp1d(r, w_true)(r_high))
w_true_high /= torch.sum(w_true_high)

### Ground truth of intensity

Now, based on the above radius distribution, we compute the ground truth of intensity. 

Note that the parameter `scale` in [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html) is not the $\xi$ in `ffsas`. For the particular unit system of [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html), the relation between `scale` and $\xi$ is

$$\xi=10^{-4}\times\dfrac{\mathrm{scale}}{V_\text{ave}},$$

where $10^{-4}$ comes from the unit system and $V_\text{ave}$ is the average volume. The `background` in [SASView/SASModels](http://www.sasview.org/docs/user/models/sphere.html) has the same definition as $b$ in `ffsas`.

In [None]:
# ground truth of scale and background
scale_true = 2.5
b_true = .14

# q vector
q = torch.linspace(.001, 1, 2000)

# (SLD - SLD_solvent) ^ 2
drho = 25.

# comupute the Green's function and volume
G = Sphere.compute_G_mini_batch([q], {'r': r}, {'drho': drho})

# compute the ground truth of xi
V = Sphere.compute_V({'r': r})
V_ave = torch.dot(V, w_true)
xi_true = 1e-4 * scale_true / V_ave

# define the G-based SAS system
g_sys = SASGreensSystem(G, Sphere.get_par_keys_G(), log_screen=True)

# finally compute the ground truth of intensity
I_true = g_sys.compute_intensity({'r': w_true}, xi_true, b_true)

Plot the ground truths:

In [None]:
fig, ax = plt.subplots(dpi=200, ncols=2, figsize=(15, 4))
ax[0].plot(r, w_true)
ax[0].set_xlabel(r'Radius, $r$ ($\AA$)')
ax[0].set_ylabel(r'Weights, $w$')
ax[1].plot(q, I_true)
ax[1].set_xscale('log')
ax[1].set_yscale('log')
ax[1].set_xlabel(r'Scattering vector, $q$ ($\AA^{-1}$)')
ax[1].set_ylabel(r'Intensity, $I$ ($\mathrm{cm}^{-1}$)')
plt.savefig(f'{output_dir}/truth.png', bbox_inches='tight', facecolor='w')
plt.show()

---

#  Inversion

Now we solve the inverse problem at different radius resolutions. It is expected that we "exactly" reproduce the ground truth with the resolution it was created, or 300 (`r = torch.linspace(500., 1000., 300)`).

In [None]:
# solve at five resolutions
resolutions = [50, 100, 200, 300, 400]
results = []

# to avoid some drifting at low resolution 50, we do not use 1.0
nu_sigma = 0.8  

# loop over resolutions
for reso in resolutions:
    # resampled r vector at given resolution
    r_reso = torch.linspace(r[0], r[-1], reso)
    
    # recompute G
    G_reso = Sphere.compute_G_mini_batch([q], {'r': r_reso}, {'drho': drho}, log_screen=False)
    
    # define the system
    g_sys_reso = SASGreensSystem(G_reso, par_keys=Sphere.get_par_keys_G(), 
                                 log_file=output_dir / f'resolution{reso}.log', 
                                 log_screen=False)  # log only to files, not displayed on screen
    
    # solve the inverse problem using "true" intensity
    # NOTE: we do not have sigma or data uncertainty, so we use mu as sigma
    result_dict = g_sys_reso.solve_inverse(I_true, I_true, nu_sigma=nu_sigma,
                                           auto_scaling=True, maxiter=1000, verbose=2)
    
    # get weights from result dict
    w_reso = result_dict['w_dict']['r']
    sens_w_reso = result_dict['sens_w_dict']['r']
    
    # compute scale 
    scale = result_dict['xi'] * V_ave / 1e-4

    # upsample w to high resolution so we can plot them together
    w_high = torch.tensor(interpolate.interp1d(r_reso, w_reso)(r_high))
    w_high /= torch.sum(w_high)
    sens_w_high = torch.tensor(interpolate.interp1d(r_reso, sens_w_reso)(r_high))
    sens_w_high /= torch.sum(sens_w_high)
    results.append((reso, w_reso, w_high, sens_w_high, scale, result_dict['xi'], result_dict['b'], 
                    result_dict['I'], result_dict['wct']))

Finally, plot the results at the different resolutions:

In [None]:
# scaling factor to plot 1/sensitivity as uncertainty 
scaling_factor_sensitivity = 3e-6

fig, ax = plt.subplots(nrows=len(results), ncols=2, dpi=200, figsize=(16, len(results) * 3))
plt.subplots_adjust(hspace=.3)
if len(results) == 1:
    ax = [ax]

for i, (reso, w_reso, w_high, sens_w_high, scale, xi, b, I, wct) in enumerate(results):
    #######################
    # radius distribution #
    #######################
    # truth
    ax[i][0].scatter(r_high, w_true_high, lw=0, c=colors[0], s=30, label=r'Truth')
    # inverted
    ax[i][0].plot(r_high, w_high, c=colors[1], label=r'MLE')
    dw_norm = torch.norm(w_true_high - w_high) ** 2 / len(w_high)
    # 1/sensitivity
    sens_scaled = scaling_factor_sensitivity / sens_w_high
    ax[i][0].fill_between(r_high, w_high - abs(sens_scaled), w_high + abs(sens_scaled), 
                          alpha=.3, color='gray', zorder=-100, label=r'Sens${}^{-1}$')
    ax[i][0].set_title(r'Res.=%d: scale=%.2f, $b$=%.3f, $|\Delta w|^2$=%.2E' % (reso, scale, b, dw_norm))
    ax[i][0].set_ylim(-0.001, 0.01)
    ax[i][0].set_ylabel(r'Weights, $w$')
    if i != len(results) - 1:
        ax[i][0].set_xticklabels([])
    else:
        ax[i][0].set_xlabel(r'Radius, $r$ ($\AA$)')
    
    #####################
    # intensity fitting #
    #####################
    chi2 = torch.norm((I_true - I) / I_true ** nu_sigma) ** 2 / len(I_true)
    # truth
    ax[i][1].scatter(q, I_true, lw=0, c=colors[2], s=30, label=r'Truth')
    # fitted
    ax[i][1].plot(q, I, c=colors[3], label=r'Fitted')
    ax[i][1].set_title(r'Res.=%d: $\chi^2$=%.2E, wct=%.1f sec' % (reso, chi2, wct))
    ax[i][1].set_xscale('log')
    ax[i][1].set_yscale('log')
    ax[i][1].set_ylabel(r'Intensity, $I$ ($\mathrm{cm}^{-1}$)')
    if i != len(results) - 1:
        ax[i][1].set_xticklabels([])
    else:
        ax[i][1].set_xlabel(r'Scattering vector, $q$ ($\AA^{-1}$)')
    
    # add some text
    if reso == 300:
        ax[i][0].text(.97, .92, 'Exact', transform=ax[i][0].transAxes, 
                     ha='right', va='top', fontsize=18, color='k')
        ax[i][1].text(.97, .92, 'Exact', transform=ax[i][1].transAxes, 
                     ha='right', va='top', fontsize=18, color='k')

order = [1, 0, 2]
handles, labels = ax[0][0].get_legend_handles_labels()
ax[0][0].legend([handles[idx] for idx in order], [labels[idx] for idx in order])
handles, labels = ax[0][1].get_legend_handles_labels()
order = [1, 0]
ax[0][1].legend([handles[idx] for idx in order], [labels[idx] for idx in order])
plt.savefig(output_dir / 'results.png', bbox_inches='tight', facecolor='w')
plt.show()

---