In [89]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from copy import deepcopy
from queue import Queue

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [90]:
W=10
H=10
N=0
inf=1<<30

deltaPos = [ # up, down, left, right
    [(-2, 0, 1), (1, 0, 1), (0, -2, 2), (0, 1, 2)],
    [(-1, 0, 0), (2, 0, 0), (0, -1, 1), (0, 1, 1)],
    [(-1, 0, 2), (1, 0, 2), (0, -1, 0), (0, 2, 0)]
]

In [91]:
class State:
    def __init__(self, x, y, dir):
        self.x = x
        self.y = y
        self.dir = dir
    
    def move(self, dir: int):
        ret = State(self.x+deltaPos[self.dir][dir][0], self.y+deltaPos[self.dir][dir][1], deltaPos[self.dir][dir][2])
        return ret
    
    def occupied(self):
        ret = [(self.x, self.y)]
        if self.dir == 1: ret.append((self.x+1, self.y))
        if self.dir == 2: ret.append((self.x, self.y+1))
        return ret
    
    def to_map(self):
        ret=[]
        for _ in range(W): ret.append([0 for _ in range(H)])
        for x, y in self.occupied(): ret[x][y] = 1
        return ret
    
    def copy(self):
        return State(self.x, self.y, self.dir)
    
    def __eq__(self, other):
        return self.x == other.x and self.y == other.y and self.dir == other.dir
    

In [92]:
def decode(map):
    m = deepcopy(map)
    for i in range(W):
        for j in range(H):
            if m[i][j] == -1:
                m[i][j] = 1
                return (np.array(m), State(i, j, 0))

def bound(x, y):
    return x >= 0 and x < W and y >= 0 and y < H

In [93]:
class Board:
    def __init__(self, map: np.array, exit: State):
        self.map = deepcopy(map)
        self.exit = exit.copy()
        self.dist = np.full((W,H,3), -1, dtype=int)
        self.bfs(exit)
    
    def valid(self, state: State):
        for x, y in state.occupied():
            if not bound(x, y) or self.map[x][y] == 0: return False
        return True
    
    def bfs(self, s: State):
        q = Queue()
        q.put((s, 0))
        while not q.empty():
            st, d = q.get()
            if not self.valid(st) or self.dist[st.x][st.y][st.dir] != -1:
                continue
            self.dist[st.x][st.y][st.dir] = d
            for dir in range(4):
                q.put((st.move(dir), d + 1))

    def get_adj(self, state: State):
        ret = []
        dc = self.dist[state.x][state.y][state.dir]
        for dir in range(4):
            st = state.move(dir)
            if self.valid(st): ret.append(max(0, dc-self.dist[st.x][st.y][st.dir]))
            else: ret.append(-1)
        return ret


In [94]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32*W*H, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 4)
    
    def forward(self, x):
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.conv2(x)
        x = nn.functional.relu(x)
        x = x.view(-1, 32*W*H)
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.fc2(x)
        x = nn.functional.relu(x)
        x = self.fc3(x)
        return x

In [95]:
class Train:
    def __init__(self):
        self.model = Model().to(device)
        self.board = None
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss = nn.MSELoss()
    
    def load(self, path):
        self.model.load_state_dict(torch.load(path))
    
    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def train_all(self, maps, predict=False):
        for i, map in enumerate(maps):
            self.board = Board(*decode(map))
            print(f'{i}/{len(maps)}: {self.train_cur(predict=predict)}')
    
    def train_cur(self, predict=False):
        num, sum = 0, 0
        for x in range(W):
            for y in range(H):
                for dir in range(3):
                    st = State(x, y, dir)
                    if not self.board.valid(st): continue
                    num += 1
                    sum += self.train(st, predict=predict)
        return sum / num

    def train(self, state: State, predict=False):
        inp = torch.tensor([self.board.map, state.to_map(), self.board.exit.to_map()], dtype=torch.float32).to(device)
        out = self.model(inp.unsqueeze(0))
        ans = torch.tensor(self.board.get_adj(state), dtype=torch.float32).to(device)
        if not predict:
            loss = self.loss(out, ans)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            return loss.item()
        out = out.squeeze().detach().numpy()
        ind = np.argmax(out)
        if out[ind] < 0 and ans == [-1, -1, -1, -1]: return 1
        if ans[ind] == 1: return 1
        return 0

In [96]:
class Parser:
    def __init__(self, path):
        self.path=path
    
    def parse(self):
        maps=[]
        with open(self.path, 'r') as f:
            data = f.read().split("\n\n")[:-1]
            for i in data:
                maps.append([[int(k) for k in j.split()] for j in i.split("\n")])
        return maps

In [97]:
maps = Parser('gen4.txt').parse()

In [98]:
train = Train()

In [99]:
# train.train_all(maps)

In [100]:
train.load('model.pt')

In [101]:
train.train_all(maps[:10], predict=True)

0/10: 0.6810344827586207
1/10: 0.5853658536585366
2/10: 0.6702127659574468
3/10: 0.5555555555555556
4/10: 0.6507936507936508
5/10: 0.759090909090909
6/10: 0.7040816326530612
7/10: 0.6234567901234568
8/10: 0.7130434782608696
9/10: 0.7033898305084746


In [102]:
train.train_all(maps[:10], predict=True)

0/10: 0.6810344827586207
1/10: 0.5853658536585366
2/10: 0.6702127659574468
3/10: 0.5555555555555556
4/10: 0.6507936507936508
5/10: 0.759090909090909
6/10: 0.7040816326530612
7/10: 0.6234567901234568
8/10: 0.7130434782608696
9/10: 0.7033898305084746


In [103]:
# train.save('model.pt')