# Example: sphere with I22 dataset

This notebook shows how to use `ffsas` to invert for the radius distribution of a `Sphere` model from a real SAXS dataset called "I22". We also compare the results from `Irena`, `SasView` and `McSAS`.

In [None]:
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy import interpolate
import torch

import ffsas
from ffsas.models import Sphere
from ffsas.system import SASGreensSystem

# avoid an OMP error on MacOS (nothing to do with ffsas)
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [None]:
# reproduce figures in the paper 
reproduce_paper_fig = True
if reproduce_paper_fig:
    # this will trigger an error if latex is not installed
    plt.rcParams.update({
        "text.usetex": True,
        "text.latex.preamble": r'\usepackage{bm,upgreek}',
        "font.family": "sans-serif",
        "font.serif": ["Times"]})
    # figure dir
    paper_fig_dir = Path('../paper_figs')
    Path(paper_fig_dir).mkdir(parents=True, exist_ok=True)

# Read data

Data are stored in the text file `observation_corrected.txt`, with the three columns being $q$, mean and standard deviation of the observed intensity.

In [None]:
# read data
fname = f'observation_corrected.txt'
data = np.loadtxt(fname)

# q vector
q = torch.tensor(data[:, 0], dtype=ffsas.torch_dtype)

# intensity mean
mu = torch.tensor(data[:, 1], dtype=ffsas.torch_dtype)

# intensity stddev
sigma = torch.tensor(data[:, 2], dtype=ffsas.torch_dtype)

# McSAS use nm^-1 for q
np.savetxt('mcsas/observation_corrected_McSAS.txt', 
           torch.stack([q * 10, mu, sigma]).t().numpy())

The above data for intensity mean are not the raw data. We have applied a correction to the raw data, which is a linear transform in the log scale. Let's visualize it:

In [None]:
# read raw data
mu_raw = np.loadtxt('observation_raw.txt')[:, 1]
plt.figure(dpi=150, figsize=(6, 3))
plt.plot(q, mu_raw, label='Mean of $I$, raw')
plt.plot(q, mu, label='Mean of $I$, corrected')
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Scattering vector, $q$ (\AA$^{-1}$)')
plt.ylabel('Intensity, $I$ ($\mathrm{cm}^{-1}$)')
plt.legend()
plt.show()

# Inversion

Just a few lines to do inversion:

In [None]:
# specify radii
r = torch.linspace(400, 1200, 1000)

# compute the Green's tensor
G = Sphere.compute_G_mini_batch([q], {'r': r}, {'drho': 1.})

# build G-system
g_sys = SASGreensSystem(G, Sphere.get_par_keys_G())

# inversion
# do 500 iterations and save every 100 iterations
results = g_sys.solve_inverse(mu, sigma, maxiter=500, save_iter=100, 
                              trust_options={'xtol': 0, 'gtol':0}, verbose=1)

---

# Plot results

First plot `ffsas` results at different iterations.

In [None]:
# volume
v = r ** 3

# colormap
cmap = matplotlib.cm.get_cmap('turbo_r')

# steps to plot (every 100 iters)
n_results = len(results['saved_res'])
plot_steps = range(n_results)

# plot
fig=plt.figure(dpi=200, figsize=(7/1.4, 3.5/1.4))
for j, step in enumerate(plot_steps):
    w = results['saved_res'][step]['w_dict']['r']
    w_hat = w * v / (w * v).sum() * 100  # x100 to percent
    plt.plot(r, w_hat, zorder=-j, c=cmap(step / (n_results - 1)),
             label=r'$w(r)$, iters=%d, wct=%.1f sec' % 
             (results['saved_res'][step]['nit'], 
              results['saved_res'][step]['wct']))
    
plt.xlim(r.min(), r.max())
plt.ylabel(r'Volume weight, $\hat{w}$ (\%)')
plt.xlabel(r'Radius, $r$ (\AA)')
plt.title(r'Convergence of $\hat{w}(r)$ in FFSAS')


# Gaussian approximations of populations
r_ranges=[[420, 800],
          [800, 1180]]

area_all_ffsas = []
for i, (r_min, r_max) in enumerate(r_ranges):
    # find peak
    i_min = torch.argmin(torch.abs(r - r_min))
    i_max = torch.argmin(torch.abs(r - r_max))
    max_loc = torch.argmax(w_hat[i_min:i_max])
    r_top = r[i_min + max_loc]
    
    # find stddev
    area_all = torch.sum(w_hat[i_min:i_max])
    area_all_ffsas.append(area_all)
    for stddev in range(1, 50):
        area = torch.sum(w_hat[i_min + max_loc - stddev:i_min + max_loc + stddev])
        if area >= area_all * .68:
            break

    # texts
    sig = round(stddev / len(r) * (r.max() - r.min()).item())
    if r_max == 800:
        plt.text(r_top - 1, w_hat[i_min + max_loc] - 8, 
                 r'$\mathcal{N}(%d,%d^2)$' % (round(r_top.item()), sig), 
                 ha='left', va='center', fontsize=12, rotation=45)
    else:
        plt.text(r_top - 20, max(w_hat[i_min + max_loc] - 6, .7), 
                 r'$\mathcal{N}(%d,%d^2)$' % (round(r_top.item()), sig), 
                 ha='left', va='bottom', fontsize=12, rotation=45)
