In [None]:
import numpy as np
import matplotlib.pyplot as plt
from random import randint, random
import sys
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import count

In [None]:
max_delta = 1
grid_size = 5

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"
device = torch.device(dev)

dtype = torch.float32

In [None]:
def grid_images(images, save=None):
    if not isinstance(images, np.ndarray):
        images = np.stack(images)
    assert len(images.shape) >= 2, "pas assez de dimensions"
    assert len(images.shape) <= 4, "trop de dimensions"
    if len(images.shape) == 2:
        images = np.expand_dims(images, 0)
    if len(images.shape) == 3:
        images = np.expand_dims(images, 0)
    plt.figure(figsize=(images.shape[1], images.shape[0]))
    print(images.shape)

    for j in range(images.shape[1]):
        for i in range(images.shape[0]):
            plt.subplot(images.shape[0], images.shape[1], i * images.shape[1] + j + 1)
            plt.imshow(images[i, j])

    plt.show()
    if save is not None:
        plt.savefig(save)

In [None]:
"""
    Game of life functions 
"""

def get_padded_version_n(X):
    X_pad = np.zeros((X.shape[0], X.shape[-2] + 2, X.shape[-1] + 2), dtype=X.dtype)
    X_pad[:, 1:-1,1:-1] += X
    
    X_pad[:, 0, 1:-1] = X[:, -1, :]
    X_pad[:, -1, 1:-1] = X[:, 0, :]
    
    X_pad[:, 1:-1, 0] = X[:, :, -1]
    X_pad[:, 1:-1, -1] = X[:, :, 0]
    
    X_pad[:, 0, 0] = X[:, -1, -1]
    X_pad[:, 0, -1] = X[:, -1, 0]
    X_pad[:, -1, 0] = X[:, 0, -1]
    X_pad[:, -1, -1] = X[:, 0, 0]
    
    return X_pad

def nConv2d_sw_3x3(X):
    X_pad = get_padded_version_n(X)
    N = np.zeros_like(X_pad)
    
    N[:, 1:, 1:] += X_pad[:,:-1,:-1]
    N[:, 1:, :] += X_pad[:,:-1,:]
    N[:, 1:, :-1] += X_pad[:,:-1,1:]

    N[:, :, 1:] += X_pad[:,:,:-1]
    N[:, :, :] += X_pad[:,:,:]
    N[:, :, :-1] += X_pad[:,:,1:]

    N[:, :-1, 1:] += X_pad[:,1:,:-1]
    N[:, :-1, :] += X_pad[:,1:,:]
    N[:, :-1, :-1] += X_pad[:,1:,1:]
    
    N = N[:,1:-1,1:-1]
    
    return N

def life_step(X):
    if len(X.shape) == 2:
        X = np.expand_dims(X, 0)
    N =  nConv2d_sw_3x3(X) - X
    return np.squeeze(np.logical_or(N == 3, np.logical_and(X, N==2)).astype(np.uint8))

class GRGoL_env:
    def __init__(self, size, max_delta, max_step, h=1):
        self.ydim, self.xdim = size, size
        self.grid_size = self.ydim * self.xdim
        self.max_delta = max_delta
        self.h = h
        self.max_step = max_step
        
    def reset(self):
        valid = False
        while not valid:
            grid = (np.random.random((self.ydim, self.xdim)) < random() * 0.98 + 0.01).astype(np.uint8)

            # Warmup steps
            for i in range(5):
                grid = life_step(grid)

            # Generate delta
            delta = randint(1, self.max_delta)

            # Calculating final state
            initial_grid = grid
            for i in range(delta):
                grid = life_step(grid)
            target = grid
            
            valid = np.sum(target) > 0
        self.initial_grid = initial_grid
        self.target = target
        self.delta = delta
        self.solution_h = np.repeat(np.expand_dims(np.random.randint(0, 2, (self.ydim, self.xdim)), 0), self.h, axis=0)
        res = life_step(self.solution_h[-1])
        self.best_score = np.sum(res == self.target)
        self.s = 0
        
        return self.get_state()
        
    def get_state(self):
        return self.solution_h, self.best_score, self.delta, self.target
    
    def step(self, action):
        initial_score = self.best_score
        
        if self.s < self.max_step and self.best_score < self.grid_size:
            new = np.expand_dims(np.copy(self.solution_h[-1]), 0)
            new = np.expand_dims(np.copy(self.solution_h[-1]), 0)
            new.reshape(-1)[action] = 1 - new.reshape(-1)[action]
            
            self.solution_h = np.append(self.solution_h, new, axis=0)
            self.solution_h = np.delete(self.solution_h, 0, 0)

            self.best_score = max(self.best_score, np.sum(life_step(new) == self.target))

            self.s += 1
            
        reward = self.best_score - initial_score
        done = self.s >= self.max_step or self.best_score >= self.grid_size
        return self.get_state(), reward, done

In [None]:
# Playing
eps_start = 0.99
eps_decay = 0.999
num_episodes = 10000
gamma = 0.999
#Replay buffer
buffer_size = 10000
min_buffer_size = 128
# Training
lr = 0.001
Optimizer = torch.optim.Adam
batch_size = 32
target_update_freq = 128
# Model
n_layers = 5
n_filters = 32

