## 1. Settings

## 1.1 Setting up the environment

In [None]:
import torch
import h5py

from tbmalt import Geometry, OrbitalInfo
from tbmalt.ml.module import Calculator
from tbmalt.physics.dftb import Dftb2
from tbmalt.physics.dftb.feeds import SkFeed, VcrSkFeed, SkfOccupationFeed, HubbardFeed
from tbmalt.io.dataset import DataSetIM

from ase.build import molecule

Tensor = torch.Tensor

# This must be set until typecasting from HDF5 databases has been implemented.
torch.set_default_dtype(torch.float64)

## 1.2 Setting up the molecular systems for training

In [None]:
# Provide a list of molecules for training
molecule_names = ['CH4', 'H2O']

# Reference of target properties
targets = {'q_final_atomic': torch.tensor(
    [[4.251914, 0.937022, 0.937022, 0.937022, 0.937022],
     [6.526248, 0.736876, 0.736876, 0, 0]])}

# Provide information about the orbitals on each atom; this is keyed by atomic
# numbers and valued by azimuthal quantum numbers like so:
#   {Z₁: [ℓᵢ, ℓⱼ, ..., ℓₙ], Z₂: [ℓᵢ, ℓⱼ, ..., ℓₙ], ...}
shell_dict = {1: [0], 6: [0, 1], 8: [0, 1]}

## 1.3 Setting up the model for training

In [None]:
# Before running this example, please use setup.ipynb to download the parameter set needed
# Location at which the DFTB parameter set database is located
parameter_db_path = '../../data/data/example_dftb_vcr.h5'
parameter_db_path_std = '../../data/example_dftb_parameters.h5'

# Number of fitting cycles, number of batch size each cycle
number_of_epochs = 50
lr = 0.01
onsite_lr = 1e-3
criterion = torch.nn.MSELoss(reduction='mean')
tolerance = 1e-6  # tolerance of loss
shell_resolved = False  # If DFTB Hubbard U is shell resolved

## 1.4 Setting up the DFTB calculator

## 1.4.1 Input the molecular systems

In [None]:
# Construct the `Geometry` and `OrbitalInfo` objects. The former is analogous
# to the ase.Atoms object while the latter provides information about what
# orbitals are present and which atoms they belong to.
geometry = Geometry.from_ase_atoms(list(map(molecule, molecule_names)))
orbs = OrbitalInfo(geometry.atomic_numbers, shell_dict, shell_resolved=False)

# Identify which species are present
species = torch.unique(geometry.atomic_numbers)
# Strip out padding species and convert to a standard list.
species = species[species != 0].tolist()

## 1.4.2 Loading of the DFTB parameters into their associated feed objects

In [None]:
# Load the Hamiltonian feed model
h_feed = VcrSkFeed.from_database(parameter_db_path, species, 'hamiltonian',
                                 requires_grad_onsite=False)
h_feed_std = SkFeed.from_database(
    parameter_db_path_std, species, 'hamiltonian')

# Load the overlap feed model
s_feed = VcrSkFeed.from_database(parameter_db_path, species, 'overlap')
s_feed_std = SkFeed.from_database(
    parameter_db_path_std, species, 'overlap')

# Load the occupation feed object
o_feed = SkfOccupationFeed.from_database(parameter_db_path, species)
o_feed_std = SkfOccupationFeed.from_database(parameter_db_path_std, species)

# Load the Hubbard-U feed object
u_feed = HubbardFeed.from_database(parameter_db_path, species)
u_feed_std = HubbardFeed.from_database(parameter_db_path_std, species)

## 1.4.3 Constructing the SCC-DFTB calculator

In [None]:
dftb_calculator_init = Dftb2(h_feed, s_feed, o_feed, u_feed,
                             filling_temp=None, filling_scheme=None)
dftb_calculator_init_std = Dftb2(h_feed_std, s_feed_std, o_feed_std, u_feed_std,
                                 filling_temp=None, filling_scheme=None)

## 1.4.4 Constructing machine learning object

