In [None]:
import ase.io
import shutil
import pathlib
import numpy as np
from pycp2k import CP2K

In [None]:
atoms = ase.io.read('input.xyz', index=':')

In [None]:
directions = {
    "x": "1.0 0.0 0.0",
    "y": "0.0 1.0 0.0",
    "z": "0.0 0.0 1.0",
    "-x": "-1.0 0.0 0.0",
    "-y": "0.0 -1.0 0.0",
    "-z": "0.0 0.0 -1.0",
}

In [None]:
print("Number of structures:", len(atoms))

In [None]:
for i, atom in enumerate(atoms):
    atoms[i].center(vacuum=15.0)
    biggest_cell_vector = max(atoms[i].cell[0][0], atoms[i].cell[1][1], atoms[i].cell[2][2])
    atoms[i].set_cell([biggest_cell_vector, biggest_cell_vector, biggest_cell_vector])
    atoms[i].center()

In [None]:
def calculate_apt(forces, field_strength=1E-5):

    forces_x = forces["x"] - forces["-x"] 
    forces_y = forces["y"] - forces["-y"]
    forces_z = forces["z"] - forces["-z"]

    apt = np.array([forces_x, forces_y, forces_z]) / (2 * field_strength)
    apt = apt.swapaxes(1,0)
    return apt

def write_apt(apt, filename='apt.txt'):
    """
    Write the APT tensor to a file in a format that can be read by humans (CP2k style).
    """
    with open(filename, 'w') as f:
        f.write("#APT\n")
        n_atoms = apt.shape[0]
        for data in range(n_atoms):
            trace = np.trace(apt[data,:,:])/3
            f.write(f"# ATOM {data+1} {trace} \n")
            np.savetxt(f, apt[data,:,:], fmt="%10.6f")

def parse_apt(filename):
    """
    Read APT tensor from a CP2k-style APT file.
    """
    atoms = []
    coords = []

    with open(filename, 'r') as f:
        lines = f.readlines()

    i = 0
    while i < len(lines):
        line = lines[i].strip()
        if line.startswith('# ATOM'):
            coords = []
            for j in range(1,4):
                coord_line = lines[i+j].strip()
                coord = list(map(float, coord_line.split()))
                if len(coord) != 3:
                    raise ValueError(f"Coordinate line expected 3 values, but got {len(coord)} values: '{coord_line}'")
                coords.append(coord)
            atoms.append(coords)
            i += 4  # Jump to next section
        else:
            i += 1  # Move one line forward

    result = np.array(atoms, dtype=float)  # shape (number_of_atoms, 3, 3)
    return result

def parse_cp2k_forces(file):
    # Use `usecols` to select only the columns with numerical data
    forces = np.genfromtxt(
        file,
        delimiter='',
        skip_header=3,
        skip_footer=2,
        usecols=(2, 3, 4, 5)
    )

    return forces[:,0:3]

