# Ising model Monte-Carlo with Metropolis Checkerboard update on Tensorflow

In [1]:
#@title Imports
# import numpy as np
from numpy import __version__ as np_version_
from numpy import cumsum as np_cumsum
from numpy import arange as np_arange
# import tensorflow as tf
from tensorflow import __version__ as tf_version_ #(v2.3.0)
from tensorflow import function as tf_function
from tensorflow import rank as tf_rank
from tensorflow import range as tf_range
from tensorflow import transpose as tf_transpose
from tensorflow import concat as tf_concat
from tensorflow import less as tf_less
from tensorflow import shape as tf_shape
from tensorflow import float32 as tf_float32
from tensorflow import exp as tf_exp
from tensorflow import where as tf_where
from tensorflow import cumsum as tf_cumsum
from tensorflow import reduce_mean as tf_reduce_mean
from tensorflow import constant as tf_constant
from tensorflow.sparse import to_dense as tf_sparse_to_dense
from tensorflow.sparse import SparseTensor as tf_sparse_SparseTensor
from tensorflow.random import uniform as tf_random_uniform
from tensorflow.random import Generator as tf_random_Generator
from tensorflow.math import floormod as tf_math_floormod
# import ipywidgets as widgets
from ipywidgets import __version__ as widgets_version_
from ipywidgets import Accordion as widgets_Accordion
from ipywidgets import VBox as widgets_VBox
from ipywidgets import GridBox as widgets_GridBox
from ipywidgets import Layout as widgets_Layout
from ipywidgets import Tab as widgets_Tab
from ipywidgets import interactive_output as widgets_interactive_output
from ipywidgets import Output as widgets_Output
from ipywidgets import FloatSlider as widgets_FloatSlider
from ipywidgets import IntSlider as widgets_IntSlider
from ipywidgets import Button as widgets_Button
from ipywidgets import ToggleButton as widgets_ToggleButton
from ipywidgets import ToggleButtons as widgets_ToggleButtons
from ipywidgets import Checkbox as widgets_Checkbox
from ipywidgets import Text as widgets_Text
from IPython.display import display
import time
from platform import python_version
%matplotlib ipympl
from matplotlib import __version__ as mpl_version_
from matplotlib import rcParams as mpl_rcParams
from matplotlib import rcParams as mpl_rcParams
from matplotlib import pyplot as plt
from matplotlib import pylab as plb
mpl_rcParams['figure.facecolor'] = '#222222'
mpl_rcParams['axes.facecolor'] = '#222222'
mpl_rcParams['axes.labelcolor'] = '#edece9'
mpl_rcParams['xtick.color'] = '#edece9'
mpl_rcParams['ytick.color'] = '#edece9'
mpl_rcParams['text.color'] = '#edece9' 

In [2]:
#@title Requirements
print("python-"+python_version())
print("tensorflow=="+tf_version_)
print("matplotlib=="+mpl_version_)
print("numpy=="+np_version_)
print("ipywidgets=="+widgets_version_)

python-3.7.9
tensorflow==2.4.0
matplotlib==3.3.3
numpy==1.19.5
ipywidgets==7.6.3


In [3]:
#@title Reset Model
from tensorflow.python.framework import ops
ops.reset_default_graph()
print_shape = True 
DIMS = 2
dims = DIMS
BOUNDARY_CONDITION = [1]*DIMS

