# Active-learning tutorial: Using committee MACE models to study protonated water clusters

1. Load all the modules
2. read the training pool 
3. select random training set of 25 structure from the pool (can be done with np.rand) --> latter exclude these from the pool
4. Train a committee (just to check we can train 2)
5. predict on the training pool and sort max energy error
6. Then we repeat in a for loop.

## To Do 

- for loop everywhere
- avoid using scripts for MACE
- fix E0s

In [None]:
from IPython.display import Image, display
display(Image(filename='../initial-datasets/zundel/zundel.png'))

## Import modules

In [None]:
import os, sys
import multiprocessing
from pathlib import Path
from tqdm.notebook import tqdm
from contextlib import redirect_stdout, redirect_stderr

import numpy as np
import matplotlib.pyplot as plt

from ase.io import read, write # read and write structures
# from ase.visualize import view # visualize structures (optional)

# import functions to run this tutorial
from myfunctions import train_mace     # train MACE model
from myfunctions import eval_mace      # evaluate MACE model
from myfunctions import extxyz2energy  # extract energy from extxyz file

In [None]:
np.random.seed(0)
plt.style.use('notebook.mplstyle')
os.makedirs('config', exist_ok=True)
os.makedirs('models', exist_ok=True)
os.makedirs('structures', exist_ok=True)

In [None]:
n_init_train = 20
n_test = 50  
n_committee = 4
parallel = False

## Select initial training structures

In [None]:
# Read the all the structures from file
structures = read('../initial-datasets/zundel/train.extxyz', index=':')
print(f'Total number of structures: {len(structures)}')
# view(structures)  # Opens an interactive GUI window to visualize the structures

In [None]:
# Create the initial training and test sets
selected_indices = np.random.choice(len(structures), size=(n_init_train + n_test), replace=False)
remaining_candidate_idcs = np.delete(np.arange(len(structures)), selected_indices)

indices_train = selected_indices[:n_init_train]
indices_test = selected_indices[n_init_train:]
assert len(indices_train) == n_init_train
assert len(indices_test) == n_test

print(f'\nSelected indices for training: {indices_train}')
print(f'\nSelected indices for test: {indices_test}')

initial_training_set = [structures[i] for i in indices_train]
test_set = [structures[i] for i in indices_test]
remaining_structures = [structures[i] for i in remaining_candidate_idcs]

print(f"\nSaving the initial training set to 'structures/init.train.extxyz'")
write('structures/init.train.extxyz', initial_training_set, format='extxyz')

print(f"\nSaving the test set to 'structures/test.extxyz'")
write('structures/test.extxyz', test_set, format='extxyz')

print(f"\nSaving the remaining structures to 'structures/remaining.extxyz'")
write('structures/remaining.extxyz', remaining_structures, format='extxyz')

## Initial Training

Hyperparameters for the committee members

In [None]:
# Define different values for each config
os.makedirs('config', exist_ok=True)
seeds = np.random.randint(0, 2**32 - 1, size=n_committee, dtype=np.uint32)
for i in range(n_committee):
    filename = f"config/config.{i}.yml"
    name = f"mace.com={i}"
    
    config_text = f"""
# You can modify the following parameters
num_channels: 16
max_L: 0            # take it larger but not smaller
max_ell: 1          # take it larger but not smaller
correlation: 1      # take it larger but not smaller
num_interactions: 2 # take it larger but not smaller

# ... but you can also modify these ones
r_max: 4.0
batch_size: 4
max_num_epochs: 100

# But please, do not modify these parameters!
model: "MACE"
name: "{name}"
model_dir: "models"
log_dir: "log"
checkpoints_dir: "checkpoints"
results_dir: "results"
train_file: "structures/init.train.extxyz"
energy_key: "REF_energy"
forces_key: "REF_forces"
E0s: "average" # to be fixed
device: cpu
swa: true
seed: {seeds[i]}
restart_latest: True
"""

    with open(filename, "w") as f:
        f.write(config_text)

    print(f"Wrote {filename}")

