In [None]:
import numpy as np
from scipy import constants
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display
import asyncio
import plotly.io as pio
# pio.templates.default = "seaborn"


In [None]:
class KineticMonteCarloSimulation:
    def __init__(self, L):
        self.L = L # Lattice size
        self.c_B = 0.5 # Concentration of B atoms
        self.initialize_lattice()

    def initialize_lattice(self):
        num_B_atoms = int(self.L**2 * self.c_B)
        num_A_atoms = self.L**2 - num_B_atoms
        # A L by L array representing the lattice using +1 for A and -1 for B
        lattice_values = np.array([-1] * num_B_atoms + [1] * num_A_atoms)
        np.random.shuffle(lattice_values)
        self.lattice = lattice_values.reshape((self.L, self.L))
        # Set the center atom to be a vacancy
        self.lattice[self.L//2, self.L//2] = 0 
        self.vacancy_position = (self.L//2, self.L//2)
        # Create lists and maps to store the type, positions, and displacements for each atom, would be helpful for MSD calculation
        self.atom_list = []
        self.atom_to_lattice_map = {}
        self.lattice_to_atom_map = {}
        self.atom_to_displacement_map = {}
        for index in range(self.L * self.L):
            i, j = index // self.L, index % self.L
            self.atom_list.append(self.lattice[i, j])
            self.atom_to_lattice_map[index] = (i, j)
            self.lattice_to_atom_map[(i, j)] = index
            self.atom_to_displacement_map[index] = np.array([0, 0])

    # Function to calculate Mean Squared Displacement (MSD) with unwrapped periodic boundary conditions
    def calculate_msd(self):
        msd_A = np.mean([np.sum(disp ** 2) for index, disp in self.atom_to_displacement_map.items() if self.atom_list[index] == 1])
        msd_B = np.mean([np.sum(disp ** 2) for index, disp in self.atom_to_displacement_map.items() if self.atom_list[index] == -1])
        msd_vac = np.mean([np.sum(disp ** 2) for index, disp in self.atom_to_displacement_map.items() if self.atom_list[index] == 0])
        return msd_A*9, msd_B*9, msd_vac*9 # Multiply by 9 to get the MSD in terms of lattice spacing (3 Å)

    def initialize_parameters(self, T, E_AA, E_BB, E_AB, e0_A, e0_B, c_B):
        self.T = T # Temperature
        self.E_AA = E_AA # AA bond energy
        self.E_BB = E_BB # BB bond energy
        self.E_AB = E_AB # AB bond energy
        self.e0_A = e0_A # A atom intrinsic barrier
        self.e0_B = e0_B # B atom intrinsic barrier
        self.c_B = c_B # Concentration of B atoms
        self.initialize_lattice()
        # Total energy per atom
        self.total_energy_per_atom = self.get_total_energy()/(self.L**2) 
        self.time = 0

    # Helper function to calculate energy contribution for a atom and its neighbors
    def calculate_energy_contribution(self, i, j):
        atom_type = self.lattice[i, j]
        # Neighbors in 4 directions with periodic boundary conditions
        neighbor_position_lists = [((i+1)%self.L, j), ((i-1)%self.L, j), (i, (j+1)%self.L), (i, (j-1)%self.L)]
        E = 0
        for ni, nj in neighbor_position_lists:
            neighbor_type = self.lattice[ni, nj]
            if atom_type == 0 or neighbor_type == 0:
                E += 0
            if atom_type == 1 and neighbor_type == 1:
                E += self.E_AA
            elif atom_type == -1 and neighbor_type == -1:
                E += self.E_BB
            else:
                E += self.E_AB
        return E

    # Define a function to calculate the total energy using bond counting model
    def get_total_energy(self):
        E = 0
        for i in range(self.L):
            for j in range(self.L):
                E += self.calculate_energy_contribution(i, j)
        return E/2 # Each bond is counted twice

    # Define a function to calculate the energy change if the vacancy 
    # and a neighbor are swapped
    def get_energy_change(self, vacancy_pos, neighbor_pos):
        i_vac, j_vac = vacancy_pos
        i_neigh, j_neigh = neighbor_pos

        dE_before = self.calculate_energy_contribution(i_neigh, j_neigh)
        self.lattice[i_vac, j_vac], self.lattice[i_neigh, j_neigh] = self.lattice[i_neigh, j_neigh], self.lattice[i_vac, j_vac]
        dE_after = self.calculate_energy_contribution(i_vac, j_vac)
        self.lattice[i_vac, j_vac], self.lattice[i_neigh, j_neigh] = self.lattice[i_neigh, j_neigh], self.lattice[i_vac, j_vac]
        return dE_after - dE_before
   
    # Define a function to attempt to swap two atoms and update the energy 
    def one_simulation_step(self):
        # Pick two random atoms in the form of pos1 = (i1, j1) and pos2 = (i2, j2)
        kB = constants.value('Boltzmann constant in eV/K')
        i_vac, j_vac = self.vacancy_position
        neighbor_position_lists = [((i_vac + 1) % self.L, j_vac), ((i_vac - 1) % self.L, j_vac),
                                   (i_vac, (j_vac + 1) % self.L), (i_vac, (j_vac - 1) % self.L)]
        rates = []
        energy_changes = []
        for pos in neighbor_position_lists:
            dE = self.get_energy_change((i_vac, j_vac), pos)
            migration_barrier = dE/2 + self.e0_A if self.lattice[pos] == 1 else dE/2+self.e0_B
            rates.append(np.exp(-migration_barrier/(kB*self.T)))
            energy_changes.append(dE)
        rates = np.array(rates)
        probabilities = rates / np.sum(rates)
        chosen_index = np.random.choice(len(neighbor_position_lists), p=probabilities)
        chosen_pos = neighbor_position_lists[chosen_index]
        dE = energy_changes[chosen_index]
        # Swap the vacancy with the chosen neighbor
        i_neigh, j_neigh = chosen_pos
        self._update_displacement(i_vac, j_vac, i_neigh, j_neigh) # For MSE
        self.lattice[i_vac, j_vac], self.lattice[i_neigh, j_neigh] = self.lattice[i_neigh, j_neigh], self.lattice[i_vac, j_vac]
        self.vacancy_position = (i_neigh, j_neigh)
        self.total_energy_per_atom += dE / (self.L ** 2)
        self.time += np.random.uniform()/(np.sum(rates)*1e13) 

    def _unwrap_positions(self, i1, j1, i2, j2):
        # Unwrap the periodic boundary conditions
        dx = (i2 - i1 + self.L // 2) % self.L - self.L // 2
        dy = (j2 - j1 + self.L // 2) % self.L - self.L // 2
        return dx, dy

    def _update_displacement(self, i_old, j_old, i_new, j_new):
        old_atom_index = self.lattice_to_atom_map[(i_old, j_old)]
        new_atom_index = self.lattice_to_atom_map[(i_new, j_new)]
        dx, dy = self._unwrap_positions(i_old, j_old, i_new, j_new)
        self.atom_to_displacement_map[old_atom_index] += np.array([dx, dy])
        self.atom_to_displacement_map[new_atom_index] -= np.array([dx, dy])
        # Swap the positions in the map
        self.atom_to_lattice_map[old_atom_index], self.atom_to_lattice_map[new_atom_index] = (i_new, j_new), (i_old, j_old)
        self.lattice_to_atom_map[(i_old, j_old)], self.lattice_to_atom_map[(i_new, j_new)] = new_atom_index, old_atom_index


    def calculate_warren_cowley_sro(self):
        # Calculate the Warren-Cowley short-range order parameter
        # The order parameter is defined as alpha_AB = 1 - (P_AB)/(2*c_A*c_B)
        n_total = self.L*self.L*2 # Total number of bonds
        n_AB = 0
        for i in range(self.L):
            for j in range(self.L):
                atom_type = self.lattice[i, j]
                if atom_type == 0:
                    continue
                # Neighbors in 4 directions with periodic boundary conditions
                neighbor_position_lists = [((i+1)%self.L, j), ((i-1)%self.L, j),
                                           (i, (j+1)%self.L), (i, (j-1)%self.L)]
                for ni, nj in neighbor_position_lists:
                    neighbor_type = self.lattice[ni, nj]
                    if neighbor_type == 0:
                        continue
                    if atom_type != neighbor_type:
                        n_AB += 1
        p_AB = n_AB/n_total/2
        return 1 - p_AB / (2 * self.c_B * (1 - self.c_B))


In [None]:
fig = go.FigureWidget(make_subplots(rows=3, cols=2, specs=[[{"rowspan": 2}, {}], [None, {}], [{}, {}]]))
fig.update_layout(height=900, width=1200)

def initialize_plot(lattice):
    type_A_x, type_A_y = np.where(lattice == 1)
    type_B_x, type_B_y = np.where(lattice == -1)
    scatterA = go.Scatter(x=type_A_x + 0.5, y=type_A_y + 0.5, mode='markers', marker_line_width=.1,
                          marker=dict(color='#FF9D3B', size=12), name='Atom A')
    scatterB = go.Scatter(x=type_B_x + 0.5, y=type_B_y + 0.5, mode='markers', marker_line_width=.1,
                          marker=dict(color='#37BCFF', size=12), name='Atom B')
    fig.add_trace(scatterA, row=1, col=1)
    fig.add_trace(scatterB, row=1, col=1)
    fig.update_xaxes(range=[0, L], row=1, col=1)
    fig.update_yaxes(range=[0, L], row=1, col=1)
    energy_time_plot = go.Scatter(x=[], y=[], mode='lines+markers', name='Energy')
    fig.add_trace(energy_time_plot, row=1, col=2)
    fig.update_xaxes(title_text="Time (s)", row=1, col=2)
    fig.update_yaxes(title_text="Energy (eV/atom)", row=1, col=2)
    sro_plot = go.Scatter(x=[], y=[], mode='lines+markers', name='SRO')
    fig.add_trace(sro_plot, row=2, col=2)
    fig.update_xaxes(title_text="Time (s)", row=2, col=2)
    fig.update_yaxes(title_text="Warren-Cowley parameter A-B bond", row=2, col=2)
    msd_plot_a = go.Scatter(x=[], y=[], mode='lines+markers', name='MSD A')
    fig.add_trace(msd_plot_a, row=3, col=1)
    fig.update_xaxes(title_text="Time (s)", row=3, col=1)
    fig.update_yaxes(title_text="MSD for A (Å²)", row=3, col=1)
    msd_plot_b = go.Scatter(x=[], y=[], mode='lines+markers', name='MSD B')
    fig.add_trace(msd_plot_b, row=3, col=2)
    fig.update_xaxes(title_text="Time (s)", row=3, col=2)
    fig.update_yaxes(title_text="MSD for B (Å²)", row=3, col=2)


def updater_plot(lattice, step_list, energy_list, time_list, sro_list,
                 msd_A_list, msd_B_list):
    type_A_x, type_A_y = np.where(lattice == 1)
    type_B_x, type_B_y = np.where(lattice == -1)
    with fig.batch_update():
        fig.data[0].x = type_A_x + 0.5
        fig.data[0].y = type_A_y + 0.5
        fig.data[1].x = type_B_x + 0.5
        fig.data[1].y = type_B_y + 0.5
        fig.data[2].x = time_list
        fig.data[2].y = energy_list
        fig.data[3].x = time_list
        fig.data[3].y = sro_list
        fig.data[4].x = time_list
        fig.data[4].y = msd_A_list
        fig.data[5].x = time_list
        fig.data[5].y = msd_B_list

# Define the animation function
async def animate_simulation(steps, T, interval, E_AA, E_BB, E_AB, e0_A, e0_B, c_B):
    global stop_animation
    MC.initialize_parameters(T, E_AA, E_BB, E_AB, e0_A, e0_B, c_B)
    step_list = []
    energy_list = []
    time_list = []
    sro_list = []
    msd_A_list = []
    msd_B_list = []
    # e_average = 0
    for step in range(steps+1):
        if stop_animation:
            break
        MC.one_simulation_step()
        # e_average += MC.total_energy_per_atom / interval
        if step % interval == 0:
            step_list.append(step)
            energy_list.append(MC.total_energy_per_atom)
            sro = MC.calculate_warren_cowley_sro()
            sro_list.append(sro)
            time_list.append(MC.time)
            msd_A, msd_B, _ = MC.calculate_msd()
            msd_A_list.append(msd_A)
            msd_B_list.append(msd_B)
            updater_plot(MC.lattice, step_list, energy_list, time_list, sro_list, msd_A_list, msd_B_list)
            step_info.value = f"Current Step: {step}, time: {MC.time:.2e} s, Energy: {MC.total_energy_per_atom:.4f} eV/atom, Warren-Cowley A-B bond: {sro:.4f}, MSD A: {msd_A:.2e} Å², MSD B: {msd_B:.2e} Å²"
            # e_average = 0
        await asyncio.sleep(1/interval/3)  # Allow interruption

# Callback function for play button
async def on_play_button_clicked(b):
    global stop_animation
    stop_animation = False
    play_button.disabled = True  # Disable the play button
    await animate_simulation(int(total_steps_input.value), float(temperature_input.value),
                            int(interval_input.value), float(E_AA_input.value),
                            float(E_BB_input.value), float(E_AB_input.value),
                            float(e0_A_input.value), float(e0_B_input.value),
                            float(C_B_slider.value) / 100)

# Callback function for stop button
def on_stop_button_clicked(b):
    global stop_animation
    stop_animation = True

# Callback function for restart button
def on_restart_button_clicked(b):
    on_stop_button_clicked(b)
    MC.c_B = float(C_B_slider.value) / 100
    MC.initialize_lattice()
    updater_plot(MC.lattice, [], [], [], [], [], [])
    step_info.value = "Lattice restarted"
    play_button.disabled = False

# Create interactive widgets with layout adjustments
style = {'description_width': 'initial'}
temperature_input = widgets.Text(value='700', description='Temperature (K):', layout=widgets.Layout(width='300px'), style=style)
total_steps_input = widgets.Text(value=f'{int(1e6)}', description='Total Steps:', layout=widgets.Layout(width='300px'), style=style)
interval_input = widgets.Text(value='500', description='Display Interval:', layout=widgets.Layout(width='300px'), style=style)
E_AA_input = widgets.Text(value='-1.0', description='E_AA (eV):', layout=widgets.Layout(width='300px'), style=style)
E_BB_input = widgets.Text(value='-1.0', description='E_BB (eV):', layout=widgets.Layout(width='300px'), style=style)
E_AB_input = widgets.Text(value='-0.9', description='E_AB (eV):', layout=widgets.Layout(width='300px'), style=style)
e0_A_input = widgets.Text(value='0.8', description='e0_A (eV):', layout=widgets.Layout(width='450px'), style=style)
e0_B_input = widgets.Text(value='0.6', description='e0_B (eV):', layout=widgets.Layout(width='450px'), style=style)
C_B_slider = widgets.IntSlider(value=50, min=0, max=100, step=1, description='Concentration B (%):', 
                               layout=widgets.Layout(width='900px'),style=style)
play_button = widgets.Button(description="Play", layout=widgets.Layout(width='300px'))
stop_button = widgets.Button(description="Stop", layout=widgets.Layout(width='300px'))
restart_button = widgets.Button(description="Restart", layout=widgets.Layout(width='300px'))
step_info = widgets.Label(value=f"Lattice initialized.", layout=widgets.Layout(width='1000px'))

play_button.on_click(lambda b: asyncio.ensure_future(on_play_button_clicked(b)))
stop_button.on_click(on_stop_button_clicked)
restart_button.on_click(on_restart_button_clicked)

# Arrange widgets in a more organized layout
inputs_box = widgets.VBox([
    widgets.HBox([temperature_input,total_steps_input,interval_input]),
    widgets.HBox([E_AA_input, E_BB_input, E_AB_input]),
    widgets.HBox([e0_A_input, e0_B_input]),
    widgets.HBox([C_B_slider]),
    widgets.HBox([play_button, stop_button,restart_button]),
    step_info
    ])


# Set up Plotly figure
L = 32 # lattice size
MC = KineticMonteCarloSimulation(L)
initialize_plot(MC.lattice)
# Display widgets and figure
display(fig,inputs_box)
