In [1]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.optim as optim
from torch.autograd import Variable
import torch
from tqdm import tqdm
import sys

In [2]:
sys.path.insert(0,'..')
import bitmap
import itertools
from forward_prediction import forward_model

# Generate data

## Training data

In [3]:
# Generate training data
N = 25600
train_data_gen = bitmap.generate_train_set(N, 41, min_delta=1, max_delta=1)
deltas, start_boards, stop_boards = map(np.array, zip(*list(train_data_gen)))

## Test data

In [4]:
# Generate testing data
N_test = 25600
test_data_gen = bitmap.generate_train_set(N, 42, min_delta=1, max_delta=1)
deltas, test_start_boards, test_stop_boards = map(np.array, zip(*list(test_data_gen)))

# Model trainer

In [37]:
def train(model, X, y, optim, criterion, output_path, num_epochs=50, batch_size=128):
    X_ = X.cuda()
    y_ = y.cuda()
    model.cuda()
    print(model)
    # Set optimizer
    optimizer = optim(model.parameters())
    for epoch in tqdm(range(num_epochs)): 
        permutation = torch.randperm(X_.size()[0])
        running_loss = 0.0
        for i in range(0, X_.size()[0], batch_size):
            indices = permutation[i:i+batch_size]
            batch = X_[indices]
            target = y_[indices]
        
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 4000 == 0 and i > 0:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0
    # Save model state_dict
    torch.save(model.state_dict(), output_path)

# Train a forward network

In [28]:
class ForwardNet(nn.Module):
    def __init__(self):
        super(ForwardNet, self).__init__()
        # in channels, out channels, kernel size
        self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 8, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ2 = nn.PReLU()
        self.conv3 = nn.Conv2d(8, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ3 = nn.PReLU()
        self.conv4 = nn.Conv2d(4, 1, (3, 3), padding=(1, 1), padding_mode='circular')
        
    def forward(self, x):
        x = self.activ1(self.conv1(x))
        x = self.activ2(self.conv2(x))
        x = self.activ3(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

In [29]:
forward_net = ForwardNet()
criterion = nn.BCELoss()

In [31]:
X = Variable(torch.tensor(start_boards).view(N, 1, 25, 25).float(), requires_grad=True)
y = Variable(torch.tensor(stop_boards).view(N, 1, 25, 25).float())

In [32]:
train(forward_net, X, y, optim.Adam, criterion, "../models/johnson/vanilla_forward.pkl")

  0%|          | 0/50 [00:00<?, ?it/s]

ForwardNet(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ1): ReLU()
  (conv2): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ2): PReLU(num_parameters=1)
  (conv3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ3): PReLU(num_parameters=1)
  (conv4): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
)
[1, 16001] loss: 0.033


  2%|▏         | 1/50 [00:10<08:27, 10.36s/it]

[2, 16001] loss: 0.028


  4%|▍         | 2/50 [00:20<08:15, 10.33s/it]

[3, 16001] loss: 0.016


  6%|▌         | 3/50 [00:31<08:06, 10.35s/it]

[4, 16001] loss: 0.004


  8%|▊         | 4/50 [00:41<07:56, 10.37s/it]

[5, 16001] loss: 0.001


 10%|█         | 5/50 [00:51<07:46, 10.37s/it]

[6, 16001] loss: 0.000


 12%|█▏        | 6/50 [01:02<07:36, 10.38s/it]

[7, 16001] loss: 0.000


 14%|█▍        | 7/50 [01:12<07:26, 10.38s/it]

[8, 16001] loss: 0.000


 16%|█▌        | 8/50 [01:22<07:16, 10.38s/it]

[9, 16001] loss: 0.000


 18%|█▊        | 9/50 [01:33<07:06, 10.40s/it]

[10, 16001] loss: 0.000


 20%|██        | 10/50 [01:43<06:57, 10.44s/it]

[11, 16001] loss: 0.000


 22%|██▏       | 11/50 [01:54<06:47, 10.44s/it]

[12, 16001] loss: 0.000


 24%|██▍       | 12/50 [02:04<06:37, 10.46s/it]

[13, 16001] loss: 0.000


 26%|██▌       | 13/50 [02:15<06:27, 10.48s/it]

[14, 16001] loss: 0.000


 28%|██▊       | 14/50 [02:25<06:17, 10.49s/it]

[15, 16001] loss: 0.000


 30%|███       | 15/50 [02:36<06:07, 10.49s/it]

