In [1]:
import warnings
warnings.filterwarnings("ignore")

from time import time

In [2]:
from citrination_client import CitrinationClient
from saxskit.saxs_models import get_data_from_Citrination

from saxskit.saxs_models import train_classifiers_partial, train_regressors_partial

## Updating models using data from Citrination.

Assume that we got a new dataset and now we want to update our models using new data. Since training "from scratch" took significant amount of time (specially, for regression models) we will use train_classifiers_partial() and train_regressors_partial().

#### Step 1. Get data from Citrination using Citrination credentials

In [3]:
with open("../../citrination_api_key_ssrl.txt", "r") as g:
    a_key = g.readline().strip()
cl = CitrinationClient(site='https://slac.citrination.com',api_key=a_key)

new_data = get_data_from_Citrination(client = cl, dataset_id_list= [16]) # [16] is a list of datasets ids

#### Step 2 (optional). Get all available data from Citrination

If we want to update not just models, but also accuracy records, we need to specify the data we want to use to calculate the accuracy. It is recommended to use all available data to calculate accuracy, including the data that was used for initial training.   

In [4]:
all_data = get_data_from_Citrination(client = cl, dataset_id_list= [1,15,16])

#### Step 3. Update Classifiers

In [5]:
t0 = time()
train_classifiers_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Updating took about", (time()-t0)/60, " minutes.")

Updating took about 0.12680447896321614  minutes.


In [6]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'diffraction_peaks': 0.98178585879954106, 'guinier_porod': 0.77288930243468423, 'spherical_normal': 0.98368631600632594, 'unidentified': 0.98851100842113593}"

#### Step 4. Update rergession models

In [7]:
t0 = time()
train_regressors_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Updating took about", (time()-t0)/60, " minutes.")

Updating took about 0.6364340662956238  minutes.


In [8]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'r0_sphere': 0.26515657127255915, 'rg_gp': 1.0723746176594846, 'sigma_sphere': 0.55702047410011657}"

#### Step 5. Compare accuracy and re-train models if it is needed.

If new accuracy is worth than accuracy we had before updating, it is recommended to retrain the models from scratch using all available data:

In [9]:
from saxskit.saxs_models import train_classifiers, train_regressors

train_classifiers(all_data,  hyper_parameters_search = True)

In [10]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'unidentified': 0.98855489215782255, 'spherical_normal': 0.99052317155423952, 'guinier_porod': 0.82488494227704401, 'diffraction_peaks': 0.98229846864406156}"

In [11]:
train_regressors(all_data,  hyper_parameters_search = True)

In [12]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'r0_sphere': 0.14254575436285899, 'sigma_sphere': 0.64800471948743044, 'rg_gp': 0.23058002270411837}"