In [None]:
import socket
import torch
import torch.nn as nn
import numpy as np

In [None]:
#env param
ENV_IP = "127.0.0.1"
ENV_PORT = 5000
AREA_SIZE = 15
NUM_STATE = 5

device = torch.device("cuda:0")

In [None]:
#Player
class Player():
    def __init__(self):
        self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.client.connect((ENV_IP, ENV_PORT))
        print("Unity server connected")

    def __del__(self):
        print("Unity server disconnecting")
        self.client.close()
        
    def step(self, cmd):
        self.client.sendall(cmd.encode())
        self.result = ""
        while len(self.result) < 297:
            self.result += self.client.recv(512).decode()
        self.state = np.array(list(map(int, self.result[: (AREA_SIZE + 2) ** 2])), dtype=np.float32)
        self.state = self.state.reshape(1, AREA_SIZE + 2, AREA_SIZE + 2)
        self.state /= NUM_STATE - 1
        self.reward = float(self.result[(AREA_SIZE + 2) **2: -2])
        self.playing = True if self.result[-1] == "T" else False
        return (self.state, self.reward, self.playing)

In [None]:
#Dueling Network
class Dueling_N(nn.Module):
    def __init__(self):
        print("Making Network")
        super().__init__()
        self.ReLU = nn.ReLU()
        self.flatten = nn.Flatten()
        self.conv1 = nn.Conv2d(in_channels=1,  out_channels=32,  kernel_size=3, stride=1, padding=1) #畳み込み層1
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64,  kernel_size=4, stride=2, padding=1) #畳み込み層2
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=1) #畳み込み層3
        self.val_fc = nn.Sequential(nn.Linear(4608, 512), nn.ReLU(), nn.Linear(512, 1)) #行動価値層
        self.adv_fc = nn.Sequential(nn.Linear(4608, 512), nn.ReLU(), nn.Linear(512, 4)) #Advantage層

    def forward(self, x):
        x = self.ReLU(self.conv1(x))
        x = self.ReLU(self.conv2(x))
        x = self.ReLU(self.conv3(x))
        x = self.flatten(x)
        x_val = self.val_fc(x)
        x_adv = self.adv_fc(x)
        q = x_val + x_adv - x_adv.mean(dim=1, keepdim=True).detach()
        return q #output Q value

In [None]:
#eval
target_net = Dueling_N().to(device)
target_net.load_state_dict(torch.load("20000.pth"))
target_net.eval()

player = Player()
cmds = ["L", "R", "U", "D"]
state, reward, playing = player.step("A")
while playing:
    cmd = target_net(torch.from_numpy(state.reshape(1, 1, AREA_SIZE + 2, AREA_SIZE + 2)).to(device))
    state, reward, playing = player.step(cmds[cmd.argmax().cpu().item()])
    print(cmd, reward)
del(player)