plt.ylim(None, 40)

# legend                 
norm = matplotlib.colors.Normalize(vmin=0, vmax=1000)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbaxes = fig.add_axes([0.18, 0.73, 0.45, .025]) 
cb = plt.colorbar(sm, ticks=np.arange(100,1001,100), 
                  boundaries=np.arange(50,1101,100),
                  cax=cbaxes, orientation='horizontal')
cb.ax.tick_params(labelsize=12) 
cb.ax.set_title('Num. iterations', fontsize=12)
cb.ax.tick_params(axis='x', labelrotation = 45)
plt.show()

# volume ratio of the two populations
v1 = area_all_ffsas[0] / (area_all_ffsas[0] + area_all_ffsas[1])
print(f'FFSAS v1/v2 = {v1:.2f}:{1 - v1:.2f}')

Now compare `Irena`, `SasView` and `McSAS` results. 

In [None]:
#### Irena ####
w_hat_irena = torch.tensor(np.loadtxt('irena/output_wv(r).txt')[250:1250])
w_hat_irena /= w_hat_irena.sum()

area_all_irena = []
for i, (r_min, r_max) in enumerate(r_ranges):
    # find peak
    i_min = torch.argmin(torch.abs(r - r_min))
    i_max = torch.argmin(torch.abs(r - r_max))
    max_loc = torch.argmax(w_hat_irena[i_min:i_max])
    r_top = r[i_min + max_loc]
    
    # find stddev
    area_all = torch.sum(w_hat_irena[i_min:i_max])
    area_all_irena.append(area_all)
    for stddev in range(1, 50):
        area = torch.sum(w_hat_irena[i_min + max_loc - stddev:i_min + max_loc + stddev])
        if area >= area_all * .68:
            break
    print(f'Irena N({r_top}, {round(stddev / len(r) * (r.max() - r.min()).item())}^2)')        

# volume ratio of the two populations
v1 = area_all_irena[0] / (area_all_irena[0] + area_all_irena[1])
print(f'Irena v1/v2 = {v1:.2f}:{1 - v1:.2f}')

# #### SASView ####
# from screenshots
r_mean1, PD1, scale1, r_mean2, PD2, scale2, scale_sasview, b_sasview = \
614.15, 0.0028012, 0.002255, 1058.6, 2.8571e-08, 0.0015457, 1, -0.0045602 
sigm1 = r_mean1 * PD1
sigm2 = r_mean2 * PD2
print(f'SasView N({r_mean1}, {sigm1}^2)')
print(f'SasView N({r_mean2}, {sigm2}^2)')
v1 = scale1 / (scale1 + scale2)
print(f'SasView v1/v2 = {v1:.2f}:{1 - v1:.2f}')


#### MCSAS ####
# read
data = np.loadtxt('mcsas/output_w(r).dat', skiprows=1)
w_mcsas = data[:, 2]

# convert to the same resolution
x = np.linspace(400, 1200, len(w_mcsas))
w_mcsas = torch.from_numpy(interpolate.interp1d(x, w_mcsas)(r))
w_mcsas_hat = w_mcsas * r ** 3 / (w_mcsas * r ** 3).sum()

# Gaussian approximation
area_all_mcsas = []
for i, (r_min, r_max) in enumerate(r_ranges):
    # find peak
    i_min = torch.argmin(torch.abs(r - r_min))
    i_max = torch.argmin(torch.abs(r - r_max))
    max_loc = torch.argmax(w_mcsas_hat[i_min:i_max])
    r_top = r[i_min + max_loc]
    
    # find stddev
    area_all = torch.sum(w_mcsas_hat[i_min:i_max])
    area_all_mcsas.append(area_all)
    for stddev in range(1, 50):
        area = torch.sum(w_mcsas_hat[i_min + max_loc - stddev:i_min + max_loc + stddev])
        if area >= area_all * .68:
            break
    print(f'McSAS N({r_top}, {round(stddev / len(r) * (r.max() - r.min()).item())}^2)')        
    
# volume ratio of the two populations
v1 = area_all_mcsas[0] / (area_all_mcsas[0] + area_all_mcsas[1])
print(f'McSAS v1/v2 = {v1:.2f}:{1 - v1:.2f}')

Finally, we compare the intensity fit. 

In [None]:
plt.rcParams.update({'font.size': 12.5})
plt.rcParams.update({'legend.fontsize': 12.5})
plt.rcParams.update({'axes.titlesize': 12.5})

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
style_irena = colors[0], '-.'
style_sasview = colors[1], [1, 1]
style_mcsas = colors[3], '-'
style_ffsas = colors[2], '--'

