In [1]:
import warnings
warnings.filterwarnings("ignore")

from time import time

In [2]:
from citrination_client import CitrinationClient
from saxskit.saxs_models import get_data_from_Citrination
from saxskit.saxs_models import train_classifiers, train_regressors

## Training models using data from Citrination

#### SAXSKIT has seven pretrained models:
* four classifiers that predict "True" or "False" for:
    * data are identifiable
    * the scatterers include one population of a normal size distribution of spherical scatterers
    * the scatters include diffraction peaks 
    * the scatters include Guinier-Porod like terms
* three regression models that predict:
    * the mean sphere size (in Angstroms)
    * the standard deviation (fractional), assuming a normal size distribution
    * the estimated intensity of the spherical scattering at q=0


#### SAXSKIT provides two options for training:
* training from scratch
* updating existing models using additional data

"training from scratch" is useful for initial training or when we have a lot of new data (more than 30%). It is recommended to use "hyper_parameters_search = True." 

Updating existing models is recommended when we have some new data (less than 30%). Updating existing models takes significant less time than "training from scratch"



## Training from "scratch"

Let's assume that initially we have only two datasets: 1 and 15. We want to use them to train the models.

#### Step 1. Get data from Citrination using Citrination credentials

In [3]:
with open("../../citrination_api_key_ssrl.txt", "r") as g:
    a_key = g.readline().strip()
cl = CitrinationClient(site='https://slac.citrination.com',api_key=a_key)

data = get_data_from_Citrination(client = cl, dataset_id_list= [1,15]) # [1,15] is a list of datasets ids

In [4]:
data.head(3)

Unnamed: 0,experiment_id,Imax_over_Imean,Imax_sharpness,I_fluctuation,logI_fluctuation,logI_max_over_std,r_fftIcentroid,r_fftImax,q_Icentroid,q_logIcentroid,...,I0_floor,G_gp,rg_gp,D_gp,I0_sphere,r0_sphere,sigma_sphere,I_pkcenter,q_pkcenter,pk_hwhm
482,R3,31.9068,1.71153,0.00232749,41.3331,3.44558,0.0655392,0.00185529,0.0694955,0.0754474,...,,,,,,,,,,
321,R2,3.38101,1.25795,0.0179402,7.46197,1.83399,0.108849,0.00185529,0.232931,4.83148,...,0.0,2.32293,4.50959,4.0,,,,,,
256,R1,18.5818,1.02822,0.00109381,3.44362,2.80407,0.107346,0.00185529,0.0651221,-0.688958,...,0.105379,,,,1602.76,33.8187,0.051751,,,


**data** is a pandas data frame that contains:
* experiment_id - It will be used for grouping for creating crossvaligdation folders during the training. Often samples from the same experiment are very similar and we should avoid to to have the samples from the same experiment in training and validation sets


* Twenty features: 
       'Imax_over_Imean', 'Imax_sharpness', 'I_fluctuation',
       'logI_fluctuation', 'logI_max_over_std', 'r_fftIcentroid', 'r_fftImax',
       'q_Icentroid', 'q_logIcentroid', 'pearson_q', 'pearson_q2',
       'pearson_expq', 'pearson_invexpq', 'I0_over_Imean', 'I0_curvature',
       'q_at_half_I0', 'q_at_Iq4_min1', 'pIq4_qwidth', 'pI_qvertex',
       'pI_qwidth'.
 
 
* Four True / False labels (for classification models): 
       'unidentified', 'guinier_porod', 'spherical_normal',
       'diffraction_peaks'. 
If a sample have 'unidentified = True', it also have "False" for all other labels.


* Ten continuouse labels (for regression models): 
       'I0_floor', 'G_gp', 'rg_gp', 'D_gp', 'I0_sphere',
       'r0_sphere', 'sigma_sphere', 'I_pkcenter', 'q_pkcenter', 'pk_hwhm'. 
Some samples have "None" for some of these labels. For example, only samples with 'spherical_normal =  True' have some value for 'sigma_sphere'

#### Step 2. Train Classifiers

In [5]:
t0 = time()
train_classifiers(data,  hyper_parameters_search = True)
# scalers and models will be saved in 'saxskit/saxskit/modeling_data/scalers_and_models.yml'
# accuracy will be saved in 'saxskit/saxskit/modeling_data/accuracy.txt'
# We can use yaml_filename='file_name.yml' as additional parametrs to save scalers and models in it
print("Training took about", (time()-t0)/60, " minutes.")

Training took about 1.0713050007820129  minutes.


In [6]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'unidentified': 0.98480689025875612, 'spherical_normal': 0.96939055265495433, 'guinier_porod': 0.84126978694564425, 'diffraction_peaks': 0.97664861249419055}"

Since often the data form the same experiment is highly correlated, "Leave N Group Out" technique is used to calculate accuracy. Data from two experiments is excluded from training and used as testing set. For example, if we have experiments 1,2,3,5,and 5:
* train the model on 1,2 3; test on 4,5
* train the model on 1,2,5; test on 3,4
* try all combinations...
* calculate average accuracy

#### Step 3. Train Regression models

In [7]:
t0 = time()
train_regressors(data,  hyper_parameters_search = True)
# scalers and models will be saved in 'saxskit/saxskit/modeling_data/scalers_and_models_regression.yml'
# accuracy will be saved in 'saxskit/saxskit/modeling_data/accuracy_regression.txt'
# We can use yaml_filename='file_name.yml' as additional parametrs to save scalers and models in it
print("Training took about", (time()-t0)/60, " minutes.")

Training took about 14.730934969584148  minutes.


In [8]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'r0_sphere': 0.29201002237526336, 'sigma_sphere': 0.61203081826345462, 'rg_gp': 0.27677267393412075}"

For the regression models, "Leave N Group Out" technique is also used. The accuracy is calculated as absolute mean error divided by standard derivation. 