# Active Learning Tutorial: Zundel cation

1. Load all the modules
2. read the training pool 
3. select random training set of 25 structure from the pool (can be done with np.rand) --> latter exclude these from the pool
4. Train a committee (just to check we can train 2)
5. predict on the training pool and sort max energy error
6. Then we repeat in a for loop.

## To Do 

- for loop everywhere
- avoid using scripts for MACE
- fix E0s

In [None]:
from IPython.display import Image, display
display(Image(filename='../initial-datasets/zundel/zundel.png'))

## Import modules

In [None]:
import os, sys
# import warnings
# import logging
from contextlib import redirect_stdout, redirect_stderr
import multiprocessing
import numpy as np
import matplotlib.pyplot as plt
from ase.io import read, write                                      # read and write structures
from ase.visualize import view                                      # visualize structures (optional)
from mace.cli.run_train import main as mace_run_train_main          # train a MACE model
from mace.cli.eval_configs import main as mace_eval_configs_main    # evaluate a MACE model
np.random.seed(0)

In [None]:
plt.style.use('notebook.mplstyle')

In [None]:
# definition of some helper functions
def extxyz2energy(file:str,keyword:str="MACE_energy"):
    """
    Extracts the energy values from an extxyz file and returns a numpy array
    """
    atoms = read(file, index=':')
    data = np.zeros(len(atoms),dtype=float)
    for n,atom in enumerate(atoms):
        data[n] = atom.info[keyword]
    return data

def train_mace(config:str):
    """
    Train a MACE model using the provided configuration file.
    """
    sys.argv = ["program", "--config", config]
    mace_run_train_main()
    
def eval_mace(model:str,infile:str,outfile:str):
    """
    Evaluate a MACE model.
    """
    sys.argv = ["program", "--config", infile,"--output",outfile,"--model",model]
    mace_eval_configs_main()

In [None]:
os.makedirs('config', exist_ok=True)
os.makedirs('models', exist_ok=True)
# os.makedirs('log', exist_ok=True)
# os.makedirs('chk', exist_ok=True)
# os.makedirs('results', exist_ok=True)
os.makedirs('structures', exist_ok=True)

In [None]:
N_INIT_TRAIN = 20
N_TEST = 50  
N_COMMITTEE = 4
PARALLEL = True

## Select initial training structures

In [None]:
# Read the all the structures from file
structures = read('../initial-datasets/zundel/train.extxyz', index=':')
print(f'Total number of structures: {len(structures)}')
# view(structures)  # Opens an interactive GUI window to visualize the structures

In [None]:
# Create the initial training and test sets
selected_indices = np.random.choice(len(structures), size=N_INIT_TRAIN+N_TEST, replace=False)
remaining_candidate_idcs = np.delete(np.arange(len(structures)), selected_indices)

indices_train = selected_indices[:N_INIT_TRAIN]
indices_test = selected_indices[N_INIT_TRAIN:]
assert len(indices_train) == N_INIT_TRAIN
assert len(indices_test) == N_TEST

print(f'\nSelected indices for training: {indices_train}')
print(f'\nSelected indices for test: {indices_test}')

initial_training_set = [structures[i] for i in indices_train]
test_set     = [structures[i] for i in indices_test]

print(f"\nSaving the initial training set to 'structures/init.train.extxyz'")
write('structures/init.train.extxyz', initial_training_set, format='extxyz')

print(f"\nSaving the test set to 'structures/test.extxyz'")
write('structures/test.extxyz', test_set, format='extxyz')

## Initial Training

Hyperparameters for the committee members

In [None]:
# Define different values for each config
os.makedirs('config', exist_ok=True)
seeds = np.random.randint(0, 2**32 - 1, size=N_COMMITTEE, dtype=np.uint32)
for i in range(4):
    filename = f"config/config.{i}.yml"
    name = f"mace.com={i}"
    
    config_text = f"""
# You can modify the following parameters
num_channels: 16
max_L: 0            # take it larger but not smaller
max_ell: 1          # take it larger but not smaller
correlation: 1      # take it larger but not smaller
num_interactions: 2 # take it larger but not smaller

# ... but you can also modify these ones
r_max: 4.0
batch_size: 4
max_num_epochs: 100

# But please, do not modify these parameters!
model: "MACE"
name: "{name}"
model_dir: "models"
log_dir: "log"
checkpoints_dir: "checkpoints"
results_dir: "results"
train_file: "structures/init.train.extxyz"
energy_key: "REF_energy"
forces_key: "REF_forces"
E0s: "average" # to be fixed
device: cpu
swa: true
seed: {seeds[i]}
restart_latest: True
"""

    with open(filename, "w") as f:
        f.write(config_text)

    print(f"Wrote {filename}")

Train a committee of MACE models.

In [None]:
# %%capture # supppress output
# train a committee of MACE models
os.makedirs('models', exist_ok=True)
if PARALLEL:
    def train_single_model(n):
        config_path = f"config/config.{n}.yml"
        with open(os.devnull, 'w') as fnull:
            with redirect_stdout(fnull), redirect_stderr(fnull):
                train_mace(config_path)
            
    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        pool.map(train_single_model, range(N_COMMITTEE))