In [None]:
def build_optim(dftb_calculator):
    """Build optimizer for VCR training."""
    # For global compression radii, optimize each atom specie parameters
    comp_r0 = torch.nn.parameter.Parameter(torch.tensor([3.0, 2.7, 2.3]))

    ml_onsite, onsite_dict = [], {}

    for key, val in dftb_calculator.h_feed._on_sites.items():
        for l in shell_dict[int(key)]:
            onsite_dict.update({(key, l): val[int(l ** 2)].detach().clone()})
            ml_onsite.append({'params': onsite_dict[(key, l)].requires_grad_(), 'lr': onsite_lr})

    optimizer = torch.optim.Adam([{'params': comp_r0, 'lr': lr}]
                                 + ml_onsite, lr=lr)

    return comp_r0, onsite_dict, optimizer

## 2. Model training

In [None]:
def calculate_losses(calculator: Calculator, targets) -> Tensor:
    """An example function computing the loss of the model.

    Args:
        calculator: calculator object via which target properties can be
            calculated.
        targets: target data to which the model should be fitted.

    Returns:
        loss: the computed loss.

    """
    loss = 0.0

    loss += criterion(calculator.q_final_atomic, targets['q_final_atomic'])

    return loss

In [None]:
# Execution
comp_r, onsite_dict, optimizer = build_optim(dftb_calculator_init)
comp_r0 = torch.clone(comp_r)
loss_old = 0
for epoch in range(number_of_epochs):
    orbs = OrbitalInfo(geometry.atomic_numbers, shell_dict,
                       shell_resolved=shell_resolved)

    this_cr = torch.ones(geometry.atomic_numbers.shape)
    for ii, iatm in enumerate(geometry.unique_atomic_numbers()):
        this_cr[iatm == geometry.atomic_numbers] = comp_r[ii]

    if not shell_resolved:
        for iatm in geometry.unique_atomic_numbers().tolist():
            for l in shell_dict[iatm]:
                for idx in torch.arange(2 * l + 1) + l:
                    dftb_calculator_init.h_feed._on_sites[str(iatm)][idx] = onsite_dict[(str(iatm), l)]

    # Perform the forwards operation
    dftb_calculator_init.h_feed.compression_radii = this_cr
    dftb_calculator_init.s_feed.compression_radii = this_cr
    dftb_calculator_init(geometry, orbs, grad_mode="direct")

    # Calculate the loss
    loss = calculate_losses(dftb_calculator_init, targets)
    optimizer.zero_grad()
    loss.retain_grad()
    print(epoch, loss)

    # Invoke the autograd engine
    loss.backward(retain_graph=True)
    optimizer.step()
    if torch.abs(loss_old - loss.detach()).lt(tolerance):
        break
    loss_old = loss.detach().clone()

    this_cr = this_cr.detach().clone()
    min_mask = this_cr[this_cr != 0].lt(1.75)
    max_mask = this_cr[this_cr != 0].gt(9.5)

    # To make sure compression radii inside reasonable range
    if min_mask.any():
        with torch.no_grad():
            comp_r.clamp_(min=2.0)
    if max_mask.any():
        with torch.no_grad():
            comp_r.clamp_(max=9.0)
            
dftb_calculator_init.h_feed.compression_radii = comp_r
        
print("\nInitial compression radius of H:", comp_r0[0].detach(), "Optimized compression radius of H:", comp_r[0].detach())
print("Initial compression radius of C:", comp_r0[1].detach(), "Optimized compression radius of C:", comp_r[1].detach())
print("Initial compression radius of O:", comp_r0[2].detach(), "Optimized compression radius of O:", comp_r[2].detach())
print("Initial onsite energy of H:", dftb_calculator_init_std.h_feed._on_sites["1"].data, "Optimized onsite energy of H:", dftb_calculator_init.h_feed._on_sites["1"].data)
print("Initial onsite energy of C:", dftb_calculator_init_std.h_feed._on_sites["6"].data, "Optimized onsite energy of C:", dftb_calculator_init.h_feed._on_sites["6"].data)
print("Initial onsite energy of O:", dftb_calculator_init_std.h_feed._on_sites["8"].data, "Optimized onsite energy of O:", dftb_calculator_init.h_feed._on_sites["8"].data)