In [None]:
# train a committee of MACE models
os.makedirs('models', exist_ok=True)
parallel = False
if parallel:
    def train_single_model(n):
        config_path = f"config/config.{n}.yml"
        with open("test.txt", 'w') as fnull:
            with redirect_stdout(fnull), redirect_stderr(fnull):
                train_mace(config_path)
            
    with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
        pool.map(train_single_model, range(n_committee))
else:    
    for n in range(n_committee):
        with open(os.devnull, 'w') as fnull:
            with redirect_stdout(fnull), redirect_stderr(fnull):
                train_mace(f"config/config.{n}.yml")
        
# it should take around 25s

Train a committee of MACE models.

In [None]:
# remove useless files
for filename in os.listdir('log'):
    if filename.endswith('_debug.log'):
        file_path = os.path.join('log', filename)
        os.remove(file_path)
        
for n in range(n_committee):
    
    # models
    filenames = [f"models/mace.com={n}.model",
                 f"models/mace.com={n}_compiled.model",
                 f"models/mace.com={n}_stagetwo.model"]
    for filename in filenames:
        if os.path.exists(filename):
            os.remove(filename)
    
    if os.path.exists(f"models/mace.com={n}_stagetwo_compiled.model"):
        os.rename(f"models/mace.com={n}_stagetwo_compiled.model",f"models/mace.n={n}.model")
    
for filename in os.listdir('results'):
    if filename.endswith('.txt') or filename.endswith('stage_one.png'):
        file_path = os.path.join('results', filename)
        os.remove(file_path)

## Evaluation

In [None]:
for n in tqdm(range(n_committee)):
    eval_mace(f'models/mace.n={n:d}.model', '../initial-datasets/zundel/train.extxyz', f'eval_train_{n:02d}.extxyz')

In [None]:
# read in predicted energies
energies = np.array([extxyz2energy(f'eval_train_{n:02d}.extxyz') for n in tqdm(range(n_committee))])

In [None]:
avg_energy = energies.mean(axis=0)
disagreement = energies.std(axis=0)

In [None]:
for n, e in enumerate(energies):
    plt.plot(e, label=rf'$E_{n:d}$', alpha=0.5)
plt.plot(avg_energy, label=r'$\overline{E}$', color='k')
plt.legend()
plt.xlabel('Data point index')
plt.ylabel('Energy [eV]');

In [None]:
plt.plot(disagreement)
plt.xlabel('Data point index')
plt.ylabel(r'$\sigma(E)$ [eV]');

# Select relevant training data via Query by Committee (QbC)

Some text...