else:    
    for n in range(N_COMMITTEE):
        train_mace(f"config/config.{n}.yml")
        
# it should take around 25s

In [None]:
# remove useless files
for filename in os.listdir('log'):
    if filename.endswith('_debug.log'):
        file_path = os.path.join('log', filename)
        os.remove(file_path)
        
for n in range(N_COMMITTEE):
    
    # models
    filenames = [f"models/mace.com={n}.model",
                 f"models/mace.com={n}_compiled.model",
                 f"models/mace.com={n}_stagetwo.model"]
    for filename in filenames:
        if os.path.exists(filename):
            os.remove(filename)
    
    if os.path.exists(f"models/mace.com={n}_stagetwo_compiled.model"):
        os.rename(f"models/mace.com={n}_stagetwo_compiled.model",f"models/mace.n={n}.model")
    
for filename in os.listdir('results'):
    if filename.endswith('.txt') or filename.endswith('stage_one.png'):
        file_path = os.path.join('results', filename)
        os.remove(file_path)

## Evaluation

In [None]:
# eval_mace("MACE_models/mace_com1_stagetwo_compiled.model",'../initial-datasets/zundel/train.extxyz',"eval_train_01.extxyz") # 50s
eval_mace("MACE_models/mace_com2_stagetwo_compiled.model",'../initial-datasets/zundel/train.extxyz',"eval_train_02.extxyz")

In [None]:
E1 = extxyz2energy("eval_train_01.extxyz")
E2 = extxyz2energy("eval_train_02.extxyz")

In [None]:
plt.scatter(E1,E2)

In [None]:
plt.plot(E1[selected_indices], label='MACE 1')
plt.plot(E2[selected_indices], label='MACE 2')

In [None]:
E = np.array([E1,E2]).T
mean = np.mean(E, axis=1) # mean of each structure
std = np.std(E, axis=1) # std of each structure

In [None]:
plt.plot(E1, label='MACE 1',color='red',alpha=0.5)
plt.plot(E2, label='MACE 2',color='blue',alpha=0.5)
plt.plot(mean, label='mean',color='green')
# plt.fill_between(range(len(mean)), mean-std, mean+std, color='green', alpha=0.2)
plt.legend()

In [None]:
plt.plot(mean, label='mean',color='green')
plt.fill_between(range(len(mean)), mean-std, mean+std, color='green', alpha=0.2)
plt.legend()

In [None]:
# Fix this
new_candidates = np.argsort(std[remaining_candidate_idcs])[:10]
plt.scatter(np.arange(len(std)),std, label='mean',color='green',s=1)
plt.scatter(new_candidates,std[new_candidates],color="red",s=1)

# Select relevant training data via Query by Committee (QbC)

Some text...

In [None]:
def run_qbc(committee, fn_candidates, n_iter, n_add_iter=10, recalculate_selected=False, calculator=None):
    """Main QbC loop."""
    # TODO: Add the possibility of attaching a ASE calculator for later when we need to address unlabeled data.
    # TODO: think about striding the candidates to make it more efficient

    print(f'Starting QbC.')
    print(f'{n_iter:d} iterations will be done in total and {n_add_iter:d} will be added every iteration.')

    if recalculate:
        assert calculator is not None, 'If a first-principles recalculation of training data is requested, a corresponding ASE calculator must be provided.'

    candidates = ase.io.read(fn_candidates)
    training_set = []
    for _ in tqdm(range(n_iter)):

        # predict sigma on all candidates
        print(f'Predicting committee disagreement across the candidate pool.')
        energies = []
        for model in committee:
            eval_mace("MACE_models/mace_com1_stagetwo_compiled.model", '../initial-datasets/zundel/train.extxyz', "eval_train_01.extxyz") # Explicit arguments!
            e = extxyz2energy("eval_train_01.extxyz")
            #FIXME: if we could do something about passing the files back and forth, that would be great, but only if it is not too much work...
            energies.append(e)
        energies = np.array(energies)
        disagreement = energies.std(axis=-1)
        avg_disagreement_pool = disagreement.mean()

        # pick the `n_add_iter` highest-disagreement structures
        print(f'Picking {n_add_iter:d} new highest-disagreement data points.')
        idcs_selected = np.argsort(disagreement)[:n_add_iter]
        avg_disagreement_selected = (disagreement[idcs_selected]).mean()
        # TODO: an ASE calculator will come here
        if recalculate:
            print(f'Recalculating ab initio energies and forces for new data points.')
            for structure in candidates[idcs_selected]:
                structure.calc = calculator
                structure.get_potential_energy()
                structure.get_forces
        training_set.append(candidates[idcs_selected])
        candidates = np.delete(candidates, idcs_selected)

        # retrain the committee with the enriched training set
        print(f'Retraining committee.')
        for model in committee:
            train_mace(...) # start from a previous set of NN parameters?
            # TODO: dump some info about current errors

        print(f'Status at the end of this QbC iteration: Disagreement (pool) [eV]    Disagreement (selected) [eV]')
        print(f'                                         {avg_disagreement_pool:06f} {avg_disagreement_selected:06f}')