[16, 16001] loss: 0.000


 32%|███▏      | 16/50 [02:46<05:57, 10.50s/it]

[17, 16001] loss: 0.000


 34%|███▍      | 17/50 [02:57<05:45, 10.48s/it]

[18, 16001] loss: 0.000


 36%|███▌      | 18/50 [03:07<05:35, 10.49s/it]

[19, 16001] loss: 0.000


 38%|███▊      | 19/50 [03:18<05:25, 10.49s/it]

[20, 16001] loss: 0.000


 40%|████      | 20/50 [03:28<05:13, 10.46s/it]

[21, 16001] loss: 0.000


 42%|████▏     | 21/50 [03:39<05:04, 10.48s/it]

[22, 16001] loss: 0.000


 44%|████▍     | 22/50 [03:49<04:54, 10.51s/it]

[23, 16001] loss: 0.000


 46%|████▌     | 23/50 [04:00<04:44, 10.53s/it]

[24, 16001] loss: 0.000


 48%|████▊     | 24/50 [04:11<04:34, 10.55s/it]

[25, 16001] loss: 0.000


 50%|█████     | 25/50 [04:21<04:24, 10.56s/it]

[26, 16001] loss: 0.000


 52%|█████▏    | 26/50 [04:32<04:13, 10.57s/it]

[27, 16001] loss: 0.000


 54%|█████▍    | 27/50 [04:42<04:03, 10.57s/it]

[28, 16001] loss: 0.000


 56%|█████▌    | 28/50 [04:53<03:52, 10.58s/it]

[29, 16001] loss: 0.000


 58%|█████▊    | 29/50 [05:04<03:42, 10.58s/it]

[30, 16001] loss: 0.000


 60%|██████    | 30/50 [05:14<03:31, 10.57s/it]

[31, 16001] loss: 0.000


 62%|██████▏   | 31/50 [05:25<03:20, 10.56s/it]

[32, 16001] loss: 0.000


 64%|██████▍   | 32/50 [05:35<03:10, 10.56s/it]

[33, 16001] loss: 0.000


 66%|██████▌   | 33/50 [05:46<02:59, 10.55s/it]

[34, 16001] loss: 0.000


 68%|██████▊   | 34/50 [05:56<02:48, 10.55s/it]

[35, 16001] loss: 0.000


 70%|███████   | 35/50 [06:07<02:38, 10.57s/it]

[36, 16001] loss: 0.000


 72%|███████▏  | 36/50 [06:17<02:28, 10.57s/it]

[37, 16001] loss: 0.000


 74%|███████▍  | 37/50 [06:28<02:17, 10.55s/it]

[38, 16001] loss: 0.000


 76%|███████▌  | 38/50 [06:38<02:06, 10.56s/it]

[39, 16001] loss: 0.000


 78%|███████▊  | 39/50 [06:49<01:56, 10.56s/it]

[40, 16001] loss: 0.000


 80%|████████  | 40/50 [07:00<01:45, 10.56s/it]

[41, 16001] loss: 0.000


 82%|████████▏ | 41/50 [07:10<01:34, 10.55s/it]

[42, 16001] loss: 0.000


 84%|████████▍ | 42/50 [07:21<01:24, 10.55s/it]

[43, 16001] loss: 0.000


 86%|████████▌ | 43/50 [07:31<01:13, 10.56s/it]

[44, 16001] loss: 0.000


 88%|████████▊ | 44/50 [07:42<01:03, 10.56s/it]

[45, 16001] loss: 0.000


 90%|█████████ | 45/50 [07:52<00:52, 10.58s/it]

[46, 16001] loss: 0.000


 92%|█████████▏| 46/50 [08:03<00:42, 10.58s/it]

[47, 16001] loss: 0.000


 94%|█████████▍| 47/50 [08:14<00:31, 10.56s/it]

[48, 16001] loss: 0.000


 96%|█████████▌| 48/50 [08:24<00:21, 10.56s/it]

[49, 16001] loss: 0.000


 98%|█████████▊| 49/50 [08:35<00:10, 10.56s/it]

[50, 16001] loss: 0.000


100%|██████████| 50/50 [08:45<00:00, 10.52s/it]


In [33]:
def get_forward_mae(start_boards, stop_boards, n):
    start_boards_tensor = torch.tensor(start_boards).view(n, 1, 25, 25).float().cuda()
    stop_boards_tensor = torch.tensor(stop_boards).view(n, 1, 25, 25)
    with torch.no_grad():
        forward_net.eval()
        # Make prediction
        predicted_stop_board = (forward_net(start_boards_tensor) > 0.5).int().cpu()
        error = torch.sum(predicted_stop_board != stop_boards_tensor)
        # print(predicted_stop_board)
        # print(stop_boards_tensor)
        return error / (n * 25 * 25)

