In [None]:
import numpy as np
import scipy as sp
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display
import asyncio
# Initialize the lattice with equal number of A and B atoms
L = 50  # Lattice size
def InitializeLattice():
    lattice = np.random.choice([-1, 1], size=(L, L)) # A L by L array representing the lattice
    while np.sum(lattice) != 0:  # Using +1 for A and -1 for B
        lattice = np.random.choice([-1, 1], size=(L, L))
    return lattice

# Helper function to calculate energy contribution for a atom and its neighbors
def CalculateEnergyContribution(lattice, i, j, E_AA, E_BB, E_AB):
    atom_type = lattice[i, j]
    # Neighbors in 4 directions with periodic boundary conditions
    neighbor_position_lists = [((i+1)%L, j), ((i-1)%L, j), (i, (j+1)%L), (i, (j-1)%L)] 
    E = 0
    for ni, nj in neighbor_position_lists:
        neighbor_type = lattice[ni, nj]
        if atom_type == 1 and neighbor_type == 1:  # AA bond
            E += E_AA
        elif atom_type == -1 and neighbor_type == -1:  # BB bond
            E += E_BB
        else:  # AB bond
            E += E_AB
    return E

# Define a function to calculate the total energy using bond counting model
# E_AA, E_BB, E_AB are the bond energies for AA, BB, and AB bonds
def GetTotalEnergy(lattice, E_AA, E_BB, E_AB):
    E = 0
    for i in range(L):
        for j in range(L):
            E += CalculateEnergyContribution(lattice, i, j, E_AA, E_BB, E_AB)
    return E / 2  # Each bond counted twice

# Define a function to calculate the energy change if two atoms are swapped
# The positions of the two atoms are (i1, j1) and (i2, j2)
def GetEnergyChange(lattice, pos1, pos2, E_AA, E_BB, E_AB):
    i1, j1 = pos1
    i2, j2 = pos2
    atom1_type, atom2_type = lattice[i1, j1], lattice[i2, j2]
    if (i1 == i2 and j1 == j2) or (atom1_type == atom2_type):
        return 0  # Ignore if the same type of atoms are picked
    # Energy contribution before the swap
    dE_before = CalculateEnergyContribution(lattice, i1, j1, E_AA, E_BB, E_AB) + \
                CalculateEnergyContribution(lattice, i2, j2, E_AA, E_BB, E_AB) 
    # Perform the swap
    lattice[i1, j1], lattice[i2, j2] = lattice[i2, j2], lattice[i1, j1]
    # Energy contribution after the swap
    dE_after = CalculateEnergyContribution(lattice, i1, j1, E_AA, E_BB, E_AB) + \
               CalculateEnergyContribution(lattice, i2, j2, E_AA, E_BB, E_AB)
    # Revert the swap
    lattice[i1, j1], lattice[i2, j2] = lattice[i2, j2], lattice[i1, j1]
    # Energy change. This function also works when the two atoms are the neighbors
    # because their internal energy contribution are canceled out
    dE = dE_after - dE_before
    return dE
    

def AtomSwapAttempt(lattice, T, E_AA, E_BB, E_AB):
    kB = sp.constants.value('Boltzmann constant in eV/K')
    # Pick two random atoms in the form of pos1 = (i1, j1) and pos2 = (i2, j2)
    i1, j1 = np.random.randint(0, L, size=2)
    i2, j2 = np.random.randint(0, L, size=2)
   
    dE = GetEnergyChange(lattice, (i1, j1), (i2, j2), E_AA, E_BB, E_AB)
    # Decide whether to change the positions of the two atoms and return the energy change
    if dE < 0 or np.random.rand() < np.exp(-dE / (kB*T)):
        # Successful swap
        lattice[i1, j1], lattice[i2, j2] = lattice[i2, j2], lattice[i1, j1]
        return dE
    # No swap
    return 0
# Set up Plotly figure
fig = go.FigureWidget(make_subplots(rows=1, cols=2))
fig.update_layout(height=500, width=1000)  # Adjust height and width here
marker_size = 6
# Create interactive widgets with layout adjustments
layout = widgets.Layout(width='300px')
style = {'description_width': 'initial'}
temperature_input = widgets.Text(value='500', description='Temperature (K):', layout=layout, style=style)
total_steps_input = widgets.Text(value=f'{100*L*L}', description='Total Steps:', layout=layout, style=style)
interval_input = widgets.Text(value='100', description='Display Interval:', layout=layout, style=style)
E_AA_input = widgets.Text(value='-1.0', description='E_AA (eV):', layout=layout, style=style)
E_BB_input = widgets.Text(value='-1.0', description='E_BB (eV):', layout=layout, style=style)
E_AB_input = widgets.Text(value='-2.0', description='E_AB (eV):', layout=layout, style=style)
play_button = widgets.Button(description="Play", layout=layout)
stop_button = widgets.Button(description="Stop", layout=layout)
restart_button = widgets.Button(description="Restart", layout=layout)
step_info = widgets.Label(value=f"Lattice initialized.", layout=widgets.Layout(width='1000px'))

