In [1]:
import warnings
warnings.filterwarnings("ignore")
from time import time

In [2]:
from citrination_client import CitrinationClient
from saxskit.saxs_models import get_data_from_Citrination
from saxskit.saxs_models import train_classifiers, train_regressors

## Training saxskit models using data from Citrination

saxskit uses SAXS spectra to identify scatterer populations (via several binary classification models) and predict the associated scattering parameters (via regression models). The saxskit API allows the user to choose between models based on scikit-learn or models based on Citrination. The latter require a Citrination API key, which can be obtained after creating a free account on https://citrination.com. 

Parameters related to curve shape (Porod exponent, radius of gyration, mean and standard deviation of sphere size distribution) are predicted by regression models. 

Diffraction peak parameters are predicted heuristically, by identifying the most likely peak locations and fitting them locally to parabolas.

Parameters that scale the intensity are not predicted by the models. This is because saxskit featurizes spectra in a way that is invariant with respect to intensity scaling, so that the models can be used on intensity spectra reported in any units. Note, however, that this is not true for the q-domain. saxskit expects the q values to be in 1/Angstrom units.

saxskit includes models for:
* four binary classifiers:
    * 'unidentified': True if the scatterers cannot be identified easily from the data.
    * 'spherical_normal': True if there are one or more normal distributions of spherical scatterers.
    * 'diffraction_peaks': True if there are one or more diffraction peaks.
    * 'guinier_porod': One or more scatterers described by a Guinier-Porod equation
* three regression models:
    * 'r0_sphere': the mean sphere size (in Angstroms) for 'spherical_normal' scatterers
    * 'sigma_sphere': the fractional standard deviation of sphere size for 'spherical_normal' scatterers
    * 'rg_gp': the radius of gyration for 'guinier_porod' scatterers

Users with Citrination accounts can pull SAXS data from Citrination to train custom models. The SAXS records used for training must have been generated with saxskit.saxs_piftools, preferably by the same version of saxskit.

#### saxskit provides two options for training:
* training from scratch: useful for initial training or when we have a lot of new data (around 30% of the dataset or more).
* updating existing models with additional data: takes less time than training new models, especially when the existing model was trained on a large data set. This is recommended when there is some new data, but the new data are less than about 30% of the dataset.


## Training from scratch

Let's assume that initially we have two Citrination datasets: 1 and 15. We want to use them to train the models.

#### Step 1. Get data from Citrination using Citrination credentials

In [3]:
with open("../../api_key.txt", "r") as g:
    a_key = g.readline().strip()
cl = CitrinationClient(site='https://slac.citrination.com',api_key=a_key)

data = get_data_from_Citrination(client = cl, dataset_id_list= [1,15]) # [1,15] is a list of datasets ids

At this point, **data** is a pandas data frame that contains (for each SAXS record):

* experiment_id - The experiment ID is used for cross-validation grouping. Often, samples from the same experiment are very similar, so cross-validating in this way avoids overtraining.

* An array of numerical features that describe the shape of the spectrum (invariant with respect to intensity scaling).
 
* Four True / False labels (for classification models): 
    - 'unidentified'
    - 'guinier_porod'
    - 'spherical_normal'
    - 'diffraction_peaks' 
Note, if a sample has 'unidentified = True', it implies False for all other labels.

* An array of scattering parameters (previously least-squares fit with saxskit):
    * For any record that is not 'unidentified':
        - 'I0_floor': flat noise floor intensity

    * For 'guinier_porod' scatterers:
        - 'G_gp': Guinier prefactors
        - 'rg_gp': radii of gyration
        - 'D_gp': Porod exponents
        
    * For 'spherical_normal' scatterers:
        - 'I0_sphere': Intensity scaling prefactors
        - 'r0_sphere': Mean sphere radii
        - 'sigma_sphere': Fractional standard deviations
        
    * For 'diffraction_peaks':
        - 'I_pkcenter': Intensities of the peaks at their maxima
        - 'q_pkcenter': q-values of the peak maxima
        - 'pk_hwhm': peak half-widths at half-max

Note that not every record contains a value for every parameter. For example, only samples with 'spherical_normal' populations will have values for 'sigma_sphere'.

In [4]:
data.head(3)

Unnamed: 0,experiment_id,Imax_over_Imean,Imax_sharpness,I_fluctuation,logI_fluctuation,logI_max_over_std,r_fftIcentroid,r_fftImax,q_Icentroid,q_logIcentroid,...,I0_floor,G_gp,rg_gp,D_gp,I0_sphere,r0_sphere,sigma_sphere,I_pkcenter,q_pkcenter,pk_hwhm
619,R4,73.1647,2.85615,0.00236278,49.8701,3.9025,0.0876662,0.00185529,0.0928799,-5.61471,...,0.0,,,,,,,,,
250,R1,18.627,1.03083,0.00109623,3.75752,2.81677,0.10758,0.00185529,0.0650745,-0.718163,...,0.103302,,,,1602.76,33.7763,0.053542,,,
296,R2,4.80428,10.2104,0.00894922,13.5409,4.16846,0.19233,0.00185529,0.262385,0.201071,...,4.28832e-14,3.49765,3.79767,4.0,,,,,,


#### Step 2. Train Classifiers

For training from scratch, we use train_classifiers() with hyper_parameters_search = True. This will seek a set of model hyperparameters that optimizes the model.

TODO: explain how the final set of hyperparameters is chosen (i.e. what is the error metric?)

Since samples from the same experiment are often highly correlated, saxskit uses a "Leave-N-Groups-Out" technique to evaluate training error. saxskit leaves two groups (experiment_ids) out for each training cycle. For example, if we have experiments 1 through 5:
* train the model on 1, 2, and 3; test on 4 and 5
* train the model on 1, 2, and 4; test on 3 and 5
* ...
* (etc: perform all possible training/testing combinations)
* ...
* calculate average accuracy

In [5]:
t0 = time()
train_classifiers(data,  hyper_parameters_search = True)
print("Training took about {:.2f} minutes".format((time()-t0)/60))

Training took about 1.82 minutes


After train_classifiers(), a set of serialized scalers and models will be saved in the package's source directory at:

**saxskit/modeling_data/scalers_and_models.yml**

Note that train_classifiers() has an optional argument 'yaml_filename' which can be used to save to a different file path.

The accuracy of the trained models will also be reported in:

**saxskit/modeling_data/accuracy.txt**

TODO: what is the accuracy metric?

In [6]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = eval(g.readline())    
print('-----trained model accuracies-----')
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

-----trained model accuracies-----
unidentified: 0.9837
diffraction_peaks: 0.9608
spherical_normal: 0.9671
guinier_porod: 0.8175


#### Step 3. Train Regression models

The approach is the same as above, but for a different set of models. These are the three regression models for the scattering spectrum parameters affecting curve shape. In the current version, the regression model output is one-dimensional, so these are mostly useful for spectra containing **one** 'guinier_porod' and/or **one** 'spherical_normal' scatterer population.

In [7]:
t0 = time()
train_regressors(data,  hyper_parameters_search = True)
print("Training took about {:.2f} minutes".format((time()-t0)/60))

Training took about 25.07 minutes


Note, for the regression models, the "Leave-N-Groups-Out" cross validation is used, also with N=2. The reported error for each model is the mean absolute validation error divided by the standard deviation of the training data. 

In [8]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = eval(g.readline())    
print('-----trained model accuracies-----')
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

-----trained model accuracies-----
rg_gp: 0.2791
sigma_sphere: 0.6120
r0_sphere: 0.2917