In [34]:
# Training data MAE
train_mae = get_forward_mae(start_boards, stop_boards, N)
print("The training data MAE is {:.6f}.".format(train_mae))

# Test data MAE
test_mae = get_forward_mae(test_start_boards, test_stop_boards, N_test)
print("The test data MAE is {:.6f}.".format(test_mae))

The training data MAE is 0.000000.
The test data MAE is 0.000000.


# Relax starting boards

In [35]:
# Modify starting boards 
def relax_boards(boards):
    return np.abs(np.random.rand(*boards.shape) / 2 - boards)

In [36]:
relaxed_start_boards = relax_boards(start_boards)

In [38]:
X_relaxed = Variable(torch.tensor(relaxed_start_boards).view(N, 1, 25, 25).float(), requires_grad=True)
y = Variable(torch.tensor(stop_boards).view(N, 1, 25, 25).float())

In [None]:
train(forward_net, X_relaxed, y, optim.Adam, criterion, "../models/johnson/forward_with_relaxation.pkl")

  0%|          | 0/50 [00:00<?, ?it/s]

ForwardNet(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ1): ReLU()
  (conv2): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ2): PReLU(num_parameters=1)
  (conv3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ3): PReLU(num_parameters=1)
  (conv4): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
)
[1, 16001] loss: 0.045


  2%|▏         | 1/50 [00:10<08:39, 10.60s/it]

[2, 16001] loss: 0.032


  4%|▍         | 2/50 [00:21<08:29, 10.61s/it]

[3, 16001] loss: 0.025


  6%|▌         | 3/50 [00:32<08:21, 10.66s/it]

[4, 16001] loss: 0.025


  8%|▊         | 4/50 [00:42<08:11, 10.68s/it]

[5, 16001] loss: 0.023


 10%|█         | 5/50 [00:53<08:01, 10.70s/it]

[6, 16001] loss: 0.021


 12%|█▏        | 6/50 [01:04<07:51, 10.71s/it]

[7, 16001] loss: 0.021


 14%|█▍        | 7/50 [01:15<07:43, 10.77s/it]

[8, 16001] loss: 0.021


 16%|█▌        | 8/50 [01:25<07:31, 10.76s/it]

[9, 16001] loss: 0.021


 18%|█▊        | 9/50 [01:36<07:22, 10.78s/it]

[10, 16001] loss: 0.020


 20%|██        | 10/50 [01:47<07:08, 10.72s/it]

[11, 16001] loss: 0.020


 22%|██▏       | 11/50 [01:58<07:00, 10.78s/it]

[12, 16001] loss: 0.020


 24%|██▍       | 12/50 [02:09<06:52, 10.87s/it]

[13, 16001] loss: 0.018


 26%|██▌       | 13/50 [02:20<06:45, 10.96s/it]

[14, 16001] loss: 0.020


 28%|██▊       | 14/50 [02:31<06:36, 11.01s/it]

[15, 16001] loss: 0.021


 30%|███       | 15/50 [02:42<06:28, 11.09s/it]

[16, 16001] loss: 0.021


 32%|███▏      | 16/50 [02:54<06:18, 11.13s/it]

[17, 16001] loss: 0.020


 34%|███▍      | 17/50 [03:05<06:07, 11.15s/it]

[18, 16001] loss: 0.018


 36%|███▌      | 18/50 [03:16<05:57, 11.16s/it]

[19, 16001] loss: 0.019


 38%|███▊      | 19/50 [03:27<05:46, 11.18s/it]

[20, 16001] loss: 0.020


 40%|████      | 20/50 [03:39<05:36, 11.23s/it]

# Reverse model