def InitializePlot(lattice):
    energy_plot = go.Scatter(x=[], y=[], mode='lines+markers', name='Energy')
    fig.add_trace(energy_plot, row=1, col=1)
    fig.update_xaxes(title_text="Steps", row=1, col=1)
    fig.update_yaxes(title_text="Total Energy (eV/atom)", row=1, col=1)

    type_A_x, type_A_y = np.where(lattice == 1)
    type_B_x, type_B_y = np.where(lattice == -1)
    scatter1 = go.Scatter(x=type_A_x + 0.5, y=type_A_y + 0.5, mode='markers', marker_line_width=.2,
                          marker=dict(color='#FF9D3B', size=marker_size), name='Type A')
    scatter2 = go.Scatter(x=type_B_x + 0.5, y=type_B_y + 0.5, mode='markers', marker_line_width=.2,
                          marker=dict(color='#37BCFF', size=marker_size), name='Type B')
    fig.add_trace(scatter1, row=1, col=2)
    fig.add_trace(scatter2, row=1, col=2)
    fig.update_xaxes(range=[0, L], row=1, col=2)  
    fig.update_yaxes(range=[0, L], row=1, col=2)

    
def UpdaterPlot(lattice,step_list, energy_list):
    type_A_x, type_A_y = np.where(lattice == 1)
    type_B_x, type_B_y = np.where(lattice == -1)
    with fig.batch_update():
        fig.data[0].x = step_list
        fig.data[0].y = energy_list
        fig.data[1].x = type_A_x + 0.5
        fig.data[1].y = type_A_y + 0.5
        fig.data[2].x = type_B_x + 0.5
        fig.data[2].y = type_B_y + 0.5

# Define the animation function
async def AnimateSimulation(steps, T, interval, E_AA, E_BB, E_AB):
    global stop_animation
    energy = GetTotalEnergy(lattice, E_AA, E_BB, E_AB)/(L*L)
    step_info.value = f"Current Step: 0, Energy: {energy:.2f} eV/atom"
    step_list = []
    energy_list = []
    for step in range(steps+1):
        if stop_animation:
            break
        dE = AtomSwapAttempt(lattice, T, E_AA, E_BB, E_AB)
        energy += dE/(L*L)
        # time.sleep(interval/1000)
        if step % interval == 0:
            step_list.append(step)
            energy_list.append(energy)
            UpdaterPlot(lattice, step_list, energy_list)
            step_info.value = f"Current Step: {step}, Energy: {energy:.4f} eV/atom"
        await asyncio.sleep(0.001)  # Allow interruption

stop_animation = False
# Callback function for play button
async def OnPlayButtonClicked(b):
    global stop_animation
    stop_animation = False
    play_button.disabled = True  # Disable the play button
    await AnimateSimulation(int(total_steps_input.value), float(temperature_input.value),
                            int(interval_input.value), float(E_AA_input.value),
                            float(E_BB_input.value), float(E_AB_input.value))

# Callback function for stop button
def OnStopButtonClicked(b):
    global stop_animation
    stop_animation = True

# Callback function for restart button
def OnRestartButtonClicked(b):
    OnStopButtonClicked(b)
    global lattice
    lattice = InitializeLattice()
    UpdaterPlot(lattice, [], [])
    step_info.value = "Lattice restarted"
    play_button.disabled = False


lattice = InitializeLattice()
InitializePlot(lattice)
play_button.on_click(lambda b: asyncio.ensure_future(OnPlayButtonClicked(b)))
stop_button.on_click(OnStopButtonClicked)
restart_button.on_click(OnRestartButtonClicked)

# Arrange widgets in a more organized layout
inputs_box = widgets.VBox([
    widgets.HBox([temperature_input,total_steps_input,interval_input]),
    widgets.HBox([E_AA_input, E_BB_input, E_AB_input], ),
    widgets.HBox([play_button, stop_button,restart_button]),
    step_info
    ])

# Display widgets and figure
display(fig,inputs_box)