In [None]:
def run_qbc(fns_committee, fn_candidates, fn_train_init, n_iter, n_add_iter=10, recalculate_selected=False, calculator=None):
    """Main QbC loop."""
    # TODO: Add the possibility of attaching a ASE calculator for later when we need to address unlabeled data.
    # TODO: think about striding the candidates to make it more efficient
    # TODO: start from training set size 0?

    print(f'Starting QbC.')
    print(f'{n_iter:d} iterations will be done in total and {n_add_iter:d} will be added every iteration.')

    #os.makedirs('QbC', exist_ok=True)

    candidates = read(fn_candidates, index=':')
    training_set = []
    progress_disagreement = []
    for _ in tqdm(range(n_iter)):

        # predict disagreement on all candidates
        print(f'Predicting committee disagreement across the candidate pool.')
        energies = []
        for n, model in enumerate(fns_committee):
            fn_dump = f'eval_train_{n:02d}.extxyz'
            eval_mace(model, fn_candidates, fn_dump) # Explicit arguments!
            e = extxyz2energy(fn_dump)
            energies.append(e)
        energies = np.array(energies)
        disagreement = energies.std(axis=0)
        avg_disagreement_pool = disagreement.mean()

        # pick the `n_add_iter` highest-disagreement structures
        print(f'Picking {n_add_iter:d} new highest-disagreement data points.')
        idcs_selected = np.argsort(disagreement)[-n_add_iter:]
        print(idcs_selected)
        avg_disagreement_selected = (disagreement[idcs_selected]).mean()
        progress_disagreement.append(np.array([avg_disagreement_selected, avg_disagreement_pool]))
        # TODO: an ASE calculator will come here
        if recalculate_selected:
            assert calculator is not None, 'If a first-principles recalculation of training data is requested, a corresponding ASE calculator must be assigned.'
            print(f'Recalculating ab initio energies and forces for new data points.')
            for structure in candidates[idcs_selected]:
                structure.calc = calculator
                structure.get_potential_energy()
                structure.get_forces()
        #training_set.append([candidates[i] for i in idcs_selected])
        #candidates = np.delete(candidates, idcs_selected)
        # TODO: super ugly, make it better
        for i in idcs_selected:
            training_set.append(candidates[i])
        for i in idcs_selected:
            del candidates[i]

        # dump files with structures
        write('train-iter.extxyz', training_set, format='extxyz')
        write('candidates.extxyz', candidates, format='extxyz')

        # retrain the committee with the enriched training set
        print(f'Retraining committee.')
        # TODO: add multiprocessing
        # TODO: add model refinement
        for n in range(len(fns_committee)):
            train_mace(f"config/config.{n}.yml")

        # update the candidate file name
        fn_candidates = 'candidates.extxyz'

        print(f'Status at the end of this QbC iteration: Disagreement (pool) [eV]    Disagreement (selected) [eV]')
        print(f'                                         {avg_disagreement_pool:06f} {avg_disagreement_selected:06f}')

    # dump final training set
    write('train-final.extxyz', training_set, format='extxyz')
    np.savetxt('disagreement.txt', progress_disagreement)

In [None]:
# Define different values for each config
# TODO: make this simpler - the only thing we need to change is the name of the training extxyz file.
# TODO: implement retraining using the refinement workflow using `foundation_model`
os.makedirs('config', exist_ok=True)
seeds = np.random.randint(0, 2**32 - 1, size=n_committee, dtype=np.uint32)
for i in range(n_committee):
    filename = f"config/config.{i}.yml"
    name = f"mace.com={i}"
    
    config_text = f"""
# You can modify the following parameters
num_channels: 16
max_L: 0            # take it larger but not smaller
max_ell: 1          # take it larger but not smaller
correlation: 1      # take it larger but not smaller
num_interactions: 2 # take it larger but not smaller

# ... but you can also modify these ones
r_max: 4.0
batch_size: 4
max_num_epochs: 100

# But please, do not modify these parameters!
model: "MACE"
name: "{name}"
model_dir: "models"
log_dir: "log"
checkpoints_dir: "checkpoints"
results_dir: "results"
train_file: "train-iter.extxyz"
energy_key: "REF_energy"
forces_key: "REF_forces"
E0s: "average" # to be fixed
device: cpu
swa: true
seed: {seeds[i]}
restart_latest: True
"""

    with open(filename, "w") as f:
        f.write(config_text)

    print(f"Wrote {filename}")

In [None]:
fns_committee = [f'models/mace.n={n:d}.model' for n in range(n_committee)]

In [None]:
run_qbc(
    fns_committee=fns_committee,
    fn_candidates='structures/remaining.extxyz',
    fn_train_init='structures/init.train.extxyz',
    n_iter=5
);

In [None]:
sigma = np.loadtxt('disagreement.txt').T

In [None]:
plt.plot(sigma[0], '-o', label='Selected')
plt.plot(sigma[1], '-o', label='Candidates')
plt.legend()
plt.xlabel('QbC iteration')
plt.ylabel(r'$\sigma(E) [eV]$')

## Run FHI-aims

In [None]:
from myfunctions import run_aims

In [None]:
to_run  = structures[:4]

In [11]:
%%capture
run_aims(
    structures=to_run,
    folder='aims',
    command=f"mpirun -n 4 /home/stoccoel/codes/FHIaims-polarization/build/polarization-debug/aims.250131.scalapack.mpi.x",
    control="../aims/control.in"
)