In [69]:
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.optim as optim
from torch.autograd import Variable
import torch
from tqdm import tqdm
import sys

In [70]:
from torch.utils.tensorboard import SummaryWriter

In [71]:
sys.path.insert(0,'..')
import bitmap
import itertools

# Generate data

## Training data

In [72]:
REGENERATE_DATA = False

In [73]:
# Generate or load training data
N = 51200

if REGENERATE_DATA:
    train_data_gen = bitmap.generate_train_set(N, 41, min_delta=1, max_delta=1)
    deltas, start_boards, stop_boards = map(np.array, zip(*list(train_data_gen)))
    # Save training data
    np.save('../../data/training_start_boards', start_boards)
    np.save('../../data/training_stop_boards', stop_boards)
else:
    start_boards = np.load('../../data/training_start_boards.npy')
    stop_boards = np.load('../../data/training_stop_boards.npy')

## Validation data

In [74]:
# Generate or load validation data
N_valid = 12800

if REGENERATE_DATA:
    valid_data_gen = bitmap.generate_train_set(N_valid, 1024, min_delta=1, max_delta=1)
    deltas, valid_start_boards, valid_stop_boards = map(np.array, zip(*list(valid_data_gen)))
    # Save validation data
    np.save('../../data/valid_start_boards', valid_start_boards)
    np.save('../../data/valid_stop_boards', valid_stop_boards)
else:
    valid_start_boards = np.load('../../data/valid_start_boards.npy')
    valid_stop_boards = np.load('../../data/valid_stop_boards.npy')

In [75]:
X_valid = Variable(torch.tensor(valid_start_boards).view(N_valid, 1, 25, 25).float())
y_valid = Variable(torch.tensor(valid_stop_boards).view(N_valid, 1, 25, 25).float())

## Test data

In [76]:
# Generate or load test data
N_test = 25600

if REGENERATE_DATA:
    test_data_gen = bitmap.generate_train_set(N_test, 42, min_delta=1, max_delta=1)
    deltas, test_start_boards, test_stop_boards = map(np.array, zip(*list(test_data_gen)))
    # Save test data
    np.save('../../data/test_start_boards', test_start_boards)
    np.save('../../data/test_stop_boards', test_stop_boards)
else:
    test_start_boards = np.load('../../data/test_start_boards.npy')
    test_stop_boards = np.load('../../data/test_stop_boards.npy')

# Forward evolver

