In [106]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.optim as optim
from torch.autograd import Variable
import torch
from tqdm import tqdm
import sys

In [107]:
sys.path.insert(0,'..')
import bitmap
import itertools
from forward_prediction import forward_model

In [108]:
class ReverseNet(nn.Module):
    def __init__(self):
        super(ReverseNet, self).__init__()
        # in channels, out channels, kernel size
        self.conv1 = nn.Conv2d(1, 16, (5, 5), padding=(2, 2), padding_mode='circular')
        self.activ1 = nn.PReLU()
        self.conv2 = nn.Conv2d(16, 8, (5, 5), padding=(2, 2), padding_mode='circular')
        self.activ2 = nn.PReLU()
        self.conv3 = nn.Conv2d(8, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ3 = nn.PReLU()
        self.conv4 = nn.Conv2d(4, 1, (3, 3), padding=(1, 1), padding_mode='circular')

    def forward(self, x):
        x = self.reverse(x)
        return x

    def reverse(self, x):
        x = self.activ1(self.conv1(x))
        x = self.activ2(self.conv2(x))
        x = self.activ3(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

In [109]:
net = ReverseNet()
print(net)

ReverseNet(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=circular)
  (activ1): PReLU(num_parameters=1)
  (conv2): Conv2d(16, 8, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), padding_mode=circular)
  (activ2): PReLU(num_parameters=1)
  (conv3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  (activ3): PReLU(num_parameters=1)
  (conv4): Conv2d(4, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
)


In [110]:
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters()) #, lr=0.001, momentum=0.9)

In [133]:
# Generate training data
N = 25600
# train_data_gen = bitmap.generate_train_set(N, 41, max_delta=1)
train_data_gen = bitmap.generate_train_set(N, 41, min_delta=5, max_delta=5)
deltas, start_boards, stop_boards = map(list, zip(*list(train_data_gen)))

In [134]:
X = Variable(torch.tensor(stop_boards).view(N, 1, 25, 25).float(), requires_grad=True)
y = Variable(torch.tensor(start_boards).view(N, 1, 25, 25).float())
num_epochs = 100
batch_size = 128
for epoch in tqdm(range(num_epochs)): 
    permutation = torch.randperm(X.size()[0])
    running_loss = 0.0
    for i in range(0, X.size()[0], batch_size):
        indices = permutation[i:i+batch_size]
        batch = X[indices]
        target = y[indices]
        
        optimizer.zero_grad()
        outputs = net(batch)
        # loss = criterion(outputs, batch)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 1000 == 0 and i > 0:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
        running_loss = 0.0

  0%|          | 0/100 [00:00<?, ?it/s]

[1, 16001] loss: 0.035


  1%|          | 1/100 [00:18<30:47, 18.66s/it]

[2, 16001] loss: 0.034


  2%|▏         | 2/100 [00:35<29:48, 18.26s/it]

[3, 16001] loss: 0.039


  3%|▎         | 3/100 [00:54<29:47, 18.43s/it]

[4, 16001] loss: 0.035


  4%|▍         | 4/100 [01:14<30:17, 18.94s/it]

[5, 16001] loss: 0.036


  5%|▌         | 5/100 [01:34<30:06, 19.02s/it]

[6, 16001] loss: 0.036


  6%|▌         | 6/100 [01:52<29:15, 18.68s/it]

[7, 16001] loss: 0.033


  7%|▋         | 7/100 [02:09<28:23, 18.31s/it]

[8, 16001] loss: 0.037


  8%|▊         | 8/100 [02:28<28:24, 18.53s/it]

[9, 16001] loss: 0.037


  9%|▉         | 9/100 [02:45<27:32, 18.16s/it]

[10, 16001] loss: 0.040


 10%|█         | 10/100 [03:03<27:09, 18.10s/it]

[11, 16001] loss: 0.036


 11%|█         | 11/100 [03:21<26:42, 18.00s/it]

[12, 16001] loss: 0.035


 12%|█▏        | 12/100 [03:39<26:28, 18.05s/it]

[13, 16001] loss: 0.036


 13%|█▎        | 13/100 [03:58<26:23, 18.20s/it]

[14, 16001] loss: 0.034


 14%|█▍        | 14/100 [04:18<26:52, 18.75s/it]

[15, 16001] loss: 0.037


 15%|█▌        | 15/100 [04:37<26:47, 18.91s/it]

[16, 16001] loss: 0.037


 16%|█▌        | 16/100 [04:57<26:54, 19.22s/it]

[17, 16001] loss: 0.037


 17%|█▋        | 17/100 [05:17<26:57, 19.49s/it]

[18, 16001] loss: 0.039


 18%|█▊        | 18/100 [05:37<26:40, 19.52s/it]

[19, 16001] loss: 0.034


 19%|█▉        | 19/100 [05:56<26:20, 19.51s/it]

[20, 16001] loss: 0.036


 20%|██        | 20/100 [06:16<26:03, 19.54s/it]

[21, 16001] loss: 0.037


 21%|██        | 21/100 [06:35<25:45, 19.56s/it]

[22, 16001] loss: 0.036


 22%|██▏       | 22/100 [06:55<25:21, 19.50s/it]

[23, 16001] loss: 0.037


 23%|██▎       | 23/100 [07:16<25:30, 19.88s/it]

[24, 16001] loss: 0.037


 24%|██▍       | 24/100 [07:35<25:05, 19.82s/it]

[25, 16001] loss: 0.033


 25%|██▌       | 25/100 [07:54<24:22, 19.50s/it]

[26, 16001] loss: 0.037


 26%|██▌       | 26/100 [08:13<23:53, 19.38s/it]

[27, 16001] loss: 0.039


 27%|██▋       | 27/100 [08:33<23:51, 19.61s/it]

[28, 16001] loss: 0.034


 28%|██▊       | 28/100 [08:54<23:57, 19.96s/it]

[29, 16001] loss: 0.038


 29%|██▉       | 29/100 [09:15<23:53, 20.19s/it]

[30, 16001] loss: 0.037


 30%|███       | 30/100 [09:34<23:15, 19.94s/it]

[31, 16001] loss: 0.035


 31%|███       | 31/100 [09:54<22:51, 19.87s/it]

[32, 16001] loss: 0.037


 32%|███▏      | 32/100 [10:13<22:17, 19.67s/it]

[33, 16001] loss: 0.037


 33%|███▎      | 33/100 [10:32<21:51, 19.57s/it]

[34, 16001] loss: 0.035


 34%|███▍      | 34/100 [10:52<21:41, 19.72s/it]

[35, 16001] loss: 0.035


 35%|███▌      | 35/100 [11:11<20:50, 19.24s/it]

[36, 16001] loss: 0.035


 36%|███▌      | 36/100 [11:33<21:38, 20.30s/it]

[37, 16001] loss: 0.035


 37%|███▋      | 37/100 [11:56<22:11, 21.13s/it]

[38, 16001] loss: 0.037


 38%|███▊      | 38/100 [12:19<22:23, 21.66s/it]

[39, 16001] loss: 0.038


 39%|███▉      | 39/100 [12:41<22:11, 21.83s/it]

[40, 16001] loss: 0.034


 40%|████      | 40/100 [13:03<21:46, 21.78s/it]

[41, 16001] loss: 0.037


 41%|████      | 41/100 [13:21<20:13, 20.56s/it]

[42, 16001] loss: 0.034


 42%|████▏     | 42/100 [13:39<19:04, 19.73s/it]

[43, 16001] loss: 0.037


 43%|████▎     | 43/100 [13:57<18:21, 19.33s/it]

[44, 16001] loss: 0.034


 44%|████▍     | 44/100 [14:15<17:39, 18.92s/it]

[45, 16001] loss: 0.036


 45%|████▌     | 45/100 [14:33<17:10, 18.73s/it]

[46, 16001] loss: 0.033


 46%|████▌     | 46/100 [14:52<16:47, 18.65s/it]

[47, 16001] loss: 0.037


 47%|████▋     | 47/100 [15:09<16:11, 18.33s/it]

[48, 16001] loss: 0.035


 48%|████▊     | 48/100 [15:28<15:56, 18.40s/it]

[49, 16001] loss: 0.034


 49%|████▉     | 49/100 [15:49<16:19, 19.21s/it]

[50, 16001] loss: 0.035


 50%|█████     | 50/100 [16:10<16:22, 19.65s/it]

[51, 16001] loss: 0.034


 51%|█████     | 51/100 [16:30<16:18, 19.97s/it]

[52, 16001] loss: 0.038


 52%|█████▏    | 52/100 [16:52<16:16, 20.34s/it]

[53, 16001] loss: 0.037


 53%|█████▎    | 53/100 [17:13<16:05, 20.55s/it]

[54, 16001] loss: 0.037


 54%|█████▍    | 54/100 [17:32<15:35, 20.33s/it]

[55, 16001] loss: 0.036


 55%|█████▌    | 55/100 [17:52<14:58, 19.96s/it]

[56, 16001] loss: 0.036


 56%|█████▌    | 56/100 [18:11<14:26, 19.69s/it]

[57, 16001] loss: 0.034


 57%|█████▋    | 57/100 [18:30<14:03, 19.61s/it]

[58, 16001] loss: 0.035


 58%|█████▊    | 58/100 [18:49<13:33, 19.38s/it]

[59, 16001] loss: 0.035


 59%|█████▉    | 59/100 [19:08<13:07, 19.22s/it]

[60, 16001] loss: 0.035


 60%|██████    | 60/100 [19:28<13:00, 19.52s/it]

[61, 16001] loss: 0.038


 61%|██████    | 61/100 [19:47<12:37, 19.41s/it]

[62, 16001] loss: 0.037


 62%|██████▏   | 62/100 [20:07<12:21, 19.51s/it]

[63, 16001] loss: 0.036


 63%|██████▎   | 63/100 [20:27<12:07, 19.66s/it]

[64, 16001] loss: 0.034


 64%|██████▍   | 64/100 [20:46<11:46, 19.61s/it]

[65, 16001] loss: 0.031


 65%|██████▌   | 65/100 [21:06<11:25, 19.58s/it]

[66, 16001] loss: 0.032


 66%|██████▌   | 66/100 [21:26<11:08, 19.67s/it]

[67, 16001] loss: 0.038


 67%|██████▋   | 67/100 [21:46<10:54, 19.82s/it]

[68, 16001] loss: 0.038


 68%|██████▊   | 68/100 [22:06<10:33, 19.79s/it]

[69, 16001] loss: 0.034


 69%|██████▉   | 69/100 [22:26<10:18, 19.96s/it]

[70, 16001] loss: 0.036


 70%|███████   | 70/100 [22:47<10:03, 20.12s/it]

[71, 16001] loss: 0.031


 71%|███████   | 71/100 [23:07<09:43, 20.11s/it]

[72, 16001] loss: 0.034


 72%|███████▏  | 72/100 [23:26<09:12, 19.75s/it]

[73, 16001] loss: 0.033


 73%|███████▎  | 73/100 [23:44<08:44, 19.44s/it]

[74, 16001] loss: 0.035


 74%|███████▍  | 74/100 [24:05<08:36, 19.86s/it]

[75, 16001] loss: 0.033


 75%|███████▌  | 75/100 [24:25<08:17, 19.89s/it]

[76, 16001] loss: 0.035


 76%|███████▌  | 76/100 [24:44<07:48, 19.53s/it]

[77, 16001] loss: 0.037


 77%|███████▋  | 77/100 [25:02<07:22, 19.22s/it]

[78, 16001] loss: 0.034


 78%|███████▊  | 78/100 [25:21<06:59, 19.06s/it]

[79, 16001] loss: 0.038


 79%|███████▉  | 79/100 [25:39<06:36, 18.89s/it]

[80, 16001] loss: 0.038


 80%|████████  | 80/100 [25:58<06:15, 18.80s/it]

[81, 16001] loss: 0.036


 81%|████████  | 81/100 [26:17<05:55, 18.73s/it]

[82, 16001] loss: 0.036


 82%|████████▏ | 82/100 [26:35<05:36, 18.68s/it]

[83, 16001] loss: 0.036


 83%|████████▎ | 83/100 [26:54<05:17, 18.66s/it]

[84, 16001] loss: 0.032


 84%|████████▍ | 84/100 [27:12<04:57, 18.59s/it]

[85, 16001] loss: 0.036


 85%|████████▌ | 85/100 [27:31<04:38, 18.53s/it]

[86, 16001] loss: 0.035


 86%|████████▌ | 86/100 [27:49<04:19, 18.56s/it]

[87, 16001] loss: 0.033


 87%|████████▋ | 87/100 [28:08<04:00, 18.49s/it]

[88, 16001] loss: 0.036


 88%|████████▊ | 88/100 [28:26<03:42, 18.58s/it]

[89, 16001] loss: 0.034


 89%|████████▉ | 89/100 [28:45<03:25, 18.64s/it]

[90, 16001] loss: 0.037


 90%|█████████ | 90/100 [29:04<03:06, 18.66s/it]

[91, 16001] loss: 0.033


 91%|█████████ | 91/100 [29:22<02:47, 18.64s/it]

[92, 16001] loss: 0.033


 92%|█████████▏| 92/100 [29:41<02:28, 18.54s/it]

[93, 16001] loss: 0.037


 93%|█████████▎| 93/100 [29:59<02:10, 18.59s/it]

[94, 16001] loss: 0.039


 94%|█████████▍| 94/100 [30:19<01:52, 18.75s/it]

[95, 16001] loss: 0.035


 95%|█████████▌| 95/100 [30:38<01:34, 18.82s/it]

[96, 16001] loss: 0.034


 96%|█████████▌| 96/100 [30:56<01:15, 18.82s/it]

[97, 16001] loss: 0.035


 97%|█████████▋| 97/100 [31:15<00:56, 18.79s/it]

[98, 16001] loss: 0.033


 98%|█████████▊| 98/100 [31:34<00:37, 18.81s/it]

[99, 16001] loss: 0.036


 99%|█████████▉| 99/100 [31:53<00:18, 18.84s/it]

[100, 16001] loss: 0.037


100%|██████████| 100/100 [32:11<00:00, 19.32s/it]


In [139]:
# Generate testing data
N_test = 12800
# test_data_gen = bitmap.generate_test_set(N_test, 45, max_delta=1)
test_data_gen = bitmap.generate_test_set(N_test, 45, min_delta=5, max_delta=5)
delta, test_stop_boards = map(list, zip(*list(test_data_gen)))

In [141]:
def get_mae(stop_boards, n, delta=1):
    stop_boards_tensor = torch.tensor(stop_boards).view(n, 1, 25, 25).float()
    with torch.no_grad():
        net.eval()
        # Make prediction
        output = (net(stop_boards_tensor) > 0.5).int()
        # Evolve one step forward
        predicted_stop_board = output
        for _ in range(delta):
            predicted_stop_board = (forward_model.forward(predicted_stop_board) > 0.5).int()
        error = torch.sum(predicted_stop_board != stop_boards_tensor.int())
        return error / (n * 25 * 25)

In [142]:
# Training data MAE
train_mae = get_mae(stop_boards, N, 5).item()
print("The training data MAE is {:.6f}.".format(train_mae))

# Test data MAE
test_mae = get_mae(test_stop_boards, N_test, 5).item()
print("The test data MAE is {:.6f}.".format(test_mae))

The training data MAE is 0.139022.
The test data MAE is 0.138522.


In [132]:
predicted_stop_board

NameError: name 'predicted_stop_board' is not defined