In [None]:
import pickle
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from copy import copy

# Jax
import jax.numpy as jnp
from jax import jit, vmap
import jax
from jax.nn import softmax
from jax_md import space

jax.config.update("jax_enable_x64", True)

# ASE
from ase.io import write
from ase.optimize import BFGS,FIRE
from ase.calculators.calculator import Calculator, all_changes

# Utils
from utils import get_atoms_from_data, get_data_from_atoms, get_data_from_xyz, batch_data
from utils import get_model
from utils import LeopoldCalculator
from utils import AtomsData

# Types
from argparse import ArgumentParser, Namespace


# Pristine structure
First we relax any training structure without any polarons to generate a pristine structure:

In [None]:
# load any structure from training data set and load calculator
data = get_data_from_xyz("../data/F-doped_TiO2/train.xyz")
image = get_atoms_from_data(data)[0]
calc = LeopoldCalculator("../models/TiO2+F.pkl", 
                         "../data/F-doped_TiO2/train.xyz")

# Remove polaron
image.arrays['pol_state'] = image.arrays['pol_state'].at[:].set(0)

# Set calculators:
image.calc = copy(calc)

# Run relaxation
dyn = FIRE(image)
dyn.run(fmax=0.001)

In [None]:
# define distance calculation function
displacement, shift = space.periodic_general(np.diag(image.get_cell()))
distances = vmap(displacement, (None,0))

# loop over all Ti atoms in cell and break local symmetry for polaron localization
images = []
for i in range(96):
    # Get 6 nearest neighbors and find calculate displacement vectors for each
    dist = distances(image.get_scaled_positions()[i], image.get_scaled_positions())
    idxs = np.argsort(np.linalg.norm(dist, axis=-1))[1:7]
    shifts = dist[idxs]*0.05

    # Create distorted image
    image_distorted = image.copy()
    image_distorted.arrays['positions'][idxs] -= shifts
    image_distorted.wrap()
    image_distorted.arrays['pol_state'] = image_distorted.arrays['pol_state'].at[:].set(0).at[i].set(1)

    # Set calculators:
    image_distorted.calc = copy(calc)

    # Run Relaxation and store for further processing
    dyn = BFGS(image_distorted)
    dyn.run(fmax=0.001)

    images.append(image_distorted.copy())