In [1]:
import torch.nn as nn
import torch.nn.functional as F
import gym_pikachu_volleyball
import gym
import time
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch
import torch.optim as optim
from itertools import count
from collections import deque, namedtuple
import random, math

In [2]:
class Memory :
    def __init__(self, capacity=10000) :
        self.memory = deque([], maxlen=capacity)
    def push(self, trans) :
        self.memory.append(trans)
    def sample(self, num) :
        return random.sample(self.memory, num)
    def __len__(self) :
        return len(self.memory)

In [3]:
class Q_net(nn.Module) :
    def __init__(self):
        super(Q_net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)
        self.fc_actor = nn.Linear(1120, 256)
        self.out_actor = nn.Linear(256, 18)

    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = nn.functional.max_pool2d(x, kernel_size=2)
        x = nn.functional.relu(self.conv2(x))
        x = nn.functional.max_pool2d(x, kernel_size=2)
        x = nn.functional.relu(self.conv3(x))
        x = nn.functional.max_pool2d(x, kernel_size=2)
        x = x.view(-1, 1120)
        actor_out = nn.functional.relu(self.fc_actor(x))
        actor_out = self.out_actor(actor_out)
        return actor_out

In [4]:
class Agent :
    def __init__(self, device, Transition, BATCH=32, GAMMA=0.99, TAU=0.005, LR=1e-4,
                EPS_START=0.9, EPS_END=0.05, EPS_DECAY=1000, capacity=100000) :
        self.device = device
        self.Transition = Transition
        self.BATCH = BATCH
        self.GAMMA = GAMMA
        self.TAU = TAU
        self.LR = LR
        self.EPS_START = EPS_START
        self.EPS_END = EPS_END
        self.EPS_DECAY = EPS_DECAY
        self.memory = Memory(capacity)
        self.n_observations = 4
        self.n_actions = 2
        self.time_step = 0
        self.policy_net = Q_net().to(device)
        #self.policy_net.apply(self.init_weights)
        self.target_net = Q_net().to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=LR, amsgrad=True)
    def init_weights(self, m):
        if type(m) == nn.Linear:
            torch.nn.init.xavier_normal_(m.weight)
            m.bias.data.fill_(0)
    def select_action(self, state) :
        self.time_step += 1
        sample = random.random()
        eps_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * math.exp(-1 * self.time_step / self.EPS_DECAY)
        if sample > eps_threshold :
            with torch.no_grad() :
                return self.policy_net(state).max(1)[1].view(1, 1)
        else :
            return torch.tensor([[random.randrange(0,2)]], device=self.device, dtype=torch.long)
    def optimize(self) :
        if len(self.memory) < self.BATCH * 100:
            return
        transitions = self.memory.sample(self.BATCH)
        batch = self.Transition(*zip(*transitions))
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=self.device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        #print(state_batch.shape)
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)
        next_state_values = torch.zeros(self.BATCH, device=self.device)
        with torch.no_grad() :
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]
        expected_state_action_values = (next_state_values * self.GAMMA) + reward_batch
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()
    def update(self) :
        target_net_state_dict = self.target_net.state_dict()
        policy_net_state_dict = self.policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*self.TAU + target_net_state_dict[key]*(1-self.TAU)
        self.target_net.load_state_dict(target_net_state_dict)

In [6]:
def convert_color(s) :
    nonblack = (s != [0, 0, 0]).any(axis=2)
    s[nonblack] = [255, 255, 255]
    result_img = Image.fromarray(s)
    result_img = np.array(result_img)
    result_img = 255 - result_img
    result_img = Image.fromarray(result_img)
    result_img = transform(result_img)
    #result_img = result_img / 255.0
    return result_img

In [None]:
transform = transforms.ToTensor()
env = gym.make("PikachuVolleyball-v0", render_mode = None)
option={'is_player1_serve' : True, 'is_player2_serve' : True}
device = torch.device('cpu')

Transition = namedtuple('Transition', 
                       ('state', 'action', 'next_state', 'reward'))
dqn_agent_1 = Agent(device, Transition)
dqn_agent_2 = Agent(device, Transition)

num_episode = 1000
scores = []

for i_episode in range(num_episode) :
    state = env.reset(options=option)
    state = convert_color(state).unsqueeze(0)
    i_episode_score = 0
    for t in count() :
        i_episode_score += 1
        action_1 = dqn_agent_1.select_action(state)
        action_2 = dqn_agent_2.select_action(state)
        observation, reward, terminated, _ = env.step([action_1.item(), action_2.item()])
        reward = torch.tensor([reward], device=device)
        reward = torch.tensor([1], device=device)
        done = terminated
        if terminated:
            next_state = None
        else:
            next_state =  convert_color(observation).unsqueeze(0)
        dqn_agent_1.memory.push(Transition(state, action_1, next_state, reward))
        dqn_agent_2.memory.push(Transition(state, action_2, next_state, reward))
        state = next_state
        dqn_agent_1.optimize()
        dqn_agent_1.update()
        dqn_agent_2.optimize()
        dqn_agent_2.update()

        if done :
            print(str(i_episode+1) + "번째 episode : " + str(i_episode_score))
            scores.append(i_episode_score)
            break

1번째 episode : 22
2번째 episode : 22
3번째 episode : 22
4번째 episode : 22
5번째 episode : 22
6번째 episode : 22
7번째 episode : 22
8번째 episode : 22
9번째 episode : 22
10번째 episode : 22
11번째 episode : 22
12번째 episode : 22
13번째 episode : 22
14번째 episode : 22
15번째 episode : 22
16번째 episode : 22
17번째 episode : 22
18번째 episode : 22
19번째 episode : 22
20번째 episode : 22
21번째 episode : 22
22번째 episode : 22
23번째 episode : 22
24번째 episode : 22
25번째 episode : 22
26번째 episode : 22
27번째 episode : 22
28번째 episode : 22
29번째 episode : 22
30번째 episode : 22
31번째 episode : 22
32번째 episode : 22
33번째 episode : 22
34번째 episode : 22
35번째 episode : 22
36번째 episode : 22
37번째 episode : 22
38번째 episode : 22
39번째 episode : 22
40번째 episode : 22
41번째 episode : 22
42번째 episode : 22
43번째 episode : 22
44번째 episode : 22
45번째 episode : 22
46번째 episode : 22
47번째 episode : 22
48번째 episode : 22
49번째 episode : 22
50번째 episode : 22
51번째 episode : 22
52번째 episode : 22
53번째 episode : 22
54번째 episode : 22
55번째 episode : 22
56번째 episode : 22
5