In [None]:
import torch
import numpy as np

from carle.env import CARLE
import matplotlib
import matplotlib.pyplot as plt

my_fontsize=22
matplotlib.rc("font", size=my_fontsize)

cmap_name = "magma"
my_cmap = plt.get_cmap(cmap_name)


In [None]:
x = np.arange(0, 256, 4).reshape(1,64)

for cmap_name0 in ["magma", "plasma", "inferno", "viridis"]:
    plt.figure(figsize=(25,10))
    plt.imshow(x, cmap=cmap_name0)
    plt.title(f"{cmap_name0}")
    plt.show()


In [None]:
grid = np.zeros((5,5))

grid[2,2] = 100
grid[1,1:4] = 60
grid[2,1] = 60
grid[2,3] = 60
grid[3,1:4] = 60

plt.figure(figsize=(6,6))
plt.imshow(grid, vmin=0, vmax=255, cmap=cmap_name)
plt.xticks([ii - 0.5 for ii in range(5)], ["" for jj in range(5)])
plt.yticks([ii - 0.5 for ii in range(5)], ["" for jj in range(5)])
plt.grid()
plt.title("Moore Neighborhood")
plt.show()

In [None]:
def get_neighbors(grid):
        
    env = CARLE()
    grid_neighbors = env.neighborhood(torch.tensor(grid[np.newaxis, np.newaxis,:,:]).float())
    
    return grid_neighbors.squeeze().detach().cpu().numpy()

def add_numbers(fig, grid, neighbors, subplot=0):
    
    if subplot:
        plt.subplot(1, 2, subplot)
        
    plt.imshow(grid, vmin=0, vmax=1.0, cmap=cmap_name)
    plt.xticks([ii - 0.5 for ii in range(grid.shape[0])], ["" for jj in range(grid.shape[0])])
    plt.yticks([ii - 0.5 for ii in range(grid.shape[1])], ["" for jj in range(grid.shape[1])])
    plt.grid()
    
    # add numbers
    if subplot != 1:
        for ii in range(grid.shape[0]):
            for jj in range(grid.shape[1]):

                x = jj - 0.1
                y = ii + 0.05

                my_text = f"{neighbors[ii,jj]:.0f}"

                my_color = my_cmap(round(1.0 - (grid[ii,jj]/np.max(grid))))

                plt.text(x+.02, y, my_text, fontsize=my_fontsize-8, c=my_color)

    return fig

grid = np.zeros((7,7))
grid[5, 1:4] = 1.0
grid[4, 1:3] = 1.0
grid[3, 1:2] = 1.0

survive_shade = 0.4
birth_shade = 0.2

for step in range(8):
    neighbors = get_neighbors(grid)

    fig = plt.figure(figsize=(18,8))

    fig = add_numbers(fig, grid, neighbors, subplot=1)

    plt.title(f"Morley Glider State Step {step}")
    plt.xlabel("B368/S245")

    grid_update = np.zeros((7,7)) + grid * 0.9

    birth = [3,6,8]
    survive = [2,4,5]

    for b in birth:
        grid_update[(neighbors == b) * (grid == 0.0)] = birth_shade

    for s in survive:
        grid_update[(neighbors == s) * (grid == 1.0)] = survive_shade

    fig = add_numbers(fig, grid_update, neighbors, subplot=2)

    plt.title("Morley Glider Update")
    plt.xlabel("B368/S245")
    plt.show()

    grid = np.zeros((7,7))
    grid[(grid_update == survive_shade) + (grid_update == birth_shade)] = 1.0
    