In [1]:
import os
import warnings
warnings.filterwarnings("ignore")
from time import time
import yaml

In [2]:
from citrination_client import CitrinationClient
from saxskit.saxs_models import get_data_from_Citrination
from saxskit.saxs_models import train_classifiers_partial, train_regressors_partial, save_models

## Updating models using data from Citrination

Assume that we got a new dataset and now we want to update our models using new data. Since training "from scratch" took a significant amount of time (specially for the regression models) we will use train_classifiers_partial() and train_regressors_partial() to update the models with the new data.

#### Step 1. Specify full path to the YAML file where the models was saved.

The cross-validation errors were also saved in a .txt file of the same name, in the same directory

In [3]:
p = os.getcwd()
d, _  = os.path.split(p)
classifiers_path = os.path.join(d,'saxskit','modeling_data','scalers_and_models.yml')
regressors_path = os.path.join(d,'saxskit','modeling_data','scalers_and_models_regression.yml')

In [4]:
classifiers_err = os.path.join(d,'saxskit','modeling_data','scalers_and_models.txt')
regressors_err = os.path.join(d,'saxskit','modeling_data','scalers_and_models_regression.txt')

#### Accuracy before updating:

In [5]:
with open(classifiers_err, "r") as g:
    accuracy = eval(g.readline())    
print('-----trained model: accuracies for classification models-----')
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

-----trained model: accuracies for classification models-----
diffraction_peaks: 0.9510
guinier_porod: 0.8013
spherical_normal: 0.9780
unidentified: 0.9845


In [6]:
with open(regressors_err, "r") as g:
    accuracy_reg = eval(g.readline())    
print('-----trained model: accuracies for rergession models-----')
for model_name, acc in accuracy_reg.items():
    print('{}: {:.4f}'.format(model_name,acc))

-----trained model: accuracies for rergession models-----
r0_sphere: 0.2921
rg_gp: 0.2768
sigma_sphere: 0.6118


#### Step 2. Get data from Citrination using Citrination credentials.

In [7]:
api_key_file = os.path.join(d, 'api_key.txt')
if not os.path.exists(api_key_file):
    print("Citrination api key file did not find")

In [8]:
with open(api_key_file, "r") as g:
    a_key = g.readline().strip()
cl = CitrinationClient(site='https://slac.citrination.com',api_key=a_key)

new_data = get_data_from_Citrination(client = cl, dataset_id_list= [16]) # [16] is a list of datasets ids

#### Step 3 (optional). Get all available data from Citrination.

If we want to know the accuracy of the updated models, it is recommended to calculate it against the full training set. To calculate the reported accuracy "Leave-N-Groups-Out" technique is used. Every cycle data from two experiments used for testing and the other data for training. The average accuracy is reported.

In [9]:
all_data = get_data_from_Citrination(client = cl, dataset_id_list= [1,15,16])

#### Step 4. Update classifiers.

In [10]:
t0 = time()
scalers, models, new_accuracy = train_classifiers_partial(
        new_data, classifiers_path, all_training_data=all_data, model='all')
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

Model updates took 0.11 minutes


train_classifiers_partial() has an optional argument 'model' which can be used to specify the model to train. For example:

    scalers, models, accuracy = train_classifiers(data, hyper_parameters_search = True, model='spherical_normal')
The names of models to train :"unidentified", "spherical_normal","guinier_porod", "diffraction_peaks", or "all" to train all models.

#### Accuracy after updating:

In [11]:
for model_name, acc in new_accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

diffraction_peaks: 0.9802
guinier_porod: 0.7321
spherical_normal: 0.9765
unidentified: 0.9886


If we are not satisfied with new accuracy, we can train the models "from scratch":

    scalers, models, new_accuracy = train_classifiers(all_data, hyper_parameters_search = True, model='all')

#### Step 5. Save updated classifiers. 

Scalers, models, sklearn version, and cross-validation errors will be saved at "classifiers_path", and the cross-validation errors are also saved in a .txt file of the same name, in the same directory. If the path is not specified, the models will be saved at'modeling_data/custom_models/some_number.yml'and the cross-validation errors are also saved in a .txt file of the same name, in the same directory.

In [12]:
save_models(scalers, models, new_accuracy, classifiers_path)

#### Step 6. Update regression models.

In [13]:
t0 = time()
scalers, models, new_accuracy = train_regressors_partial(
        new_data, regressors_path, all_training_data=all_data, model='all')
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

Model updates took 0.64 minutes


train_regressors_partial() has an optional argument 'model' which can be used to specify the model to train. For example:

    scalers, models, accuracy = train_regressors_partial(data, hyper_parameters_search = False, model= 'r0_sphere')
The names of models to train :"r0_sphere", "sigma_sphere", "rg_gp", or "all" to train all models.

#### Accuracy after updating:

In [14]:
for model_name, acc in new_accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

r0_sphere: 0.2642
rg_gp: 1.1316
sigma_sphere: 0.5594


Again, if we are not satisfied with new accuracy, we can train the models "from scratch":

    scalers, models, new_accuracy = train_regressors(all_data, hyper_parameters_search = True, model='all')

#### Step 7. Save updated regression models.

Scalers, models, sklearn version, and cross-validation errors will be saved at "regressors_path", and the cross-validation errors are also saved in a .txt file of the same name, in the same directory. If the path is not specified, the models will be saved at'modeling_data/custom_models/some_number.yml'and the cross-validation errors are also saved in a .txt file of the same name, in the same directory.

In [15]:
save_models(scalers, models, new_accuracy, regressors_path)