In [1]:
import warnings
warnings.filterwarnings("ignore")
from time import time
import yaml

In [2]:
from citrination_client import CitrinationClient
from saxskit.saxs_models import get_data_from_Citrination

from saxskit.saxs_models import train_classifiers_partial, train_regressors_partial

## Updating models using data from Citrination

Assume that we got a new dataset and now we want to update our models using new data. Since training "from scratch" took a significant amount of time (specially for the regression models) we will use train_classifiers_partial() and train_regressors_partial() to update the models with the new data.

#### Step 1 (Optional): Load the current set of models

In [28]:
s_and_m_file = open( "../saxskit/modeling_data/scalers_and_models.yml",'rb')
s_and_m = yaml.load(s_and_m_file)

In [36]:
s_and_m_file_regression = open( "../saxskit/modeling_data/scalers_and_models_regression.yml",'rb')
s_and_m_regresson = yaml.load(s_and_m_file_regression)

#### Accuracy before updating:

In [30]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = eval(g.readline())    
print('-----trained model: accuracies for classification models-----')
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

-----trained model: accuracies for classification models-----
unidentified: 0.9861
spherical_normal: 0.9913
guinier_porod: 0.8253
diffraction_peaks: 0.9663


In [37]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy_reg = eval(g.readline())    
print('-----trained model: accuracies for rergession models-----')
for model_name, acc in accuracy_reg.items():
    print('{}: {:.4f}'.format(model_name,acc))

-----trained model: accuracies for rergession models-----
r0_sphere: 0.1431
sigma_sphere: 0.6480
rg_gp: 0.2306


#### Step 2. Get data from Citrination using Citrination credentials

In [22]:
with open("../../api_key.txt", "r") as g:
    a_key = g.readline().strip()
cl = CitrinationClient(site='https://slac.citrination.com',api_key=a_key)

new_data = get_data_from_Citrination(client = cl, dataset_id_list= [16]) # [16] is a list of datasets ids

#### Step 3 (optional). Get all available data from Citrination

If we want to know the accuracy of the updated models, it is recommended to calculate it against the full training set. 

In [23]:
all_data = get_data_from_Citrination(client = cl, dataset_id_list= [1,15,16])

#### Step 4. Update Classifiers

In [32]:
t0 = time()
train_classifiers_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

Model updates took 0.11 minutes


#### Accuracy after updating:

In [33]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = eval(g.readline())
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

diffraction_peaks: 0.9839
guinier_porod: 0.7874
spherical_normal: 0.9925
unidentified: 0.9870


#### Step 5. Update regression models

In [38]:
t0 = time()
train_regressors_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

Model updates took 0.69 minutes


#### Accuracy after updating:

In [39]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = eval(g.readline())
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

r0_sphere: 8059346029.5767
rg_gp: 0.3215
sigma_sphere: 0.5780


#### Step 6. Compare accuracy and restore previous models if needed

If the new accuracy is worse than the accuracy before the update, the models from before the update can be restored.

In [13]:
with open('../saxskit/modeling_data/scalers_and_models.yml', 'w') as yaml_file:
    yaml.dump(s_and_m, yaml_file)
    
with open('../saxskit/modeling_data/scalers_and_models_regression.yml', 'w') as yaml_file:
    yaml.dump(s_and_m_regresson, yaml_file)        

In [14]:
with open ('../saxskit/modeling_data/accuracy.txt', 'w') as txt_file:
    txt_file.write(str(accuracy))
    
with open ('../saxskit/modeling_data/scalers_and_models_regression.yml', 'w') as txt_file:
    txt_file.write(str(accuracy_reg))

In [15]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = eval(g.readline())    
print('-----trained model: accuracies for classification models-----')
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

-----trained model: accuracies for classification models-----
r0_sphere: 0.2658
rg_gp: 0.6377
sigma_sphere: 0.5590


In [17]:
accuracy_reg

{'r0_sphere': 0.29217359855279595,
 'rg_gp': 0.2768065209314346,
 'sigma_sphere': 0.6123823239209741}