In [2]:
import warnings
warnings.filterwarnings("ignore")
from time import time

In [3]:
from citrination_client import CitrinationClient
from saxskit.saxs_models import get_data_from_Citrination

from saxskit.saxs_models import train_classifiers_partial, train_regressors_partial

## Updating models using data from Citrination

Assume that we got a new dataset and now we want to update our models using new data. Since training "from scratch" took a significant amount of time (specially for the regression models) we will use train_classifiers_partial() and train_regressors_partial() to update the models with the new data.

#### Step 1. Load the current set of models

In order to compare the current models against the updated models, load the current models from the scalers_and_models.yaml file where they are saved. (TODO)

#### Step 1. Get data from Citrination using Citrination credentials

In [4]:
with open("../../api_key.txt", "r") as g:
    a_key = g.readline().strip()
cl = CitrinationClient(site='https://slac.citrination.com',api_key=a_key)

new_data = get_data_from_Citrination(client = cl, dataset_id_list= [16]) # [16] is a list of datasets ids

#### Step 2 (optional). Get all available data from Citrination

If we want to update not just models, but also accuracy records, we need to specify the data we want to use to calculate the accuracy. It is recommended to use all available data to calculate accuracy, including the data that was used for initial training.

In [9]:
all_data = get_data_from_Citrination(client = cl, dataset_id_list= [1,15,16])

#### Step 3. Update Classifiers

In [11]:
t0 = time()
train_classifiers_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

NameError: name 'all_data' is not defined

In [None]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = eval(g.readline())
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

#### Step 4. Update regression models

In [7]:
t0 = time()
train_regressors_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

Updating took about 0.6364340662956238  minutes.


In [10]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = eval(g.readline())
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

rg_gp: 0.2306
r0_sphere: 0.1425
sigma_sphere: 0.6480


#### Step 5. Compare accuracy and re-train models if it is needed.

If the new accuracy is worse than the accuracy before the update, the models from before the update can be restored.

TODO: instead of re-training on all data, re-generate the scalers_and_models.yaml file from the models that were read in from the file at the beginning.

In [9]:
from saxskit.saxs_models import train_classifiers, train_regressors

train_classifiers(all_data,  hyper_parameters_search = True)

In [10]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'unidentified': 0.98855489215782255, 'spherical_normal': 0.99052317155423952, 'guinier_porod': 0.82488494227704401, 'diffraction_peaks': 0.98229846864406156}"

In [11]:
train_regressors(all_data,  hyper_parameters_search = True)

In [12]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = g.readline()    
accuracy

"{'r0_sphere': 0.14254575436285899, 'sigma_sphere': 0.64800471948743044, 'rg_gp': 0.23058002270411837}"