def run_cp2k(project_name, atom, field_vector=None, field_strength=1E-5, use_restart=None):
    dir = pathlib.Path(project_name)

    not_converged = False
    if (dir / 'run.out').exists():
        with open(dir / 'run.out') as f:
            if "SCF run converged in" in f.read():
                #pass
                return
            else:
                not_converged = True

    dir.mkdir(parents=True, exist_ok=True)
    if use_restart:
        # Copy restart file
        shutil.copyfile(use_restart, dir / 'RESTART.wfn')

    atom.cell[atom.cell < 0.0001] = 0

    ase.io.write(dir / 'input.xyz', atom)

    slab = ase.io.read(dir / "input.xyz")
    slab.cell
    
    calc = CP2K()
    calc.working_directory = str(dir.absolute())
    calc.project_name = project_name

    CP2K_INPUT = calc.CP2K_INPUT

    GLOBAL = CP2K_INPUT.GLOBAL
    GLOBAL.Run_type = "ENERGY_FORCE"
    GLOBAL.Print_level = "MEDIUM"
    GLOBAL.Preferred_diag_library = "SCALAPACK"
    GLOBAL.Extended_fft_lengths = True

    FORCE_EVAL = (
        CP2K_INPUT.FORCE_EVAL_add()
    )  # Repeatable items have to be first created
    FORCE_EVAL.Method = "QS"

    SUBSYS = FORCE_EVAL.SUBSYS

    HYDROGEN = SUBSYS.KIND_add("H")
    HYDROGEN.Basis_set = "DZVP-MOLOPT-SR-GTH-q1"
    HYDROGEN.Potential = "GTH-PBE-q1"

    OXYGEN = SUBSYS.KIND_add("O")
    OXYGEN.Basis_set = "DZVP-MOLOPT-SR-GTH-q6"
    OXYGEN.Potential = "GTH-PBE-q6"

    ### DFT Section ###
    DFT = FORCE_EVAL.DFT

    DFT.Basis_set_file_name = "BASIS_MOLOPT"
    DFT.Potential_file_name = "GTH_POTENTIALS"
    DFT.Restart_file_name = f"./RESTART.wfn"

    DFT.Charge = 0

    DFT.SCF.Scf_guess = "RESTART"

    DFT.SCF.PRINT.RESTART.Filename = "=RESTART.wfn"

    DFT.SCF.Max_scf = 50
    DFT.SCF.OUTER_SCF.Max_scf = 100

    DFT.SCF.Eps_scf = 1e-10
    DFT.SCF.OUTER_SCF.Eps_scf = 1e-10

    # OT Section
    DFT.SCF.OT.Section_parameters = "ON"
    if not_converged:
        DFT.SCF.OT.Minimizer = "CG"
    else:
        DFT.SCF.OT.Minimizer = "DIIS"
    DFT.SCF.OT.Preconditioner = "FULL_KINETIC"

    DFT.MGRID.Cutoff = 800
    DFT.MGRID.Ngrids = 5
    DFT.MGRID.Rel_cutoff = 60

    DFT.POISSON.Periodic = "NONE"
    DFT.POISSON.Poisson_solver = "WAVELET"

    ## Exchange Correlation Section
    DFT.XC.XC_FUNCTIONAL.PBE.Section_parameters = True
    DFT.XC.XC_FUNCTIONAL.PBE.Parametrization = "revPBE"

    # external field
    if field_vector:
        field = DFT.PERIODIC_EFIELD_add()
        field.Intensity = field_strength
        field.Polarisation = field_vector

    DFT.XC.VDW_POTENTIAL.Potential_type = "PAIR_POTENTIAL"
    pair_potential = DFT.XC.VDW_POTENTIAL.PAIR_POTENTIAL_add()
    pair_potential.Type = "DFTD3"
    pair_potential.R_cutoff = 15
    pair_potential.Long_range_correction = True
    pair_potential.Reference_functional = "revPBE"
    pair_potential.Parameter_file_name = "dftd3.dat"

    # Output Section
    DFT.PRINT.MOMENTS.Section_parameters = "ON"

    FORCE_EVAL.PRINT.FORCES.Section_parameters = "ON"
    FORCE_EVAL.PRINT.FORCES.Filename = "=forces.xyz"

    FORCE_EVAL.DFT.PRINT.HIRSHFELD.Section_parameters = "ON"
    FORCE_EVAL.DFT.PRINT.HIRSHFELD.Filename = "=hirshfeld.dat"

    FORCE_EVAL.DFT.PRINT.MULLIKEN.Section_parameters = "ON"
    FORCE_EVAL.DFT.PRINT.MULLIKEN.Filename = "=mulliken.dat"

    # Housekeeping
    SUBSYS.CELL._print_input
    calc.create_cell(SUBSYS, slab)
    calc.create_coord(SUBSYS, slab)

    calc.write_input_file(dir / 'run.inp')
    calc.run()

In [None]:
# Run configurations
for i, atom in enumerate(atoms):
    run_cp2k(f'calc/conf_{i}', atom, field_vector=None, use_restart='calc/conf_0/RESTART.wfn')

In [None]:
# Run external field calculations for all directions
for i, atom in enumerate(atoms[0:1]):
    print(f"Configuration {i} of {len(atoms)}")
    # Check if base run has finished
    if pathlib.Path(f"calc/conf_{i}/forces.xyz").exists():
        for dir in directions.keys():
            # Use the restart file from the base calculation to speed up convergence
            run_cp2k(f'calc/conf_{i}/field_{dir}', atoms[i], field_vector=directions[dir], field_strength=0.0005, use_restart=f"calc/conf_{i}/RESTART.wfn")

In [None]:
apts = {}
for i, atom in enumerate(atoms):
    forces = {}
    for dir in directions.keys():
        forces[dir] = parse_cp2k_forces(f"calc/conf_{i}/field_{dir}/forces.xyz")
    apt = calculate_apt(forces, field_strength=0.0005)
    apts[i] = apt
    write_apt(apt, f"calc/conf_{i}/apt.dat")