In [None]:
import random
import socket
import math
import torch
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
from collections import deque

In [None]:
#env param
ENV_IP = "127.0.0.1"
ENV_PORT = 5000
AREA_SIZE = 15
NUM_STATE = 5

#learning param
NUM_EPISODES = 20000
BATCH_SIZE = 256
MEMORY_SIZE = 100000
SYNC_INTERVAL = 500
SAVE_INTERVAL = 1000
TAU = 0.5

#e-greedy param
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.1
EPS_DECAY = 5000

device = torch.device("cuda:0")

In [None]:
#ReplayMemory
class ReplayMemory(object):
    def __init__(self, size):
        self.memory = deque(maxlen=size)
    def __len__(self):
        return len(self.memory)
    def push(self, transition):
        self.memory.append(transition)
    def sample(self, size):
        return random.sample(self.memory, size)

In [None]:
#Player
class Player():
    def __init__(self):
        self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.client.connect((ENV_IP, ENV_PORT))
        print("Unity server connected")

    def __del__(self):
        print("Unity server disconnecting")
        self.client.close()
        
    def step(self, cmd):
        self.client.sendall(cmd.encode())
        self.result = ""
        while len(self.result) < 297:
            self.result += self.client.recv(512).decode()
        self.state = np.array(list(map(int, self.result[: (AREA_SIZE + 2) ** 2])), dtype=np.float32)
        self.state = self.state.reshape(1, AREA_SIZE + 2, AREA_SIZE + 2)
        self.state /= NUM_STATE - 1
        self.reward = float(self.result[(AREA_SIZE + 2) **2: -2])
        self.playing = True if self.result[-1] == "T" else False
        return (self.state, self.reward, self.playing)

In [None]:
#Dueling Network
class Dueling_N(nn.Module):
    def __init__(self):
        print("Making Network")
        super().__init__()
        self.ReLU = nn.ReLU()
        self.flatten = nn.Flatten()
        self.conv1 = nn.Conv2d(in_channels=1,  out_channels=32,  kernel_size=3, stride=1, padding=1) #畳み込み層1
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64,  kernel_size=4, stride=2, padding=1) #畳み込み層2
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=1) #畳み込み層3
        self.val_fc = nn.Sequential(nn.Linear(4608, 512), nn.ReLU(), nn.Linear(512, 1)) #行動価値層
        self.adv_fc = nn.Sequential(nn.Linear(4608, 512), nn.ReLU(), nn.Linear(512, 4)) #Advantage層

    def forward(self, x):
        x = self.ReLU(self.conv1(x))
        x = self.ReLU(self.conv2(x))
        x = self.ReLU(self.conv3(x))
        x = self.flatten(x)
        x_val = self.val_fc(x)
        x_adv = self.adv_fc(x)
        q = x_val + x_adv - x_adv.mean(dim=1, keepdim=True).detach()
        return q #output Q value

In [None]:
#Train
class Train():
    def __init__(self):
        self.target_net = Dueling_N().to(device) #Target Network
        self.policy_net = Dueling_N().to(device) #Policy Network
        self.optimizer = torch.optim.AdamW(self.policy_net.parameters(), lr=1e-4, weight_decay=1e-3, amsgrad=True)
        self.writer = SummaryWriter("./")
        self.memory = ReplayMemory(MEMORY_SIZE) #Replay Memory
        self.cmds = ["L", "R", "U", "D"]
        self.game_step = 0
        print("Ready for training")
    
    def train_model(self, epoch):
        if len(self.memory) < BATCH_SIZE:
            return
        batch = self.memory.sample(BATCH_SIZE)
        state_batch, action_batch, next_state_batch, reward_batch = zip(*batch)
        state_batch = torch.from_numpy(np.array(state_batch)).to(device)
        action_batch = torch.tensor(action_batch).unsqueeze(1).to(device)
        next_state_batch = torch.from_numpy(np.array(next_state_batch)).to(device)
        reward_batch = torch.tensor(reward_batch).unsqueeze(1).to(device)
        state_q_values = self.policy_net(state_batch).gather(1, action_batch)
        next_state_q_values = self.target_net(next_state_batch).max(1)[0].detach().unsqueeze(1)
        target_q_values = reward_batch + GAMMA * next_state_q_values
        loss = nn.SmoothL1Loss()(state_q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.writer.add_scalar("Loss", loss.item(), epoch)

    def network_sync(self):
        target_net_dict = self.target_net.state_dict()
        policy_net_dict = self.policy_net.state_dict()
        for key in policy_net_dict:
            target_net_dict[key] *= 1 - TAU
            target_net_dict[key] += policy_net_dict[key] * TAU
            self.target_net.load_state_dict(target_net_dict)

    def save_weight(self, fname):
        torch.save(self.target_net.state_dict(), str(fname) + ".pth")

    def e_greedy(self, state, i):
        eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp((-1 * i) / EPS_DECAY)
        if self.game_step < 1000:
            eps_threshold += self.game_step / 2000
        else:
            eps_threshold += 0.5
        if random.random() > eps_threshold: #greedy
            with torch.no_grad():
                command = self.policy_net(torch.from_numpy(state).clone().unsqueeze(0).to(device))
                command = command.argmax().view(1, 1).cpu()
                return command.item()
        else: #random
            return random.randrange(4)
        
    def train(self, num):
        self.player = Player()
        for i in tqdm(range(num + 1)):
            self.game_step = 0
            past_state, reward, playing = self.player.step("A")
            while True:
                self.game_step += 1
                cmd = self.e_greedy(past_state, i)
                state, reward, playing = self.player.step(self.cmds[cmd])
                reward += self.game_step / 1000
                self.memory.push((past_state, cmd, state, reward))
                past_state = state
                if not playing:
                    break
            self.train_model(i)
            if i % SYNC_INTERVAL == 0:
                print("Target Net Sync")
                self.network_sync()
            if i % SAVE_INTERVAL == 0:
                print("Save Weight")
                self.save_weight(i)
        del(self.player)

In [None]:
train = Train()
train.train(20000)

In [None]:
#Visualize Network
from torchinfo import summary

model = Dueling_N()
summary(model, input_size=(1, 1, 17, 17))