In [1]:
import pickle
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from copy import copy

# Jax
import jax.numpy as jnp
from jax import jit, vmap
import jax
from jax.nn import softmax
from jax_md import space

jax.config.update("jax_enable_x64", True)

# ASE
from ase.io import write
from ase.optimize import BFGS,FIRE
from ase.calculators.calculator import Calculator, all_changes

# Utils
from leopold.utils import get_atoms_from_data, get_data_from_atoms, get_data_from_xyz, batch_data
from leopold.utils import get_model
from leopold.utils import LeopoldCalculator
from leopold.utils import AtomsData

# Types
from argparse import ArgumentParser, Namespace


# Pristine structure
First we relax any training structure without any polarons to generate a pristine structure:

In [5]:
# load any structure from training data set and load calculator
data = get_data_from_xyz("../data/F-doped_TiO2/train.xyz")
image = get_atoms_from_data(data)[0]
calc = LeopoldCalculator("../models/TiO2+F.pkl", 
                         "../data/F-doped_TiO2/train.xyz")

# Remove polaron
image.arrays['pol_state'] = image.arrays['pol_state'].at[:].set(0)

# Set calculators:
image.calc = copy(calc)

# Run relaxation
dyn = FIRE(image)
dyn.run(fmax=0.001)



      Step     Time          Energy          fmax
FIRE:    0 15:44:13    -2187.956140        4.891189
FIRE:    1 15:44:13    -2192.835267        3.503501
FIRE:    2 15:44:13    -2196.351683        2.634714
FIRE:    3 15:44:14    -2198.768333        2.503214
FIRE:    4 15:44:14    -2200.381300        2.249863
FIRE:    5 15:44:14    -2201.475957        1.860856
FIRE:    6 15:44:14    -2202.248111        1.487095
FIRE:    7 15:44:14    -2202.771063        1.481042
FIRE:    8 15:44:14    -2203.064132        1.549481
FIRE:    9 15:44:14    -2203.263991        1.348672
FIRE:   10 15:44:14    -2203.599118        1.072157
FIRE:   11 15:44:15    -2203.968101        0.786485
FIRE:   12 15:44:15    -2204.275396        0.543098
FIRE:   13 15:44:15    -2204.471368        0.524780
FIRE:   14 15:44:15    -2204.571345        0.645833
FIRE:   15 15:44:15    -2204.631045        0.674260
FIRE:   16 15:44:15    -2204.710282        0.699212
FIRE:   17 15:44:15    -2204.841809        0.743243
FIRE:   18 15:

True

# Relax all distinct polaron-F configurations

In [6]:
# define distance calculation function
displacement, shift = space.periodic_general(np.diag(image.get_cell()))
distances = vmap(displacement, (None,0))

# loop over all Ti atoms in cell and break local symmetry for polaron localization
images = []
for i in range(96):
    # Get 6 nearest neighbors and find calculate displacement vectors for each
    dist = distances(image.get_scaled_positions()[i], image.get_scaled_positions())
    idxs = np.argsort(np.linalg.norm(dist, axis=-1))[1:7]
    shifts = dist[idxs]*0.05

    # Create distorted image
    image_distorted = image.copy()
    image_distorted.arrays['positions'][idxs] -= shifts
    image_distorted.wrap()
    image_distorted.arrays['pol_state'] = image_distorted.arrays['pol_state'].at[:].set(0).at[i].set(1)

    # Set calculators:
    image_distorted.calc = copy(calc)

    # Run Relaxation and store for further processing
    dyn = BFGS(image_distorted)
    dyn.run(fmax=0.001)

    images.append(image_distorted.copy())

      Step     Time          Energy          fmax
BFGS:    0 15:44:57    -2203.028176        1.809969
BFGS:    1 15:44:57    -2203.225383        1.033376
BFGS:    2 15:44:58    -2203.286772        0.275188
BFGS:    3 15:44:58    -2203.297296        0.089552
BFGS:    4 15:44:58    -2203.301284        0.065207
BFGS:    5 15:44:58    -2203.303207        0.055664
BFGS:    6 15:44:58    -2203.304971        0.042948
BFGS:    7 15:44:59    -2203.306249        0.043942
BFGS:    8 15:44:59    -2203.307163        0.032730
BFGS:    9 15:44:59    -2203.307733        0.023330
BFGS:   10 15:44:59    -2203.308107        0.018018
BFGS:   11 15:44:59    -2203.308324        0.012719
BFGS:   12 15:44:59    -2203.308441        0.010279
BFGS:   13 15:45:00    -2203.308512        0.007910
BFGS:   14 15:45:00    -2203.308567        0.008070
BFGS:   15 15:45:00    -2203.308609        0.006578
BFGS:   16 15:45:00    -2203.308637        0.004709
BFGS:   17 15:45:00    -2203.308657        0.004905


KeyboardInterrupt: 