In [None]:
import numpy as np
import os
import shutil
import tempfile

from matplotlib import pyplot as plt
from matplotlib.animation import FuncAnimation
from PIL import Image

In [None]:
class Ising(object):
    """
    A class for 2D grid-like Ising model. Notice that the class only stores
    model state (the current assignment of random variables); you have to 
    compute the needed potential / conditional probability on your own.
    """
    def __init__(self, dim, Js, Jst):
        """
        Initialize a random instance of 2D grid-like Ising model.
        
        Parameters:
            dim (int): the dimension of the grid 
                       (i.e. the model has dim*dim RVs)
            Js (float): the unary parameter $J_s$ 
            Jst (float): the binary parameter $J_{st}$
        """
        self._dim = dim
        self._Js = Js
        self._Jst = Jst
        self.init_state()

    def init_state(self):
        """
        Initialize the state of the model randomly.
        """ 
        self._state = np.random.randint(0, 2, (self._dim, self._dim))
        self._state = 2 * self._state - 1
    
    @property
    def dim(self):
        return self._dim
    
    @property
    def Js(self):
        return self._Js
    
    @property
    def Jst(self):
        return self._Jst
    
    @property
    def state(self):
        return self._state

    @state.setter
    def state(self, state):
        if isinstance(state, list):
            assert (len(state) == self._dim)
            assert (len(state[0]) == self._dim)
        elif isinstance(state, np.ndarray):
            assert (state.shape == (self._dim, self._dim))
        else:
            raise TypeError("only support list and np.ndarray")
            
        for i in range(self._dim):
            for j in range(self._dim):
                assert (state[i][j] == 1 or state[i][j] == -1)
        self._state = np.array(state)

## Some helper functions for visualization and submission

In [None]:
def visualize_state(state, title=None, file_path=None):
    """
    Visualize the 2D state as an image.
    
    Parameters:
        state (np.ndarray): the 2D state to be visualized
        title (str, optional): the title of the plot
        file_path (str, optional): the path to save the image
    """
    assert isinstance(state, np.ndarray)

    N, M = state.shape
    X, Y = np.meshgrid(range(N + 1), range(M + 1))
    ax = plt.axes()
    ax.imshow(state, "binary", vmin=-1, vmax=1)
    plt.setp(ax.get_yticklabels(), visible=False)
    plt.setp(ax.get_xticklabels(), visible=False)      
    
    if title is not None:
        plt.title(title)
    plt.axis('tight')
    if file_path is not None:
        plt.savefig(file_path)
    else:
        plt.show()
    plt.clf()

In [None]:
def generate_gif(samples, output_path, image_dir=None):
    """
    Generate a GIF using collected samples (not required)
    
    Parameters:
        samples (List of np.ndarray): the list of samples
        output_path: the path to save the GIF
        image_dir (optional): the directory for saving the images,
            if not provided, the images will be removed afterwards
    """
    if image_dir is None:
        image_dir = tempfile.mkdtemp()
        rm_flag = True
    else:
        if not os.path.exists(image_dir):
            os.makedirs(image_dir)
        rm_flag = False
        
    for i, sample in enumerate(samples):
        visualize_state(sample, f"Time={i}", f"{image_dir}/{i}.png")
    
    images = []
    for i in range(len(samples)):
        images.append(Image.open(f"{image_dir}/{i}.png"))
        
    images[0].save(
        output_path,
        save_all=True,
        append_images=images[1:],
        duration=20)
    
    if rm_flag:
        shutil.rmtree(image_dir)

In [None]:
def merge_images(filenames, n_rows, n_cols, output_path):
    """
    Merge a list of image files into single image (for submission)
    
    Parameters:
        filenames (list of str): the images which are going to be merged
        n_rows (int): the number of rows in the image grid
        n_cols (int): the number of columns in the image grid
        output_path (str): the output path of the merged image
    """
    images = [Image.open(filename) for filename in filenames]
    width, height = images[0].size
    new_image = Image.new('RGB', (width * n_cols, height * n_rows))
    for row in range(n_rows):
        for col in range(n_cols):
            new_image.paste(
                images[row * n_cols + col],
                (width * col, height * row))
            
    new_image.save(output_path)