In [4]:
class ReverseNet(nn.Module):
    def __init__(self):
        super(ReverseNet, self).__init__()
        # in channels, out channels, kernel size
        self.conv1 = nn.Conv2d(1, 16, (5, 5), padding=(2, 2), padding_mode='circular')
        self.activ1 = nn.PReLU()
        self.conv2 = nn.Conv2d(16, 8, (5, 5), padding=(2, 2), padding_mode='circular')
        self.activ2 = nn.PReLU()
        self.conv3 = nn.Conv2d(8, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ3 = nn.PReLU()
        self.conv4 = nn.Conv2d(4, 1, (3, 3), padding=(1, 1), padding_mode='circular')

    def forward(self, x):
        x = self.reverse(x)
        return x

    def reverse(self, x):
        x = self.activ1(self.conv1(x))
        x = self.activ2(self.conv2(x))
        x = self.activ3(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

In [5]:
net = ReverseNet()
print(net)

ReverseNet(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=circular)
  (activ1): PReLU(num_parameters=1)
  (conv2): Conv2d(16, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=circular)
  (activ2): PReLU(num_parameters=1)
  (conv3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ3): PReLU(num_parameters=1)
  (conv4): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
)


In [6]:
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters()) #, lr=0.001, momentum=0.9)

In [10]:
X = Variable(torch.tensor(stop_boards).view(N, 1, 25, 25).float(), requires_grad=True)
y = Variable(torch.tensor(start_boards).view(N, 1, 25, 25).float())
num_epochs = 50
batch_size = 128
for epoch in tqdm(range(num_epochs)): 
    permutation = torch.randperm(X.size()[0])
    running_loss = 0.0
    for i in range(0, X.size()[0], batch_size):
        indices = permutation[i:i+batch_size]
        batch = X[indices]
        target = y[indices]
        
        optimizer.zero_grad()
        outputs = net(batch)
        # loss = criterion(outputs, batch)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 1000 == 0 and i > 0:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
        running_loss = 0.0

  0%|          | 0/50 [00:00<?, ?it/s]

[1, 16001] loss: 0.037


  2%|▏         | 1/50 [00:17<14:31, 17.78s/it]

[2, 16001] loss: 0.034


  4%|▍         | 2/50 [00:35<14:17, 17.87s/it]

[3, 16001] loss: 0.034


  6%|▌         | 3/50 [00:54<14:06, 18.02s/it]

[4, 16001] loss: 0.035


  8%|▊         | 4/50 [01:12<13:50, 18.06s/it]

[5, 16001] loss: 0.037


 10%|█         | 5/50 [01:30<13:35, 18.13s/it]

[6, 16001] loss: 0.038


 12%|█▏        | 6/50 [01:48<13:18, 18.16s/it]

[7, 16001] loss: 0.035


 14%|█▍        | 7/50 [02:07<13:00, 18.15s/it]

[8, 16001] loss: 0.031


 16%|█▌        | 8/50 [02:25<12:43, 18.18s/it]

[9, 16001] loss: 0.033


 18%|█▊        | 9/50 [02:43<12:25, 18.18s/it]

[10, 16001] loss: 0.035


 20%|██        | 10/50 [03:01<12:06, 18.17s/it]

[11, 16001] loss: 0.033


 22%|██▏       | 11/50 [03:19<11:49, 18.19s/it]

[12, 16001] loss: 0.034


 24%|██▍       | 12/50 [03:38<11:30, 18.18s/it]

[13, 16001] loss: 0.035


 26%|██▌       | 13/50 [03:56<11:12, 18.18s/it]

[14, 16001] loss: 0.037


 28%|██▊       | 14/50 [04:14<10:54, 18.18s/it]

[15, 16001] loss: 0.035


 30%|███       | 15/50 [04:32<10:35, 18.15s/it]

[16, 16001] loss: 0.037


 32%|███▏      | 16/50 [04:50<10:16, 18.13s/it]

[17, 16001] loss: 0.034


 34%|███▍      | 17/50 [05:08<09:59, 18.17s/it]

[18, 16001] loss: 0.034


 36%|███▌      | 18/50 [05:27<09:42, 18.22s/it]

[19, 16001] loss: 0.035


 38%|███▊      | 19/50 [05:45<09:24, 18.21s/it]

[20, 16001] loss: 0.033


 40%|████      | 20/50 [06:03<09:04, 18.16s/it]

[21, 16001] loss: 0.036


 42%|████▏     | 21/50 [06:21<08:46, 18.14s/it]

[22, 16001] loss: 0.036


 44%|████▍     | 22/50 [06:39<08:26, 18.10s/it]

[23, 16001] loss: 0.038


 46%|████▌     | 23/50 [06:57<08:09, 18.11s/it]

[24, 16001] loss: 0.034


 48%|████▊     | 24/50 [07:15<07:50, 18.11s/it]

[25, 16001] loss: 0.037


 50%|█████     | 25/50 [07:33<07:32, 18.12s/it]

