In [1]:
import numpy as np
import matplotlib.pyplot as plt

import sys
import os

from pythtb import W90, TBModel
import pythtb
from pythtb.utils import finite_diff_coeffs
import logging
from tensorflow import constant as const
from tensorflow import complex64
import tensorflow as tf
import tensorflow.linalg as tfla

pythtb.configure_logging(level="DEBUG")

<StreamHandler stderr (NOTSET)>

In [3]:
# Adjust the path to point to your project root.
project_root = os.path.abspath('/Users/treycole/Repos/axion-pert')
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from modules.qe_read_file import *

In [4]:
def vel_fd(H_k, mu, dk_mu, order_eps, mode='central'):
    coeffs, stencil = finite_diff_coeffs(order=order_eps, mode=mode)
    print(coeffs)

    fd_sum = np.zeros_like(H_k)

    for s, c in zip(stencil, coeffs):
        fd_sum += c * np.roll(H_k, shift=-s, axis=mu)

    v = fd_sum / (dk_mu)
    return v

def berry_curvature(v_k, H_flat, occ_idxs=None):
   
    # tensorflow optimization
    H_flat_tf = const(H_flat, dtype=complex64)
    v_k_tf = const(v_k, dtype=complex64)

    evals_tf, evecs_tf = tfla.eigh(H_flat_tf)

    # swap for consistent indexing
    evecs_tf = tf.transpose(evecs_tf, perm=[0, 1, 3, 2])  # (n_kpts, n_beta, n_state, n_state)
    evecs_T_tf = tf.transpose(evecs_tf, perm=[0, 1, 3, 2])  # (n_kpts, n_beta, n_state, n_state) 

    evecs_conj_tf = tf.math.conj(evecs_tf)

    # Rotate velocity operators to eigenbasis
    v_k_rot_tf = tf.matmul(
        evecs_conj_tf[None, :, :, :, :],  # (1, n_kpts, n_beta, n_state, n_state)
        tf.matmul(
            v_k_tf,                       # (dim_k, n_kpts, n_beta, n_state, n_state)
            evecs_T_tf[None, :, :, :, :]  # (1, n_kpts, n_beta, n_state, n_state)
        )
    )  # (dim_k, n_kpts, n_beta, n_state, n_state)

    # Identify occupied bands
    n_eigs = evals_tf.shape[-1]
    if occ_idxs is None:
        occ_idxs =  np.arange(n_eigs//2)
    elif occ_idxs == 'all':
        occ_idxs =  np.arange(n_eigs)
    else:
        occ_idxs = np.array(occ_idxs)

    # Identify conduction bands
    cond_idxs = np.setdiff1d(np.arange(n_eigs), occ_idxs)  # Identify conduction bands

    # Compute energy denominators
    delta_E_tf = evals_tf[..., None, :] - evals_tf[..., :, None]
    delta_E_occ_cond_tf = tf.gather(tf.gather(delta_E_tf, occ_idxs, axis=-2), cond_idxs, axis=-1)
    delta_E_cond_occ_tf = tf.gather(tf.gather(delta_E_tf, cond_idxs, axis=-2), occ_idxs, axis=-1)
    inv_delta_E_occ_cond_tf = 1 / delta_E_occ_cond_tf
    inv_delta_E_cond_occ_tf = 1 / delta_E_cond_occ_tf

    v_occ_cond_tf = tf.gather(tf.gather(v_k_rot_tf, occ_idxs, axis=-2), cond_idxs, axis=-1)
    v_cond_occ_tf = tf.gather(tf.gather(v_k_rot_tf, cond_idxs, axis=-2), occ_idxs, axis=-1)
    v_occ_cond_tf = v_occ_cond_tf * inv_delta_E_occ_cond_tf
    v_cond_occ_tf = v_cond_occ_tf * -inv_delta_E_cond_occ_tf

    Q = tf.matmul(v_occ_cond_tf[:, None], v_cond_occ_tf[None, :])
    Q = Q.numpy()

    Omega = 1j * (Q - np.swapaxes(Q, -1, -2).conj())

    return Omega

In [8]:
MBT = W90("124968bc889", r"MnBi2Te4")
MBT_mode1 = W90("132931bc889", r"MnBi2Te4")

In [9]:
E_F = 7.5901 # eV

In [10]:
# get tb model in which some small terms are ignored
model = MBT.model(
    zero_energy=E_F, min_hopping_norm=1e-5,
    max_distance=125, ignorable_imaginary_part=None
    )

In [11]:
model_mode1 = MBT_mode1.model(
    zero_energy=E_F, min_hopping_norm=1e-5,
    max_distance=125, ignorable_imaginary_part=None
    )

In [18]:
nk = 10
k_mesh = model.k_uniform_mesh([nk, nk, nk])
d3k = 1/(nk**3)

In [13]:
n_occ = 58
batch_size = 100

H = np.zeros((k_mesh.shape[0], 2, model.norb, model.norb), dtype=np.complex128)
v_k = np.zeros((3, k_mesh.shape[0], 2, model.norb, model.norb), dtype=np.complex128)

for i in range(0, len(k_mesh), batch_size):
    print(f"{i} / {len(k_mesh)} k points processed")
    k_batch = k_mesh[i:i+batch_size]

    H_batch = model.hamiltonian(k_batch, flatten_spin_axis=True)
    H_mode1_batch = model_mode1.hamiltonian(k_batch, flatten_spin_axis=True)

    H[i:i+batch_size, 0, :, :] = H_batch
    H[i:i+batch_size, 1, :, :] = H_mode1_batch

    v_batch = model.velocity(k_batch, flatten_spin_axis=True)
    v_mode1_batch = model_mode1.velocity(k_batch, flatten_spin_axis=True)

    v_k[:, i:i+batch_size, 0,:, :] = v_batch
    v_k[:, i:i+batch_size, 1, :, :] = v_mode1_batch

0 / 1000 k points processed
100 / 1000 k points processed
200 / 1000 k points processed
300 / 1000 k points processed
400 / 1000 k points processed
500 / 1000 k points processed
600 / 1000 k points processed
700 / 1000 k points processed
800 / 1000 k points processed
900 / 1000 k points processed


In [14]:
v_beta = vel_fd(H, mu=1, dk_mu=0.01, order_eps=1, mode='forward')
v = np.concatenate((v_k, v_beta[np.newaxis, ...]), axis=0)

[-1.  1.]


In [15]:
b_curv = berry_curvature(v, H, occ_idxs=range(n_occ))

2025-12-22 12:14:29.355404: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2025-12-22 12:14:29.358829: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2025-12-22 12:14:29.359684: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.66 GB
2025-12-22 12:14:29.360377: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-12-22 12:14:29.360776: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


$\partial_{B_\nu} \theta = \epsilon_{ijkl} \text{Tr}(\Omega_{ij} \Omega_{kl})$

In [16]:
from pythtb.utils import levi_civita
epsilon = levi_civita(4, 4)
chern2_density = np.einsum("ijkl, ij...mn, kl...nm->...", epsilon, b_curv, b_curv) * (1/(16*np.pi))

In [19]:
# Reshape to 3D k-mesh
chern2_density = chern2_density.reshape((nk, nk, nk, 2))

In [20]:
# Integrate over k (excluding k_i=1)
dtheta = np.sum(chern2_density[:-1, :-1, :-1], axis=(0,1,2)) * d3k
dtheta

array([-33.92484848+0.j,   9.99498484+0.j])