In [4]:
#@title Boundary Condition
@tf_function
def shift_spin(x, axis=0, shift=0, bcaxis=1):    
    shifted_x = x
    # dims = tf_rank(x)
    if shift==1:
        perm = tf_range(start=0, limit=dims, delta=1)
        permd = tf_sparse_to_dense(tf_sparse_SparseTensor([[0],[axis]], [axis,-axis], [dims]))
        perm = perm + permd
        if axis>0:
            shifted_x = tf_transpose(shifted_x, perm)
        shifted_x = tf_concat([shifted_x[1:], bcaxis*shifted_x[:1]], axis=0)
        if axis>0:
            shifted_x = tf_transpose(shifted_x, perm)
    if shift==-1:
        perm = tf_range(start=0, limit=dims, delta=1)
        permd = tf_sparse_to_dense(tf_sparse_SparseTensor([[0],[axis]], [axis,-axis], [dims]))
        perm = perm + permd
        if axis>0:
            shifted_x = tf_transpose(shifted_x, perm)
        shifted_x = tf_concat([bcaxis*shifted_x[-1:], shifted_x[:-1]], axis=0)
        if axis>0:
            shifted_x = tf_transpose(shifted_x, perm)
    return shifted_x
# shift_spin(spin_input)

In [5]:
#@title Potential
@tf_function
def potential_persite(x, uniformJ, uniformH, bc, print_shape=False):
    potential = -uniformH + 0*x
    # dims = tf_rank(x)
    # while_condition = lambda axis: tf_less(axis, dims)
    for axis in range(dims):
        for s in [-1,1]:
            potential = potential - uniformJ[axis]*shift_spin(x, axis=axis, shift=s, bcaxis=bc[axis])
    
    if print_shape: print("potential:"+str(potential.shape))
    return potential

In [6]:
#@title Checkerboard update
@tf_function
def update(x, stncl, Temperature, uniformJ, uniformH, bc):
    r = tf_random_uniform(tf_shape(x), minval=0, maxval=1, dtype=tf_float32)
    dE = -2*x*stncl*potential_persite(x, uniformJ=uniformJ, uniformH=uniformH, bc=bc)
    p = tf_where(dE<=0.0,stncl,stncl*tf_exp(-dE/Temperature))
    y = tf_where(p>1-r,-x,x)

    return y
# update(spin_input,stencil)

In [7]:
#@title Measurement Step
@tf_function
def measure_step(x, stncl, Temperature, uniformJ, uniformH, bc, steps=1):
    spinin = x
    for i in range(steps):
        # print("here!")
        y = update(spinin, stncl, Temperature, uniformJ, uniformH, bc)
        # print("here!")
        spinin = update(y, 1-stncl, Temperature, uniformJ, uniformH, bc)
        # print(f"here! {i}")
    return spinin

In [8]:
#@title Testing 
def initialize_stencil_spin(sshape, p=0.5):
    gs = tf_random_Generator.from_non_deterministic_state()
    spin_input = tf_where(gs.uniform(shape=sshape)<p,1.0,-1.0)
    stencil = spin_input*0
    for axis in range(len(sshape)):
        stencil = tf_math_floormod(stencil + tf_cumsum(spin_input*0+1, axis=axis), 2)
    # if print_shape: print("stencil:"+str(stencil.shape))
    return spin_input,stencil

spin_input,stencil = initialize_stencil_spin([1000]*DIMS)

spin_output = measure_step(x=spin_input, 
                           stncl=stencil, 
                           Temperature=tf_constant(2.27), 
                           uniformJ=tf_constant([1.0]*DIMS), 
                           uniformH=tf_constant(0.0), 
                           bc=tf_constant(BOUNDARY_CONDITION, tf_float32), 
                           steps=1)

In [9]:
#@title Magnetization and Energy

@tf_function
def magnetization_per_site(x,moments=[1,2,3,4]):
    magnetization = tf_reduce_mean(x)
    magnetization_moments = [magnetization**i for i in moments]
    return magnetization_moments

@tf_function
def energy_per_site(x, uniformJ, uniformH, bc, moments=[1,2,3,4]):
    energy = -uniformH + 0*x
    # dims = tf_rank(x)
    # while_condition = lambda axis: tf_less(axis, dims)
    for axis in range(dims):
        energy = energy - uniformJ[axis]*shift_spin(x, axis=axis, shift=1, bcaxis=bc[axis]) 
    energy = energy*x
    energy = tf_reduce_mean(energy)
    energy_moments = [energy**i for i in moments]
    return energy_moments