[26, 16001] loss: 0.037


 52%|█████▏    | 26/50 [07:51<07:13, 18.06s/it]

[27, 16001] loss: 0.036


 54%|█████▍    | 27/50 [08:09<06:55, 18.08s/it]

[28, 16001] loss: 0.039


 56%|█████▌    | 28/50 [08:28<06:39, 18.16s/it]

[29, 16001] loss: 0.035


 58%|█████▊    | 29/50 [08:46<06:23, 18.27s/it]

[30, 16001] loss: 0.036


 60%|██████    | 30/50 [09:05<06:06, 18.32s/it]

[31, 16001] loss: 0.035


 62%|██████▏   | 31/50 [09:23<05:48, 18.33s/it]

[32, 16001] loss: 0.038


 64%|██████▍   | 32/50 [09:41<05:29, 18.32s/it]

[33, 16001] loss: 0.034


 66%|██████▌   | 33/50 [10:00<05:10, 18.27s/it]

[34, 16001] loss: 0.038


 68%|██████▊   | 34/50 [10:18<04:52, 18.25s/it]

[35, 16001] loss: 0.035


 70%|███████   | 35/50 [10:36<04:32, 18.18s/it]

[36, 16001] loss: 0.037


 72%|███████▏  | 36/50 [10:54<04:13, 18.12s/it]

[37, 16001] loss: 0.036


 74%|███████▍  | 37/50 [11:12<03:55, 18.13s/it]

[38, 16001] loss: 0.035


 76%|███████▌  | 38/50 [11:30<03:36, 18.07s/it]

[39, 16001] loss: 0.039


 78%|███████▊  | 39/50 [11:48<03:18, 18.09s/it]

[40, 16001] loss: 0.035


 80%|████████  | 40/50 [12:06<03:01, 18.12s/it]

[41, 16001] loss: 0.036


 82%|████████▏ | 41/50 [12:24<02:43, 18.15s/it]

[42, 16001] loss: 0.038


 84%|████████▍ | 42/50 [12:42<02:24, 18.11s/it]

[43, 16001] loss: 0.035


 86%|████████▌ | 43/50 [13:01<02:06, 18.12s/it]

[44, 16001] loss: 0.033


 88%|████████▊ | 44/50 [13:19<01:48, 18.08s/it]

[45, 16001] loss: 0.034


 90%|█████████ | 45/50 [13:37<01:30, 18.12s/it]

[46, 16001] loss: 0.034


 92%|█████████▏| 46/50 [13:55<01:12, 18.14s/it]

[47, 16001] loss: 0.035


 94%|█████████▍| 47/50 [14:13<00:54, 18.16s/it]

[48, 16001] loss: 0.036


 96%|█████████▌| 48/50 [14:33<00:37, 18.74s/it]

[49, 16001] loss: 0.036


 98%|█████████▊| 49/50 [14:53<00:19, 19.16s/it]

[50, 16001] loss: 0.037


100%|██████████| 50/50 [15:14<00:00, 18.28s/it]


In [11]:
# Generate testing data
N_test = 25600
# test_data_gen = bitmap.generate_test_set(N_test, 45, max_delta=1)
test_data_gen = bitmap.generate_test_set(N_test, 45, min_delta=5, max_delta=5)
delta, test_stop_boards = map(list, zip(*list(test_data_gen)))

In [17]:
def get_mae(stop_boards, n, delta=1):
    stop_boards_tensor = torch.tensor(stop_boards).view(n, 1, 25, 25).float()
    with torch.no_grad():
        net.eval()
        # Make prediction
        output = (net(stop_boards_tensor) > 0.5).int()
        # Evolve one step forward
        predicted_stop_board = output
        for _ in range(delta):
            predicted_stop_board = (forward_model.forward(predicted_stop_board) > 0.5).int()
        error = torch.sum(predicted_stop_board != stop_boards_tensor.int())
        return error / (n * 25 * 25)

In [13]:
# Training data MAE
train_mae = get_mae(stop_boards, N, 5).item()
print("The training data MAE is {:.6f}.".format(train_mae))

# Test data MAE
test_mae = get_mae(test_stop_boards, N_test, 5).item()
print("The test data MAE is {:.6f}.".format(test_mae))

The training data MAE is 0.137085.
The test data MAE is 0.136620.


In [18]:
# Test data MAE
test_mae = get_mae(test_stop_boards, N_test, 5).item()
print("The test data MAE is {:.6f}.".format(test_mae))

The test data MAE is 0.140959.
