In [None]:
import numpy as np
from scipy.signal import convolve2d
import tensorflow as tf
import keras
from keras import layers
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from random import randint, random
import sys
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
test = True
max_delta = 5

In [None]:
def grid_images(images, save=None):
    if not isinstance(images, np.ndarray):
        images = np.stack(images)
    assert len(images.shape) >= 2, "pas assez de dimensions"
    assert len(images.shape) <= 4, "trop de dimensions"
    if len(images.shape) == 2:
        images = np.expand_dims(images, 0)
    if len(images.shape) == 3:
        images = np.expand_dims(images, 0)
    plt.figure(figsize=(images.shape[1], images.shape[0]))
    print(images.shape)

    for j in range(images.shape[1]):
        for i in range(images.shape[0]):
            plt.subplot(images.shape[0], images.shape[1], i * images.shape[1] + j + 1)
            plt.imshow(images[i, j])

    plt.show()
    if save is not None:
        plt.savefig(save)

In [None]:
"""
    Game of life functions
"""

def get_padded_version_n(X):
    X_pad = np.zeros((X.shape[0], X.shape[-2] + 2, X.shape[-1] + 2), dtype=X.dtype)
    X_pad[:, 1:-1,1:-1] += X
    
    X_pad[:, 0, 1:-1] = X[:, -1, :]
    X_pad[:, -1, 1:-1] = X[:, 0, :]
    
    X_pad[:, 1:-1, 0] = X[:, :, -1]
    X_pad[:, 1:-1, -1] = X[:, :, 0]
    
    X_pad[:, 0, 0] = X[:, -1, -1]
    X_pad[:, 0, -1] = X[:, -1, 0]
    X_pad[:, -1, 0] = X[:, 0, -1]
    X_pad[:, -1, -1] = X[:, 0, 0]
    
    return X_pad

def nConv2d_sw_3x3(X):
    X_pad = get_padded_version_n(X)
    N = np.zeros_like(X_pad)
    
    N[:, 1:, 1:] += X_pad[:,:-1,:-1]
    N[:, 1:, :] += X_pad[:,:-1,:]
    N[:, 1:, :-1] += X_pad[:,:-1,1:]

    N[:, :, 1:] += X_pad[:,:,:-1]
    N[:, :, :] += X_pad[:,:,:]
    N[:, :, :-1] += X_pad[:,:,1:]

    N[:, :-1, 1:] += X_pad[:,1:,:-1]
    N[:, :-1, :] += X_pad[:,1:,:]
    N[:, :-1, :-1] += X_pad[:,1:,1:]
    
    N = N[:,1:-1,1:-1]
    
    return N

def life_step(X):
    N =  nConv2d_sw_3x3(X) - X
    return np.logical_or(N == 3, np.logical_and(X, N==2)).astype(np.uint8)

class GoL_data:
    def __init__(self, size, max_delta):
        self.ydim, self.xdim = size
        self.grid_size = self.ydim * self.xdim
        self.max_delta = max_delta
        
    def iterate_grids(self, batch_size):
        while 1:
            # Generate starting states
            sample_size = 2 * batch_size
            probs = np.random.random(sample_size) * 0.98 + 0.01
            grids = (np.random.random((sample_size, self.ydim, self.xdim)) < np.repeat(probs, self.grid_size).reshape((sample_size, self.ydim, self.xdim))).astype(np.uint8)

            # Warmup steps
            for i in range(5):
                grids = life_step(grids)

            # Generate deltas
            deltas = np.random.randint(1, 6, sample_size)

            # Calculating final states
            initial_grids = grids
            targets = np.zeros_like(grids)
            for i in range(5):
                grids = life_step(grids)
                step_mask = deltas == i+1
                targets[step_mask] = grids[step_mask]

            # Keeping only non-empty states
            keep_mask = targets.sum(axis=(-2, -1)) > 0
            initial_grids = initial_grids[keep_mask]
            targets = targets[keep_mask]
            deltas = deltas[keep_mask]

            yield initial_grids[:batch_size], targets[:batch_size], deltas[:batch_size]
            