In [None]:
fig=plt.figure(dpi=200, figsize=(7.5 / 1.4, 4.1 / 1.4))

# data
plt.plot(q, mu_raw, c='gray', zorder=-200, alpha=.5)
plt.errorbar(q, mu, yerr=sigma, c='pink', ecolor='skyblue', lw=1, fmt='o',
             markersize=3, label=r'Corrected data', zorder=-300)

# Irena
I_irena = np.loadtxt('irena/output_I(q).txt')
chi2 = np.linalg.norm((mu[:len(I_irena)] - I_irena) / sigma[:len(I_irena)]) ** 2
plt.plot(q[:len(I_irena)], I_irena, c=style_irena[0], ls=style_irena[1], lw=1.5, zorder=0,
         label=r'Irena (IPG/TNNLS): $\chi^2$=%.2f (no high-$q$),' % (chi2))
plt.plot(q[:len(I_irena)], I_irena, c='whitesmoke', ls=style_irena[1], lw=1.5, zorder=-10000000, 
         label=r'$\hat{w}\approx60\%\mathcal{N}(620, 39^2)+40\%\mathcal{N}(1066, 38^2)$')

# SASView
I_sasview = np.loadtxt('sasview/output_I(q).txt', skiprows=1)[:, 1]
I = I_sasview - b_sasview
chi2 = np.linalg.norm((mu - I_sasview) / sigma) ** 2
plt.plot(q, I, c=style_sasview[0], dashes=style_sasview[1], lw=1.5, zorder=0,
         label=r'SasView: $\chi^2$=%.2f,' % (chi2))
plt.plot(q, I, c='whitesmoke', dashes=style_sasview[1], lw=1.5, zorder=-10000000, 
         label=r'$\hat{w}=59\%\mathcal{N}(614, 2^2)+41\%\mathcal{N}(1059, 0^2)$')

# McSAS
mc_data = np.loadtxt('mcsas/output_I(q).dat', skiprows=1)
q_low = mc_data[:, 0] / 1e10
I_low = mc_data[:, 3]
I_mcsas = interpolate.interp1d(q_low, I_low, fill_value="extrapolate")(q) 
chi2 = np.linalg.norm((mu - I_mcsas) / sigma) ** 2
b_mcsas = -0.00487
I = I_mcsas - b_mcsas
plt.plot(q, I, c=style_mcsas[0], ls=style_mcsas[1], lw=1.5, zorder=0,
         label=r'McSAS: $\chi^2$=%.2f,' % (chi2))
plt.plot(q, I, c='whitesmoke', ls=style_mcsas[1], lw=1.5, zorder=-10000000, 
         label=r'$\hat{w}\approx58\%\mathcal{N}(629, 20^2)+42\%\mathcal{N}(1067, 15^2)$')

# FFSAS
I_ffsas = results['I']
I = I_ffsas - results['b']
chi2 = np.linalg.norm((mu - I_ffsas) / sigma) ** 2
plt.plot(q, I, c=style_ffsas[0], ls=style_ffsas[1], lw=1.5, zorder=0,
         label=r'FFSAS: $\chi^2$=%.2f,' % (chi2))
plt.plot(q, I, c='whitesmoke', ls=style_ffsas[1], lw=1.5, zorder=-10000000, 
         label=r'$\hat{w}\approx58\%\mathcal{N}(630, 1^2)+42\%\mathcal{N}(1060, 1^2)$')


plt.xlim(q[0] / 1.1, q[-50])
plt.ylim(.5e-4, 10**3.5)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Scattering vector, $q$ (\AA$^{-1}$)')
plt.ylabel('Intensity, $I$ ($\mathrm{cm}^{-1}$)')
handles, labels = plt.gca().get_legend_handles_labels()
handles.insert(0, handles.pop())
labels.insert(0, labels.pop())
plt.legend(handles, labels, prop={'size': 11}, loc=[1.03, 0], 
           labelspacing=.5, facecolor='whitesmoke')

# add some texts for paper
plt.axvline(0.008, c='k', lw=1, ymin=.8)
plt.axvline(0.03, c='k', lw=1, ymin=.8)
plt.text(0.004, 1e3, r'low-$q$', va='top')
plt.text(0.013, 1e3, r'mid-$q$', va='top')
plt.text(0.07, 1e3, r'high-$q$', va='top', ha='center')
plt.text(0.06, 10, r'Raw intensity mean', va='top', ha='center')
plt.plot([0.05, 0.04], [1.5, 0.1], c='k', lw=1)

# save for paper
if reproduce_paper_fig:
    plt.savefig(paper_fig_dir / 'I22.pdf', bbox_inches='tight', 
                facecolor='w', pad_inches=.05)
plt.show()

---