# Interactive

In [37]:
#@title Plot Update
def show_plot(**kwargs):
    global fig,axM,axE,axpin,gs,spin,energys,magnetizations,MCS_steps
    try:
        axM.cla()
        axE.cla()
        axpin.cla()
        # plt.close();
        # fig = plt.figure(figsize=(16,8),constrained_layout=True)
        # gs = fig.add_gridspec(2, 4)
        # axM = fig.add_subplot(gs[0,0:2])
        # axE = fig.add_subplot(gs[1,0:2])
        # axpin = fig.add_subplot(gs[:,2:])
    except:
        fig = plt.figure(figsize=(16,8),constrained_layout=True)
        gs = fig.add_gridspec(2, 4)
        axM = fig.add_subplot(gs[0,0:2])
        axE = fig.add_subplot(gs[1,0:2])
        axpin = fig.add_subplot(gs[:,2:])
        plt.ion()
    col = ["#f58231","#3cb44b","#4363d8","#911eb4"]
    mark = ["s","D","^","v"]
    axM.cla()
    axE.cla()
    axpin.cla()
    # axM.set_xlabel("Monte-Carlo Steps");
    axM.set_ylabel("Magnetization");
    leg = False
    for mi,mv in enumerate(show_Ms):
        if mv.value:
            axM.plot(MCS_steps,[m[mi] for m in magnetizations],c=col[mi],marker=mark[mi],linestyle='dotted',markersize=2,label=r"$"+mv.description+r"$")
            leg = True
    for mi,mv in enumerate(show_Mavgs):
        if mv.value:
            axM.plot(MCS_steps,np_cumsum([m[mi] for m in magnetizations])/np_arange(1,len(magnetizations)+1),c=col[mi],label=r"$ \langle "+mv.description[1:-1]+r" \rangle $")
            leg = True
    if leg:
        axM.legend()
    axE.set_xlabel("Monte-Carlo Steps");
    axE.set_ylabel("Energy");
    leg = False
    for ei,ev in enumerate(show_Es):
        if ev.value:
            axE.plot(MCS_steps,[e[ei] for e in energys],c=col[ei],marker=mark[ei],linestyle='dotted',markersize=2,label=r"$"+ev.description+r"$")
            leg = True
    for ei,ev in enumerate(show_Eavgs):
        if ev.value:
            axE.plot(MCS_steps,np_cumsum([e[ei] for e in energys])/np_arange(1,len(energys)+1),c=col[ei],label=r"$\langle "+ev.description[1:-1]+r" \rangle $")
            leg = True
    if leg:
        axE.legend()
    axpin.set_title("MCS:"+str(MCS_steps[-1]))
    axpin.imshow(spin,vmin=-1,vmax=1)
    axpin.axis('off')
    fig.canvas.draw_idle()
    # plt.pause(0.001)
    # fig.show()
    
    # plt.show()
    display(fig)

In [38]:
#@title Main Buttons
init_but = widgets_Button(description="Re/Initialize",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'},
                          button_style='success')
def init_button(button_click):
    global spin,stencil,energys,magnetizations,MCS_steps
    # print(f"\rInitializing...                                    ",end=' ',flush=True)
    status_text.value = f"Initializing... "
    spin,stencil = initialize_stencil_spin([l.value for l in L_sliders],p=S_slider.value)
    magnetizations = [magnetization_per_site(spin)]
    energys = [energy_per_site(spin,uniformJ=[j.value for j in J_sliders],uniformH=H_slider.value,bc=[b.value for b in BC_sliders])]
    MCS_steps = [0] 
    update_dict["Update"].value = not update_dict["Update"].value
    # print(f"\rInitializing... Done!                              ",end=' ',flush=True)
    status_text.value = f"Initialized! "
init_but.on_click(init_button)

sim_but = widgets_Button(description="Simulate",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'},
                         button_style='success')