def evaluate_gol(I, F, d):
    Fh = np.zeros_like(I)
    s = I
    for k in range(max_delta):
        s = life_step(s)
        sm = d == k+1
        Fh[sm] = s[sm]
    acc = (F == Fh).mean()
    return Fh, 1 - acc

if test:
    gen = GoL_data((25, 25), max_delta)
    i, f, d = next(iter(gen.iterate_grids(1000)))
    _, s = evaluate_gol(i, f, d)
    assert s == 0.0, "build_grids is wrong"

In [None]:
"""
    GAN functions
    Forked from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py
"""

if torch.cuda.is_available():  
    dev = "cuda:0" 
else:  
    dev = "cpu"
device = torch.device(dev)

dtype = torch.float32

n_steps = 50000
batch_size = 128
lr = 0.0002
b1 = 0.5
b2 = 0.999
display_frequency = 50
save_frequency = 1000

n_layers = 12
n_filters = 64
latent_dim = 100

def format_gen_input(F, d):
    F = np.expand_dims(F, axis=1).astype(np.float32) * 1 - 0.5
    d_indic = np.repeat(np.eye(max_delta)[d-1], F.shape[-2] * F.shape[-1]).reshape((F.shape[0], max_delta, F.shape[-2], F.shape[-1])).astype(np.float32) * 1 - 0.5
    z = np.concatenate((F, d_indic), axis=1)
    return torch.tensor(z, dtype=dtype, device=device)

