In [103]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
from collections import namedtuple, deque
from matplotlib import pyplot as plt
from pettingzoo.atari import space_invaders_v2
from pettingzoo.sisl import pursuit_v4
import gym
from gym.wrappers import RecordVideo

# Testing agent trained for parameter sharing with free evaders

In [104]:
class DQN(nn.Module):

    def __init__(self, n_states, n_actions):
        super().__init__()
        self.layer1 = nn.Linear(n_states, 64)
        self.layer2 = nn.Linear(64, 64)
        self.out = nn.Linear(64,n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.out(x)
        return x

In [105]:
env = pursuit_v4.env(max_cycles=500, x_size=16, y_size=16, shared_reward=True, n_evaders=25,
n_pursuers=6,obs_range=7, n_catch=2, freeze_evaders=False, tag_reward=0.01,
catch_reward=5.0, urgency_reward=0, surround=True, constraint_window=1.0, render_mode = 'rgb_array')

#env = RecordVideo(env, "DQN_combined_500.mp4")

In [106]:
policy_net = DQN(3*7*7, 5)
policy_net.load_state_dict(torch.load("DQN_combined_500"))
policy_net.eval()

env.reset()
time = 0
frames = []
for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    observation = torch.tensor(observation).to(torch.float)
    observation = (torch.flatten(observation)).unsqueeze(0)
    if termination or truncation:
        a = None
    else:
        a = (policy_net(observation).max(1)[1]).item()
    env.step(a)
    frames.append(env.render())
    time+=1
#     if time>=3000:
#         break
env.close()

In [107]:
from PIL import Image

In [109]:
frames = [Image.fromarray(frame) for frame in frames]
frame_one = frames[0]
frame_one.save("DQN_combined_500.gif", format="GIF", append_images=frames,
               save_all=True, duration=0.5, loop=0)

# Testing agent trained for parameter sharing with frozen evaders

In [116]:
from PIL import Image

env = pursuit_v4.env(max_cycles=500, x_size=16, y_size=16, shared_reward=True, n_evaders=25,
n_pursuers=6,obs_range=7, n_catch=2, freeze_evaders=True, tag_reward=0.01,
catch_reward=5.0, urgency_reward=0, surround=True, constraint_window=1.0, render_mode = 'rgb_array')

#env = RecordVideo(env, "DQN_combined_500.mp4")

policy_net = DQN(3*7*7, 5)
policy_net.load_state_dict(torch.load("DQN_combined_500_frozen"))
policy_net.eval()

env.reset()
time = 0
frames = []
for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    observation = torch.tensor(observation).to(torch.float)
    observation = (torch.flatten(observation)).unsqueeze(0)
    if termination or truncation:
        a = None
    else:
        a = (policy_net(observation).max(1)[1]).item()
    env.step(a)
    frames.append(env.render())
    time+=1
#     if time>=3000:
#         break
env.close()

In [117]:
frames = [Image.fromarray(frame) for frame in frames]
frame_one = frames[0]
frame_one.save("DQN_combined_500_frozen.gif", format="GIF", append_images=frames,
               save_all=True, duration=0.5, loop=0)

# Testing agent trained for separate networks with free evaders

In [118]:
class DQN(nn.Module):

    def __init__(self, n_states, n_actions):
        super().__init__()
        self.layer1 = nn.Linear(n_states, 128)
        self.layer2 = nn.Linear(128,64)
        self.layer3 = nn.Linear(64, 64)
        self.out = nn.Linear(64,n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = self.out(x)
        return x

In [122]:
env = pursuit_v4.env(max_cycles=500, x_size=16, y_size=16, shared_reward=True, n_evaders=25,
n_pursuers=6,obs_range=7, n_catch=2, freeze_evaders=False, tag_reward=0.01,
catch_reward=5.0, urgency_reward=0, surround=True, constraint_window=1.0, render_mode = 'rgb_array')

policy_net = [DQN(3*7*7, 5)  for _ in range(6)]
for i in range(6):
    policy_net[i].load_state_dict(torch.load("DQN_separate"+str(i)))
    policy_net[i].eval()

env.reset()
time = 0
frames = []
for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()
    observation = torch.tensor(observation).to(torch.float)
    observation = (torch.flatten(observation)).unsqueeze(0)
    if termination or truncation:
        a = None
    else:
        a = (policy_net[int(str(agent)[-1])](observation).max(1)[1]).item()
    env.step(a)
    frames.append(env.render())
    time+=1
#     if time>=3000:
#         break
env.close()

In [123]:
frames = [Image.fromarray(frame) for frame in frames]
frame_one = frames[0]
frame_one.save("DQN_separate.gif", format="GIF", append_images=frames,
               save_all=True, duration=0.5, loop=0)