def sim_button(button_click):
    global spin,stencil,energys,magnetizations,MCS_steps
    tStart = time.time()
    for step in range(MCS_sliders[1].value):
        # print(f"\rSimulating... {int(100*step/MCS_sliders[1].value)}% - {int(0.5+time.time()-tStart)} seconds",end=' ',flush=True)
        status_text.value = f"Simulating... {int(100*step/MCS_sliders[1].value)}% - {int(0.5+time.time()-tStart)} seconds"
        spin = measure_step(x=spin, 
                            stncl=stencil, 
                            Temperature=tf_constant(T_slider.value), 
                            uniformJ=tf_constant([j.value for j in J_sliders]), 
                            uniformH=tf_constant(H_slider.value), 
                            bc=tf_constant([b.value for b in BC_sliders], tf_float32), 
                            steps=MCS_sliders[0].value)
        magnetizations.append(magnetization_per_site(spin))
        energys.append(energy_per_site(spin,uniformJ=[j.value for j in J_sliders],uniformH=H_slider.value,bc=[b.value for b in BC_sliders]))
        MCS_steps.append(MCS_steps[-1]+MCS_sliders[0].value)
        if (plot_period.value and (step+1)%plot_period.value==0) or (step+1)==MCS_sliders[1].value :
            # show_plot()
            update_dict["Update"].value = not update_dict["Update"].value
    # print(f"\rSimulating...  {int(100)}% - {int(0.5+time.time()-tStart)} seconds",end=' ',flush=True)
    status_text.value = f"Simulating...  {int(100)}% - {int(0.5+time.time()-tStart)} seconds"
sim_but.on_click(sim_button)


In [39]:
#@title Buttons and Sliders
S_slider = widgets_FloatSlider(value=0.5,
                               min=0.0,
                               max=1.0,
                               step=0.01,
                               description="p",
                               readout=True,
                               style={'description_width': 'initial'},
                               layout={'width':'auto'},
                               continuous_update=False,
                               orientation='horizontal'
                              )

T_slider = widgets_FloatSlider(value=2.27,
                               min=0.01,
                               max=9.99,
                               step=0.01,
                               description="T",
                               readout=True,
                               style={'description_width': 'initial'},
                               layout={'width':'auto'},
                               continuous_update=False,
                               orientation='horizontal'
                              )

J_sliders = [widgets_FloatSlider(value=1.0,
                                 min=-2.0,
                                 max=2.0,
                                 step=0.01,
                                 description=f"J[{i}]",
                                 readout=True,
                                 style={'description_width': 'initial'},
                                 layout={'width':'auto'},
                                 continuous_update=False,
                                 orientation='horizontal'
                                ) for i in range(DIMS)]

H_slider = widgets_FloatSlider(value=0.0,
                               min=-5.0,
                               max=5.0,
                               step=0.01,
                               description=f"H",
                               readout=True,
                               style={'description_width': 'initial'},
                               layout={'width':'auto'},
                               continuous_update=False,
                               orientation='horizontal'
                              )

L_sliders = [widgets_IntSlider(value=100,
                               min=2,
                               max=10000,
                               step=2,
                               description=f"L[{i}]",
                               readout=True,
                               style={'description_width': 'initial'},
                               layout={'width':'auto'},
                               continuous_update=False,
                               orientation='horizontal'
                              ) for i in range(DIMS)]

BC_sliders = [widgets_FloatSlider(value=1,
                                  min=-1,
                                  max=1,
                                  step=0.01,
                                  description=f"BC[{i}]",
                                  readout=True,
                                  style={'description_width': 'initial'},
                                  layout={'width':'auto'},
                                  continuous_update=False,
                                  orientation='horizontal'
                                 ) for i in range(DIMS)]