def format_dis_input(I, d, F):
    d_indic = np.repeat(np.eye(max_delta)[d-1], F.shape[-2] * F.shape[-1]).reshape((F.shape[0], max_delta, F.shape[-2], F.shape[-1]))
    if torch.is_tensor(I):
        ret = torch.cat((I, torch.tensor(d_indic, dtype=dtype, device=device) * 1 - 0.5, torch.tensor(np.expand_dims(F, axis=1), dtype=dtype, device=device) * 1 - 0.5), 1)
    else:
        ret = torch.tensor(np.concatenate((np.expand_dims(I, axis=1), d_indic, np.expand_dims(F, axis=1)), axis=1), dtype=dtype, device=device) * 1 - 0.5
        ret[:, 0] += torch.randn(len(I), I.shape[-2], I.shape[-1], dtype=dtype, device=device) * 0.01
    return ret
        
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class InBlock(nn.Module):
    def __init__(self, infilters, filters):
        super(InBlock, self).__init__()
        self.conv = nn.Conv2d(infilters, filters, 3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(filters)
        
        nn.init.normal_(self.bn.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bn.bias.data, 0)
        
        nn.init.normal_(self.conv.weight.data, 0.0, 0.02)
        nn.init.constant_(self.conv.bias.data, 0)

    def forward(self, s):
        return F.relu(self.bn(self.conv(s)))

class ResBlock(nn.Module):
    def __init__(self, infilters, filters):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(infilters, filters, kernel_size=3, stride=1,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(filters)
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(filters)
        
        nn.init.normal_(self.conv1.weight.data, 0.0, 0.02)
        nn.init.normal_(self.conv2.weight.data, 0.0, 0.02)
        
        nn.init.normal_(self.bn1.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bn1.bias.data, 0)
        nn.init.normal_(self.bn2.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bn2.bias.data, 0)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = F.relu(out)
        return out
    
class OutBlock(nn.Module):
    def __init__(self, infilters):
        super(OutBlock, self).__init__()
        self.conv = nn.Conv2d(infilters, 1, kernel_size=1)
        nn.init.normal_(self.conv.weight.data, 0.0, 0.02)
        nn.init.constant_(self.conv.bias.data, 0)
    
    def forward(self,s):
        return F.tanh(self.conv(s))

class Generator(nn.Module):
    def __init__(self, n_res_layers, n_filters):
        super(Generator, self).__init__()
        self.n_res_layers = n_res_layers
        
        ntfilters = int(n_filters/2)
        
        self.convt1 = nn.ConvTranspose2d(latent_dim, ntfilters * 8, 3, 1, 0, bias=False)
        self.bnormt1 = nn.BatchNorm2d(ntfilters * 8)
        
        self.convt2 = nn.ConvTranspose2d(ntfilters * 8, ntfilters * 4, 3, 2, 0, bias=False)
        self.bnormt2 = nn.BatchNorm2d(ntfilters * 4)
        
        self.convt3 = nn.ConvTranspose2d(ntfilters * 4, ntfilters * 2, 3, 2, 1, bias=False)
        self.bnormt3 = nn.BatchNorm2d(ntfilters * 2)
        
        self.convt4 = nn.ConvTranspose2d(ntfilters * 2, ntfilters, 3, 2, 1, bias=False)
        self.bnormt4 = nn.BatchNorm2d(ntfilters)
        
        nn.init.normal_(self.convt1.weight.data, 0.0, 0.02)
        nn.init.normal_(self.convt2.weight.data, 0.0, 0.02)
        nn.init.normal_(self.convt3.weight.data, 0.0, 0.02)
        nn.init.normal_(self.convt4.weight.data, 0.0, 0.02)
        
        nn.init.normal_(self.bnormt1.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bnormt1.bias.data, 0)
        nn.init.normal_(self.bnormt2.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bnormt2.bias.data, 0)
        nn.init.normal_(self.bnormt3.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bnormt3.bias.data, 0)
        nn.init.normal_(self.bnormt4.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bnormt4.bias.data, 0)
        
        self.convi = InBlock(6, int(n_filters/2))
        self.convz = InBlock(ntfilters, int(n_filters/2))
        
        self.blocks = []
        for _ in range(n_res_layers):
            self.blocks.append(ResBlock(n_filters, n_filters))
        self.blocks = nn.ModuleList(self.blocks)
        self.outblock = OutBlock(n_filters)
    
    def forward(self,s):
        z = torch.randn(s.shape[0], latent_dim, 1, 1, dtype=dtype, device=device)
        
        z = self.convt1(z)
        z = self.bnormt1(z)
        z = F.relu(z)
        
        z = self.convt2(z)
        z = self.bnormt2(z)
        z = F.relu(z)
        
        z = self.convt3(z)
        z = self.bnormt3(z)
        z = F.relu(z)
        
        z = self.convt4(z)
        z = self.bnormt4(z)
        z = F.relu(z)
        
        z = self.convz(z)
        
        s = self.convi(s)
        
        s = torch.cat((s, z), 1)
        
        for i in range(self.n_res_layers):
            s = self.blocks[i](s)
        s = self.outblock(s)
        return s

class Discriminator(nn.Module):
    def __init__(self, n_res_layers, n_filters):
        super(Discriminator, self).__init__()
        self.n_res_layers = n_res_layers
        
        self.convi = nn.Conv2d(7, n_filters, 3, 1, 1, bias=False)
        self.lri = nn.LeakyReLU(0.2)
        
        self.blocks = []
        for _ in range(n_res_layers):
            self.blocks.append(ResBlock(n_filters, n_filters))
        self.blocks = nn.ModuleList(self.blocks)
        
        self.dconv1 = nn.Conv2d(n_filters, n_filters, 4, 2, 1, bias=False)
        self.lr1 = nn.LeakyReLU(0.2)
        
        self.dconv2 = nn.Conv2d(n_filters, n_filters * 2, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(n_filters * 2)
        self.lr2 = nn.LeakyReLU(0.2)
        
        self.dconv3 = nn.Conv2d(n_filters * 2, n_filters * 4, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(n_filters * 4)
        self.lr3 = nn.LeakyReLU(0.2)
        
        self.convo = nn.Conv2d(n_filters * 4, 1, 3, 1, 0, bias=False)
        self.sigmo = nn.Sigmoid()
        
        nn.init.normal_(self.convi.weight.data, 0.0, 0.02)
        nn.init.normal_(self.dconv1.weight.data, 0.0, 0.02)
        nn.init.normal_(self.dconv2.weight.data, 0.0, 0.02)
        nn.init.normal_(self.dconv3.weight.data, 0.0, 0.02)
        nn.init.normal_(self.convo.weight.data, 0.0, 0.02)
        
        nn.init.normal_(self.bn2.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bn2.bias.data, 0)
        nn.init.normal_(self.bn3.weight.data, 1.0, 0.02)
        nn.init.constant_(self.bn3.bias.data, 0)

    def forward(self, s):
        s = self.convi(s)
        s = self.lri(s)
        
        for i in range(self.n_res_layers):
            s = self.blocks[i](s)
        
        s = self.dconv1(s)
        s = self.lr1(s)
        
        s = self.dconv2(s)
        s = self.bn2(s)
        s = self.lr2(s)
        
        s = self.dconv3(s)
        s = self.bn3(s)
        s = self.lr3(s)
        
        s = self.convo(s)
        s = self.sigmo(s)
        
        return s

# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator(n_layers, n_filters).to(device)
#generator.apply(weights_init)
print("Parmeters of Generator: {}".format(count_parameters(generator)))
discriminator = Discriminator(n_layers, n_filters).to(device)
#discriminator.apply(weights_init)
print("Parmeters of Discriminator: {}".format(count_parameters(discriminator)))

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
    
data_gen = iter(GoL_data((25, 25), max_delta).iterate_grids(batch_size))
d_loss_h, g_loss_h, lb_h = [], [], []
for s in range(n_steps):
    Is, Fs, ds = next(data_gen)
    
    # Adversarial ground truths
    valid = torch.ones(len(Is), dtype=dtype, device=device)
    fake = torch.zeros(len(Is), dtype=dtype, device=device)
    
    ## Generator training
    optimizer_G.zero_grad()

    # Conditionned noisy input of discriminator
    z = format_gen_input(Fs, ds)

    # Generate a batch of states
    gen_states = generator(z)
    
    dis_input = format_dis_input(gen_states, ds, Fs)

    # Loss measures generator's ability to fool the discriminator
    g_loss = adversarial_loss(discriminator(dis_input), valid)

    g_loss.backward()
    optimizer_G.step()

    ## Discriminator training
    optimizer_D.zero_grad()
    
    #print(format_dis_input(Is, ds, Fs)[0, :, :5, :5])
    #print(format_dis_input(gen_states.detach(), ds, Fs)[0, :, :5, :5])
    #raise Exception("oui")

    # Measure discriminator's ability to classify real from generated samples
    real_loss = adversarial_loss(discriminator(format_dis_input(Is, ds, Fs)), valid)
    fake_loss = adversarial_loss(discriminator(format_dis_input(gen_states.detach(), ds, Fs)), fake)
    d_loss = (real_loss + fake_loss) / 2

    d_loss.backward()
    optimizer_D.step()
    
    d_loss_h.append(d_loss.item())
    g_loss_h.append(g_loss.item())
    
    Infered = (gen_states[:, 0].detach().cpu().numpy() > 0.0).astype(np.uint8)
    Fhs, scores = evaluate_gol(Infered, Fs, ds)
    lb_h.append(scores)

    if s % display_frequency == display_frequency-1:
        dhl = np.array(d_loss_h[-display_frequency:])
        ghl = np.array(g_loss_h[-display_frequency:])
        lbhl = np.array(lb_h[-display_frequency:])
        print("Step {}: [D loss: {:.4f}, {:.4f}] [G loss: {:.4f}, {:.4f}] [LB: {:.4f}, {:.4f}]".format(
            s + 1,
            dhl.mean(),
            dhl.std(),
            ghl.mean(),
            ghl.std(),
            lbhl.mean(),
            lbhl.std()
        ))

    if s % save_frequency == save_frequency-1:
        grid_images([Is[0], Fs[0], Infered[0], Fhs[0]], "sample_s{}.png".format(s))
        torch.save(generator, 'generator.pt')

In [None]:
plt.plot(d_loss_h)

In [None]:
plt.plot(g_loss_h)

In [None]:
plt.plot(lb_h)

In [None]:
np.save("generator_loss_history", np.array(g_loss_h))
np.save("discriminator_loss_history", np.array(d_loss_h))
np.save("lb_history", np.array(lb_h))