def format_res_input(I, s, d, F):
    if len(I.shape) == 3:
        I = np.expand_dims(I, 0)
        d = np.expand_dims(d, 0)
        F = np.expand_dims(F, 0)
    s = np.repeat(s / grid_size / grid_size, grid_size * grid_size).reshape((-1, 1, grid_size, grid_size))
    d_indic = np.repeat(np.eye(max_delta)[d-1], grid_size * grid_size).reshape((-1, max_delta, grid_size, grid_size))
    F = np.expand_dims(F, axis=1)
    z = np.concatenate((I, s, d_indic, F), axis=1)
    return torch.tensor(z, dtype=dtype, device=device) * 2 - 1
        
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class InBlock(nn.Module):
    def __init__(self, infilters, filters):
        super(InBlock, self).__init__()
        self.conv = nn.Conv2d(infilters, filters, 3, stride=1, padding=1, padding_mode = 'circular')
        self.bn = nn.BatchNorm2d(filters)

    def forward(self, s):
        return F.relu(self.bn(self.conv(s)))

class ResBlock(nn.Module):
    def __init__(self, infilters, filters):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(infilters, filters, kernel_size=3, stride=1,
                     padding=1, bias=False, padding_mode = 'circular')
        self.bn1 = nn.BatchNorm2d(filters)
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1,
                     padding=1, bias=False, padding_mode = 'circular')
        self.bn2 = nn.BatchNorm2d(filters)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = F.relu(out)
        return out
    
class OutBlock(nn.Module):
    def __init__(self, infilters):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(infilters, 1, kernel_size=1, padding_mode = 'circular')
    
    def forward(self,s):
        return torch.relu(self.conv(s))
    
class ResNet(nn.Module):
    def __init__(self, n_filters, res_layers):
        super(ResNet, self).__init__()
        self.convi = InBlock(max_delta + 3, n_filters)
        
        self.i_blocks = []
        for _ in range(res_layers):
            self.i_blocks.append(ResBlock(n_filters, n_filters))
        self.i_blocks = nn.ModuleList(self.i_blocks)
        
        self.convo = OutBlock(n_filters)
        
    def forward(self, y):
        s = self.convi(y)
        
        for b in self.i_blocks:
            s = b(s)
            
        return self.convo(s)
    
class ReplayBuffer:
    def __init__(self, max_size):
        self.max_size = max_size
        self.states = np.zeros((max_size, max_delta+3, grid_size, grid_size))
        self.actions = np.zeros(max_size)
        self.next_states = np.zeros((max_size, max_delta+3, grid_size, grid_size))
        self.rewards = np.zeros(max_size)
        self.cursor = 0
        
    def add(self, state, action, next_state, reward):
        if self.cursor + 1 == self.max_size:
            self.states = np.concatenate([self.states[1:], np.zeros((1, max_delta + 3, grid_size, grid_size))], axis=0)
            self.actions = np.concatenate([self.actions[1:], np.zeros(1)], axis=0)
            self.next_states = np.concatenate([self.next_states[1:], np.zeros((1, max_delta + 3, grid_size, grid_size))], axis=0)
            self.rewards = np.concatenate([self.rewards[1:], np.zeros(1)], axis=0)
        self.states[self.cursor] = state
        self.actions[self.cursor] = action
        self.next_states[self.cursor] = next_state
        self.rewards[self.cursor] = reward
        self.cursor = min(self.max_size - 1, self.cursor + 1)
        
    def sample(self, batch_size):
        idxs = np.random.randint(0, self.cursor, batch_size)
        return (
            torch.tensor(self.states[idxs], dtype=dtype, device=device),
            torch.tensor(self.actions[idxs], dtype=torch.int64, device=device),
            torch.tensor(self.next_states[idxs], dtype=dtype, device=device),
            torch.tensor(self.rewards[idxs], dtype=dtype, device=device)
        )
        
    def __len__(self):
        return self.cursor
    
def optimize_model(buffer):
    if len(buffer) < min_buffer_size:
        return
    state_batch, action_batch, next_state_batch, reward_batch = buffer.sample(batch_size)
    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    qs = policy_net(state_batch).view((-1, grid_size * grid_size))
    state_action_values = qs.gather(1, torch.unsqueeze(action_batch, axis=-1))

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = target_net(next_state_batch).view((-1, grid_size * grid_size)).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * gamma) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    
    return loss.item()
    

policy_net = ResNet(n_filters, n_layers).to(device)
target_net = ResNet(n_filters, n_layers).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = Optimizer(policy_net.parameters(), lr=lr)
env = GRGoL_env(grid_size, max_delta, grid_size * grid_size)
buffer = ReplayBuffer(buffer_size)
tot_r_h = []
for i_episode in range(num_episodes):
    # Initialize the environment and state
    state = env.reset()
    for t in count():
        # Select and perform an action
        eps = eps_start * eps_decay ** i_episode
        nn_state = format_res_input(*state)
        if random() < eps:
            action = randint(0, grid_size ** 2-1)
        else:
            with torch.no_grad():
                qs = policy_net(nn_state)
                action = qs.view(-1, grid_size * grid_size).max(1)[1].view(1, 1).item()
        new_state, reward, done = env.step(action)
        
        # Store the transition in memory
        buffer.add(nn_state.cpu().numpy(), action, format_res_input(*new_state).cpu().numpy(), reward)
        
        # Move to the next state
        state = new_state

        # Perform one step of the optimization (on the policy network)
        optimize_model(buffer)
        
        if done:
            tot_r_h.append(env.best_score)
            print("Episode {}: epsilon={}, Solution MAE={}".format(i_episode, eps, 1 - env.best_score / grid_size / grid_size))
            break
    # Update the target network, copying all weights and biases in DQN
    if i_episode % target_update_freq == target_update_freq-1:
        print("Updating target net")
        target_net.load_state_dict(policy_net.state_dict())
        torch.save(policy_net, 'policy.pt')

In [None]:
plt.plot(tot_r_h)