# Analysis of the SkX phase on the triangular lattice

## Header

In [None]:
import numpy as np
from numpy import linalg
from numba import jit
import math
import matplotlib as mpl
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.colors import SymLogNorm
import matplotlib.transforms as mtransforms
from matplotlib.ticker import (MultipleLocator, AutoMinorLocator)
from scipy.constants import pi
from scipy.interpolate import RectBivariateSpline, griddata

import os
import ast
import pandas as pd

def set_pandas_display_options() -> None:
    """Set pandas display options."""
    # Ref: https://stackoverflow.com/a/52432757/
    display = pd.options.display

    display.max_columns = None
    display.max_rows = None
    display.max_colwidth = None
    display.width = None
    display.precision = None  # set as needed
    display.float_format = '{:,.8f}'.format

## Data Import

Importing data to pandas dataframe

In [None]:
files = [x[0] for x in os.walk('./data/')][1::]

df = pd.DataFrame()
for file in files:
    data = pd.read_json(file+'/params.json',orient='index').transpose()
    data["key"] = file[7:]
    
    if "tau1" in data:
        
        Chern_numbers = np.load(file+"/chern_numbers_tau1tau2u1u2.npy",allow_pickle=True)[()]

        for label in Chern_numbers:
            if label=="":
                data["IDS"]=Chern_numbers[label].real
            else:
                data[label]=Chern_numbers[label].real
        df = pd.concat([df,data])
        
    if "n_moments" in data:
        df = pd.concat([df,data])
    
display(df.set_index('key').sort_index(ascending=False))

In [None]:
key = '1700658946'

gap_keys = ['1700142847','1700143386','1700143909','1700144436','1700144967','1700145500']

In [None]:
# parameter
tex, sys_size = df.set_index('key').at[key,'texture'], df.set_index('key').at[key,'system_sizes'],
t, m = df.set_index('key').at[key,'t'], df.set_index('key').at[key,'m'],
shift, mag = df.set_index('key').at[key,'shift'], df.set_index('key').at[key,'mag'],
n_energies, n_moments, n_random_states = df.set_index('key').at[key,'n_energies'], df.set_index('key').at[key,'n_moments'], df.set_index('key').at[key,'n_random_states'],
scale, epsilon = 12, 0.01,

# data
qs = np.array([i/sys_size for i in range(1,math.floor(sys_size/2))])
flux = np.concatenate((qs,np.array([1-q for q in qs[::-1]])))

emesh = scale*np.linspace(-1, 1, n_energies) * (1-epsilon)

n_q = len(qs)
dos = []
for i in range(n_q):
    dos.append( np.load('./data/'+key+'/dos_'+str(i).zfill(4)+'.npy') )
dos = np.array(dos)  
dos = np.concatenate((dos,dos[::-1]))

Emin = np.amin(emesh)
Emax = np.amax(emesh)

phimin = np.amin(flux)
phimax = np.amax(flux)

fermis = np.zeros(len(gap_keys))
thetas = np.zeros(len(gap_keys))
for i,k in enumerate(gap_keys):
    fermis[i] = df.set_index('key').at[k,'fermi']
    thetas[i] = df.set_index('key').at[k,'q']

## Plot

In [None]:
set_pandas_display_options()
plt.rcParams['figure.figsize'] = [12, 6]
plt.rcParams['savefig.facecolor'] = "white"
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['axes.linewidth'] = 1
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
tfs    = 30 #title font size
lfs    = 24 #label font size
fs     = 20 #font size
cbarfs = 20 #colorbar font size

linthresh = 0.07 # The range within which the plot is linear
linscale = 1   # The factor by which data smaller than `linthresh` is scaled.
norm = SymLogNorm(linthresh=linthresh, linscale=linscale)

In [None]:
x_ticks = np.arange(0, 1.1, 0.1)

plt.imshow(dos.T, aspect='auto',norm=norm, extent=(0,1, Emin, Emax), interpolation='gaussian', origin = 'lower', resample=True,cmap='Blues');
ax = plt.gca()
ax.set_xlabel(r"$\vartheta$",fontsize=lfs)
ax.set_ylabel(r"$E_F$",labelpad=-20,fontsize=lfs)
ax.set_xticks(x_ticks)
ax.tick_params(axis='both', which='major', labelsize=fs)
ax.set_ylim( (-11,8))

cbar = plt.colorbar(pad=0.02)
cbar.set_label(label=r'DOS (a.u.)', size=lfs)
cbar.ax.tick_params(labelsize=fs)

#attaching gap labels
gaps = np.array([(thetas[i],fermis[i]) for i in range(len(gap_keys))]) # s=0,s=pi , m=0.0

gaps_col = cm.Set1(range(len(gap_keys)))
plt.scatter(gaps[:,0],gaps[:,1],c=gaps_col,s=40)

plt.tight_layout()
plt.savefig("./plots/DoS_kpm_0shift0mag", dpi=300, bbox_inches = 'tight')
plt.savefig("./plots/lowres/DoS_kpm_0shift0mag", dpi=100, bbox_inches = 'tight')