In [77]:
def actual_forward(x, use_cuda=False): 
    # Weights for layer 1
    weight1 = torch.tensor([[[1, 1, 1], [1, 0.1, 1], [1, 1, 1]],
                            [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]).view(2, 1, 3, 3).float()
    b1 = torch.tensor([-3, -2]).float()
    # Weights for layer 2
    weight2 = torch.tensor([-10, 1]).view(1, 2, 1, 1).float()
    # Weights for layer 3
    s = 20
    weight3 = torch.tensor([2*s]).view(1, 1, 1, 1).float()
    b3 = torch.tensor([-s]).float()
    
    if use_cuda:
        torch.backends.cudnn.deterministic = True
        weight1 = weight1.cuda()
        b1 = b1.cuda()
        weight2 = weight2.cuda()
        weight3 = weight3.cuda()
        b3 = b3.cuda()

    x = F.pad(x.float(), (1, 1, 1, 1), mode='circular')
    x = F.relu(F.conv2d(x, weight1, b1))
    x = F.relu(F.conv2d(x, weight2))
    x = torch.sigmoid(F.conv2d(x, weight3, b3))
    return x

# Model trainer

In [106]:
def train(model, X, y, X_valid, y_valid, 
          optim, criterion, output_path, num_epochs=50, batch_size=128):
    # Release CUDA memory
    # torch.cuda.empty_cache()

    # Set optimizer
    optimizer = optim(model.parameters())
    
    # Setup Tensorboard (https://pytorch.org/docs/stable/tensorboard.html)
    writer = SummaryWriter()
    # writer.add_graph(model.cpu(), X)
    model.cuda()

    # Best validation MAE
    best_valid_mae = 1
    
    # Train
    n_iter = 0
    for epoch in range(num_epochs): 
        permutation = torch.randperm(X.size()[0])
        running_loss = 0.0
        pbar = tqdm(range(0, X.size()[0], batch_size))
        for i in pbar:
            n_iter += 1
            indices = permutation[i:i+batch_size]
            batch = X[indices].cuda()
            target = y[indices].cuda()
        
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()
            
            # Calculate MAE
            if hasattr(model, "reverse_net"):
                pred_start_boards = model.reverse_net(batch)
                outputs = actual_forward((pred_start_boards > 0.5).int(), True)
            output_boards = (outputs > 0.5).int()
            mae = torch.sum(output_boards != target).float() / (batch_size * 25 * 25)
            
            # Write data to Tensorboard
            writer.add_scalar('Loss/train', loss.item(), n_iter)
            writer.add_scalar('MAE/train', mae.item(), n_iter)
            
            pbar.set_description("[{:d}, {:5d}] loss: {:.6f} | train MAE {:.6f} | best MAE: {:.6f}".format(epoch + 1, i + 1, loss.item(), mae, best_valid_mae))
            
            # Write boards and validation results to Tensorboard every 50 batches
            if n_iter % 50 == 0:
                with torch.no_grad():
                    model.eval()
                    valid_loss = 0
                    valid_mae = 0
                    m = 0
                    for j in range(0, X_valid.size()[0], batch_size):
                        m += 1
                        valid_batch = X_valid[j:j+batch_size].cuda()
                        valid_target = y_valid[j:j+batch_size].cuda()
                        valid_outputs = model(valid_batch)
                        valid_loss += criterion(valid_outputs, valid_target)
                        if hasattr(model, "reverse_net"):
                            valid_start_boards = model.reverse_net(valid_batch)
                            valid_outputs = actual_forward((valid_start_boards > 0.5).int(), True)
                        valid_boards = (valid_outputs > 0.5).int()
                        valid_mae += torch.sum(valid_boards != valid_target).float()
                    valid_loss /= m
                    valid_mae /= (X_valid.size()[0] * 25 * 25)
                    writer.add_image('predicted stop board', valid_boards[-1], n_iter)
                    writer.add_image('actual stop board', y_valid[-1], n_iter)
                    if hasattr(model, "reverse_net"):
                        pred_start_board = (model.reverse_net(X_valid[-1].view(1, 1, 25, 25).cuda()) > 0.5).int()
                        writer.add_image('predicted start board', pred_start_board[-1], n_iter)
                    writer.add_scalar('Loss/valid', valid_loss.item(), n_iter)
                    writer.add_scalar('MAE/valid', valid_mae.item(), n_iter)
                    
                if valid_mae < best_valid_mae:
                    best_valid_mae = valid_mae
                    # Save model if we have the lastest best MAE
                    torch.save(model.state_dict(), output_path)
    writer.close()
    print("The best validation MAE: {}".format(best_valid_mae))

# MAE Evaluation

In [79]:
def get_forward_mae(model, weight_path, input_boards, output_boards, n):
    # Release CUDA memory
    torch.cuda.empty_cache()
    # Load model
    model.load_state_dict(torch.load(weight_path))
    model.cuda()
    # Convert boards to tensor
    input_boards_tensor = torch.tensor(input_boards[:n]).view(n, 1, 25, 25).float().cuda()
    output_boards_tensor = torch.tensor(output_boards[:n]).view(n, 1, 25, 25)
    with torch.no_grad():
        model.eval()
        # Make prediction
        if hasattr(model, "reverse_net"):
            predicted_start_board = model.reverse_net(input_boards_tensor)
            predicted_output_board = (actual_forward.forward((predicted_start_board > 0.5).int()) > 0.5).int()
        else:
            predicted_output_board = (model(input_boards_tensor) > 0.5).int()
        error = torch.sum(predicted_output_board.cpu() != output_boards_tensor)
        # print(predicted_stop_board)
        # print(stop_boards_tensor)
        return error / (n * 25 * 25)

# Relax starting boards

In [80]:
# Modify starting boards 
def relax_boards(boards):
    np.random.seed(41)
    return np.abs(np.random.rand(*boards.shape) / 2 - boards)

In [81]:
relaxed_start_boards = relax_boards(start_boards)
relaxed_valid_start_boards = relax_boards(valid_start_boards)

In [82]:
class RelaxedForwardNet(nn.Module):
    def __init__(self):
        super(RelaxedForwardNet, self).__init__()
        # in channels, out channels, kernel size
        self.conv0 = nn.Conv2d(1, 8, (1, 1))
        self.activ0 = nn.ReLU()
        self.conv1 = nn.Conv2d(8, 16, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ1 = nn.PReLU()
        self.conv2 = nn.Conv2d(16, 8, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ2 = nn.PReLU()
        self.conv3 = nn.Conv2d(8, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ3 = nn.PReLU()
        self.conv4 = nn.Conv2d(4, 1, (3, 3), padding=(1, 1), padding_mode='circular')
        
    def forward(self, x):
        x = self.activ0(self.conv0(x))
        x = self.activ1(self.conv1(x))
        x = self.activ2(self.conv2(x))
        x = self.activ3(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

In [120]:
relaxed_forward_net = RelaxedForwardNet()
criterion = nn.BCELoss()

relaxed_forward_model_path = "../models/johnson/relaxed_forward.pkl"

# Reverse model

## Reverse model version A

In [121]:
class ReverseNetA(nn.Module):
    def __init__(self):
        super(ReverseNetA, self).__init__()
        # in channels, out channels, kernel size
        self.conv0 = nn.Conv2d(1, 4, (1, 1))
        self.activ0 = nn.ReLU()
        self.conv1_7 = nn.Conv2d(4, 4, (7, 7), padding=(3, 3), padding_mode='circular')
        self.conv1_5 = nn.Conv2d(4, 4, (5, 5), padding=(2, 2), padding_mode='circular')
        self.conv1_3 = nn.Conv2d(4, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.conv1_1 = nn.Conv2d(4, 4, (1, 1))
        self.activ1 = nn.PReLU()
        self.conv2_5 = nn.Conv2d(16, 4, (5, 5), padding=(2, 2), padding_mode='circular')
        self.conv2_3 = nn.Conv2d(16, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ2 = nn.PReLU()
        self.conv3 = nn.Conv2d(8, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ3 = nn.PReLU()
        self.conv4 = nn.Conv2d(4, 1, (3, 3), padding=(1, 1), padding_mode='circular')

    def forward(self, x):
        x = self.activ0(self.conv0(x))
        x = self.activ1(torch.cat((self.conv1_1(x), self.conv1_3(x), 
                                   self.conv1_5(x), self.conv1_7(x)), 1))
        x = self.activ2(torch.cat((self.conv2_3(x), self.conv2_5(x)), 1))
        x = self.activ3(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

## Reverse model version B

In [122]:
class ReverseNetB(nn.Module):
    def __init__(self):
        super(ReverseNetB, self).__init__()
        # in channels, out channels, kernel size
        self.conv0 = nn.Conv2d(1, 4, (1, 1))
        self.activ0 = nn.ReLU()
        self.conv1_5 = nn.Conv2d(4, 8, (5, 5), padding=(2, 2), padding_mode='circular')
        self.conv1_3 = nn.Conv2d(4, 8, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ1 = nn.PReLU()
        self.conv2_5 = nn.Conv2d(16, 4, (5, 5), padding=(2, 2), padding_mode='circular')
        self.conv2_3 = nn.Conv2d(16, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ2 = nn.PReLU()
        self.conv3 = nn.Conv2d(8, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ3 = nn.PReLU()
        self.conv4 = nn.Conv2d(4, 1, (3, 3), padding=(1, 1), padding_mode='circular')

    def forward(self, x):
        x = self.activ0(self.conv0(x))
        x = self.activ1(torch.cat((self.conv1_3(x), self.conv1_5(x)), 1))
        x = self.activ2(torch.cat((self.conv2_3(x), self.conv2_5(x)), 1))
        x = self.activ3(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

## Reverse model version C

In [123]:
class ReverseNetC(nn.Module):
    def __init__(self):
        super(ReverseNetC, self).__init__()
        # in channels, out channels, kernel size
        self.conv0 = nn.Conv2d(1, 8, (1, 1))
        self.activ0 = nn.ReLU()
        self.conv1 = nn.Conv2d(8, 16, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ1 = nn.PReLU()
        self.conv2 = nn.Conv2d(16, 8, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ2 = nn.PReLU()
        self.conv3 = nn.Conv2d(8, 4, (3, 3), padding=(1, 1), padding_mode='circular')
        self.activ3 = nn.PReLU()
        self.conv4 = nn.Conv2d(4, 1, (3, 3), padding=(1, 1), padding_mode='circular')

    def forward(self, x):
        x = self.activ0(self.conv0(x))
        x = self.activ1(self.conv1(x))
        x = self.activ2(self.conv2(x))
        x = self.activ3(self.conv3(x))
        x = torch.sigmoid(self.conv4(x))
        return x

## Forward-Reverse net

In [124]:
class ReverseForwardNet(nn.Module):
    def __init__(self, ForwardNet, ReverseNet):
        super(ReverseForwardNet, self).__init__()
        self.reverse_net = ReverseNet()
        # freeze the weights of the forward net
        self.forward_net = ForwardNet()
        
    def forward(self, x):
        x = self.reverse_net(x)
        x = self.forward_net(x)
        return x

In [125]:
MODEL_VERSION = 'C'

if MODEL_VERSION == 'A':
    ReverseNet = ReverseNetA
elif MODEL_VERSION == 'B':
    ReverseNet = ReverseNetB
elif MODEL_VERSION == 'C':
    ReverseNet = ReverseNetC

rf_net = ReverseForwardNet(RelaxedForwardNet, ReverseNet)

In [126]:
criterion = nn.BCELoss()

In [127]:
X_rf = Variable(torch.tensor(stop_boards).view(N, 1, 25, 25).float(), requires_grad=True)
y_rf = Variable(torch.tensor(stop_boards).view(N, 1, 25, 25).float())
X_valid_rf = y_valid
y_valid_rf = y_valid

rf_model_path = "../models/johnson/reverse_forward.pkl"

# Main training loop

In [128]:
# relaxed_forward_net.load_state_dict(torch.load(relaxed_forward_model_path))
# rf_net.load_state_dict(torch.load(rf_model_path))

In [129]:
X_relaxed = Variable(torch.tensor(relaxed_start_boards).view(N, 1, 25, 25).float(), requires_grad=True)
X_valid_relaxed = Variable(torch.tensor(relaxed_valid_start_boards).view(N_valid, 1, 25, 25).float())

In [137]:
NUM_EPOCHS = 5
for i in range(NUM_EPOCHS):
    # Evolved the relaxed boards
    y_relaxed = (actual_forward((X_relaxed > 0.5).int()) > 0.5).float()
    y_valid_relaxed = (actual_forward((X_valid_relaxed > 0.5).int()) > 0.5).float()
    
    # Load relaxed forward net
    relaxed_forward_net = RelaxedForwardNet()
    relaxed_forward_net.load_state_dict(torch.load(relaxed_forward_model_path))
    # Train relaxed forward net
    criterion = nn.BCELoss()
    train(relaxed_forward_net, X_relaxed, y_relaxed, X_valid_relaxed, y_valid_relaxed, 
          optim.Adam, criterion, relaxed_forward_model_path, batch_size=128, num_epochs=3)
    
    # Load reverse-forward net
    rf_net = ReverseForwardNet(RelaxedForwardNet, ReverseNet)
    rf_net.load_state_dict(torch.load(rf_model_path))
    # Load relaxed forward net weights and freeze them
    rf_net.forward_net.load_state_dict(torch.load(relaxed_forward_model_path.format(i)))
    for param in rf_net.forward_net.parameters():
        param.requires_grad = False
    rf_net.cuda()
        
    # Train reverse net
    criterion = nn.BCELoss()
    train(rf_net, X_rf, y_rf, X_valid_rf, y_valid_rf, 
          optim.Adam, criterion, rf_model_path, batch_size=128, num_epochs=3)
    
    # Create new X_relaxed and y_relaxed
    X_relaxed = rf_net.reverse_net.cpu()(y_relaxed)
    X_valid_relaxed = rf_net.reverse_net.cpu()(y_valid_relaxed)
    
    relaxed_forward_net.load_state_dict(torch.load(relaxed_forward_model_path.format(i)))

  0%|          | 0/400 [00:00<?, ?it/s]


RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

In [None]:
# Training data MAE
train_mae = get_forward_mae(rf_net, rf_model_path, stop_boards, stop_boards, N)
print("The training data MAE is {:.6f}.".format(train_mae))

# Test data MAE
test_mae = get_forward_mae(rf_net, rf_model_path, test_stop_boards, test_stop_boards, N_test)
print("The test data MAE is {:.6f}.".format(test_mae))