rnd_but = widgets_Button(description="p[Spin UP]=p[Spin DOWN]",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'})
def rnd_button(button_click):
    S_slider.value = 0.5
rnd_but.on_click(rnd_button)
sqr_but = widgets_Button(description="Equal Lattice dimensions: L[i]=L[0]",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'})
def sqr_button(button_click):
    for L_slider in L_sliders:
        L_slider.value = L_sliders[0].value
sqr_but.on_click(sqr_button)
same_but = widgets_Button(description="Same Boundary condition: BC[i]=BC[0]",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'})
def same_button(button_click):
    for BC_slider in BC_sliders:
        BC_slider.value = BC_sliders[0].value
same_but.on_click(same_button)
iso_but = widgets_Button(description="Isotropic Interactions: J[i]=J[0]",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'})
def iso_button(button_click):
    for J_slider in J_sliders:
        J_slider.value = J_sliders[0].value
iso_but.on_click(iso_button)
h0_but = widgets_Button(description="Zero Field",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'})
def h0_button(button_click):
    H_slider.value = 0.0
h0_but.on_click(h0_button)
crt_but = widgets_Button(description="Critical Temperature",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'})
def crt_button(button_click):
    T_slider.value = 2.27
crt_but.on_click(crt_button)

def_but = widgets_Button(description="Default Parameters",
                                 layout={'width':'auto', 'height':'30px', 'align':'center'},
                         button_style='info')
def def_button(button_click):
    for BC_slider in BC_sliders:
        BC_slider.value = 1
    for J_slider in J_sliders:
        J_slider.value = 1.0
    H_slider.value = 0.0
    T_slider.value = 2.27
def_but.on_click(def_button)

show_spin = widgets_ToggleButton(value=True,
                                 description='Spin Configuration',
                                 layout={'width':'auto', 'height':'30px', 'align':'center'}
                                )
show_Ms = [widgets_ToggleButton(value=True,
                               description=f"M{s}",
                               layout={'width':'auto', 'height':'30px', 'align':'center'}
                              ) for s in ["","^2","^3","^4"] ]
show_Mavgs = [widgets_ToggleButton(value=True,
                               description=f"<M{s}>",
                               layout={'width':'auto', 'height':'30px', 'align':'center'}
                              ) for s in ["","^2","^3","^4"] ]
show_Es = [widgets_ToggleButton(value=True,
                               description=f"E{s}",
                               layout={'width':'auto', 'height':'30px', 'align':'center'}
                              ) for s in ["","^2","^3","^4"] ]
show_Eavgs = [widgets_ToggleButton(value=True,
                               description=f"<E{s}>",
                               layout={'width':'auto', 'height':'30px', 'align':'center'}
                              ) for s in ["","^2","^3","^4"] ]
MCS_sliders = [widgets_IntSlider(value=i/100,
                                 min=1,
                                 max=i,
                                 step=1,
                                 description=f"{s}",
                                 readout=True,
                                 style={'description_width': 'initial'},
                                 layout={'width':'auto', 'height':'30px', 'align':'center'},
                                 continuous_update=False,
                                 orientation='horizontal'
                                ) for s,i in [["Sampling Interval",1000],["Number of Samples",10000]] ]

plot_period = widgets_ToggleButtons(options=[('after simulation Ends',0),('every 100th Interval',100),('every 10th Interval',10),('every Interval',1)],
                                      description='Update Plot:',
                                      disabled=False,
                                      style={'button_width': 'auto'},
                                      layout={'width':'auto', 'height':'30px', 'align':'center'},
                                      button_style='warning',#'success', 'info', 'warning', 'danger' or ''
                                      tooltips=['Recommended', 'Slow', 'Slower', 'Slowest'],
                                      # icons=['check'] * 3
                                   )

update_dict = {
    "spin": show_spin,
    "M1": show_Ms[0],
    "M2": show_Ms[1],
    "M3": show_Ms[2],
    "M4": show_Ms[3],
    "Mavg1": show_Mavgs[0],
    "Mavg2": show_Mavgs[1],
    "Mavg3": show_Mavgs[2],
    "Mavg4": show_Mavgs[3],
    "E1": show_Es[0],
    "E2": show_Es[1],
    "E3": show_Es[2],
    "E4": show_Es[3],
    "Eavg1": show_Eavgs[0],
    "Eavg2": show_Eavgs[1],
    "Eavg3": show_Eavgs[2],
    "Eavg4": show_Eavgs[3],
    "Update": widgets_Checkbox(value=True,
                               description='Update',
                               disabled=False,
                               button_style=''
                              )
}
status_text = widgets_Text(
    value='Initialized!',
    placeholder='Paste ticket description here!',
    description='Status:',
    disabled=True,
    layout={'width':'auto', 'height':'30px', 'align':'center'}
)

In [40]:
#@title Initialize for the first time
spin,stencil = initialize_stencil_spin([l.value for l in L_sliders])
magnetizations = [magnetization_per_site(spin)]
energys = [energy_per_site(spin,uniformJ=[j.value for j in J_sliders],uniformH=H_slider.value,bc=[b.value for b in BC_sliders])]
MCS_steps = [0]

In [42]:
#@title Main display
L_accordian = widgets_Accordion(children=[widgets_VBox([sqr_but,*L_sliders])])
L_accordian.set_title(0,'Lattice Dimensions')
S_accordian = widgets_Accordion(children=[widgets_VBox([rnd_but,S_slider])])
S_accordian.set_title(0,'Spin UP probability')
init_tab = widgets_VBox([S_accordian,L_accordian,init_but])

BC_accordian = widgets_Accordion(children=[widgets_VBox([same_but,*BC_sliders])])
BC_accordian.set_title(0,'Boundary Conditions')
J_accordian = widgets_Accordion(children=[widgets_VBox([iso_but,*J_sliders])])
J_accordian.set_title(0,'Nearest-neighbor Interactions')
H_accordian = widgets_Accordion(children=[widgets_VBox([h0_but,H_slider])])
H_accordian.set_title(0,'Uniform Field')
T_accordian = widgets_Accordion(children=[widgets_VBox([crt_but,T_slider])])
T_accordian.set_title(0,'Temperature')
param_tab = widgets_VBox([BC_accordian,J_accordian,H_accordian,T_accordian,def_but])

MCS_accordian = widgets_Accordion(children=[widgets_VBox(MCS_sliders)])
MCS_accordian.set_title(0,'Monte-Carlo Steps')
Spin_accordian = widgets_Accordion(children=[widgets_GridBox([show_spin],
                                                          layout=widgets_Layout(grid_template_columns="repeat(1, auto)"))])
Spin_accordian.set_title(0,'Spin')
M_accordian = widgets_Accordion(children=[widgets_GridBox([*show_Ms,*show_Mavgs],
                                                          layout=widgets_Layout(grid_template_columns="repeat(4, auto)"))])
M_accordian.set_title(0,'Magnetization')
E_accordian = widgets_Accordion(children=[widgets_GridBox([*show_Es,*show_Eavgs],
                                                          layout=widgets_Layout(grid_template_columns="repeat(4, auto)"))])
E_accordian.set_title(0,'Energy')
measr_tab = widgets_VBox([MCS_accordian,M_accordian,E_accordian,sim_but])

tabs = widgets_Tab(children=[init_tab,
                             param_tab,
                             measr_tab])

tabs.set_title(0,"Initialization")
tabs.set_title(1,"Parameters")
tabs.set_title(2,"Measurement")
plot_out = widgets_interactive_output(show_plot, update_dict)
out1 = widgets_Output(layout={'border': '0px solid black'})
with out1:
    out1.clear_output()
    display(widgets_VBox([tabs,
                          plot_period,
                          plot_out,
                          status_text,
                          widgets_GridBox([init_but,def_but,sim_but], layout=widgets_Layout(grid_template_columns="repeat(3, auto)")),
                         ]),
                          layout={'color':'black'})
    
out1

Output(layout=Layout(border='0px solid black'))