In [1]:
import warnings
warnings.filterwarnings("ignore")
from time import time

In [2]:
from citrination_client import CitrinationClient
from saxskit.saxs_models import get_data_from_Citrination

from saxskit.saxs_models import train_classifiers_partial, train_regressors_partial

## Updating models using data from Citrination

Assume that we got a new dataset and now we want to update our models using new data. Since training "from scratch" took a significant amount of time (specially for the regression models) we will use train_classifiers_partial() and train_regressors_partial() to update the models with the new data.

#### Step 1 (Optional): Load the current set of models

In order to compare the current models against the updated models, load the current models from the scalers_and_models.yaml file where they are saved. (TODO)

#### Step 2. Get data from Citrination using Citrination credentials

In [3]:
with open("../../api_key.txt", "r") as g:
    a_key = g.readline().strip()
cl = CitrinationClient(site='https://slac.citrination.com',api_key=a_key)

new_data = get_data_from_Citrination(client = cl, dataset_id_list= [16]) # [16] is a list of datasets ids

#### Step 3 (optional). Get all available data from Citrination

If we want to know the accuracy of the updated models, it is recommended to calculate it against the full training set. 

In [4]:
all_data = get_data_from_Citrination(client = cl, dataset_id_list= [1,15,16])

#### Step 4. Update Classifiers

In [5]:
t0 = time()
train_classifiers_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

TypeError: unhashable type: 'collections.OrderedDict'

In [None]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = eval(g.readline())
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

#### Step 5. Update regression models

In [None]:
t0 = time()
train_regressors_partial(new_data, yaml_filename = None, all_training_data = all_data)
print("Model updates took {:.2f} minutes".format((time()-t0)/60))

In [None]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = eval(g.readline())
for model_name, acc in accuracy.items():
    print('{}: {:.4f}'.format(model_name,acc))

#### Step 6. Compare accuracy and restore previous models if needed

If the new accuracy is worse than the accuracy before the update, the models from before the update can be restored.

TODO: instead of re-training on all data, re-generate the scalers_and_models.yaml file from the models that were read in from the file at the beginning.

In [None]:
from saxskit.saxs_models import train_classifiers, train_regressors

train_classifiers(all_data,  hyper_parameters_search = True)

In [None]:
with open("../saxskit/modeling_data/accuracy.txt", "r") as g:
    accuracy = g.readline()    
accuracy

In [None]:
train_regressors(all_data,  hyper_parameters_search = True)

In [None]:
with open("../saxskit/modeling_data/accuracy_regression.txt", "r") as g:
    accuracy = g.readline()    
accuracy