In [2]:
import torch as th
from torch import nn
import torch.nn.functional as F
from icecream import ic 
from types import SimpleNamespace as SN
import os

In [3]:
class MLPAgent(nn.Module):
    def __init__(self, input_shape, args) -> None:
        super().__init__()
        self.args = args
        self.fc1 = nn.Linear(input_shape, args.hidden_dim)
        self.fc2 = nn.Linear(args.hidden_dim, args.hidden_dim)
        self.fc3 = nn.Linear(args.hidden_dim, args.n_actions)
        

    def forward(self, inputs, proxy_z=None):
        if proxy_z is not None:
            assert inputs.shape[:-1] == proxy_z.shape[:-1], ic(inputs.shape, proxy_z.shape)
            inputs = th.cat((inputs, proxy_z), dim=-1)
        else:
            # assert self.args.use_encoder is False
            pass
        x = F.relu(self.fc1(inputs))
        h = F.relu(self.fc2(x))
        q = self.fc3(h)
        return q

In [4]:
class MLPNSAgent(nn.Module):
    def __init__(self, input_shape, args, train_teammate=True) -> None:
        super().__init__()
        self.args = args
        self.n_agents = args.n_agents
        self.n_control = args.n_control
        if train_teammate:
            self.n_control = self.n_agents
        self.input_shape = input_shape
        self.agents = th.nn.ModuleList([MLPAgent(input_shape, args) for _ in range(self.n_control)])
        #ic("length of MLPNS Agent:", len(self.agents))

    def forward(self, inputs, proxy_z=None):
        qs = []
        if proxy_z is not None:
            assert inputs.shape[:-1] == proxy_z.shape[:-1], print(inputs.shape, proxy_z.shape)
            if inputs.size(0) == self.n_control:
                for i in range(self.n_control):
                    #ic(self.input_shape)
                    #ic(inputs[i].unsqueeze(0).shape, proxy_z[i].unsqueeze(0).shape)
                    q = self.agents[i](inputs[i].unsqueeze(0), proxy_z[i].unsqueeze(0))
                    #assert 0
                    qs.append(q)
                return th.cat(qs)
            else:
                for i in range(self.n_control):
                    inputs = inputs.view(-1, self.n_control, self.args.obs_shape)
                    proxy_z = proxy_z.view(-1, self.n_control, self.args.proxy_z_dim)
                    q = self.agents[i](inputs[:, i], proxy_z[:, i])
                    qs.append(q.unsqueeze(1))
                return th.cat(qs, dim=-1).view(-1, q.size(-1))
        else:
            #assert self.args.use_encoder is False
            if inputs.size(0) == self.n_control:
                for i in range(self.n_control):
                    #ic(self.input_shape)
                    #ic(inputs[i].unsqueeze(0).shape, proxy_z[i].unsqueeze(0).shape)
                    q = self.agents[i](inputs[i].unsqueeze(0), None)
                    #assert 0
                    qs.append(q)
                return th.cat(qs)
            else:
                for i in range(self.n_control):
                    inputs = inputs.view(-1, self.n_control, self.args.obs_shape)
                    q = self.agents[i](inputs[:, i], None)
                    qs.append(q.unsqueeze(1))
                return th.cat(qs, dim=-1).view(-1, q.size(-1))

    def cuda(self, device=None):
        if not device:
            device = self.args.device
        for a in self.agents:
            a.cuda(device=device)
    
    def freeze(self):
        for param in self.parameters():
            param.requires_grad = True

In [4]:
lbf_args = {"hidden_dim": 64,
            "n_actions": 6,
            "n_agents": 4,
            "n_control": 4}
lbf_args = SN(**lbf_args)

load_team = MLPNSAgent(input_shape=27, args=lbf_args)
load_path = "pretrain_checkpoint/lbf_6x6_4p5f/pretrain_teammate_path/01/agent.th"
load_team.load_state_dict(th.load(load_path, map_location=lambda storage, loc: storage))

lbf_args.n_control = 2
save_control_team = MLPNSAgent(input_shape=27, args=lbf_args, train_teammate=False)
for i in range(2):
    for load_param, save_param in zip(load_team.parameters(), save_control_team.parameters()):
        save_param.data.copy_(load_param.data)

In [5]:
save_path = "debug/performance_drop/lbf"
th.save(save_control_team.state_dict(), "{}/controllable_agent.th".format(save_path))
th.save(load_team.state_dict(), "{}/team.th".format(save_path))

In [6]:
st_args = {"hidden_dim": 64,
            "n_actions": 5,
            "n_agents": 3,
            "n_control": 3}
st_args = SN(**st_args)

load_team = MLPNSAgent(input_shape=16, args=st_args)
load_path = "pretrain_checkpoint/SimpleTag-1good-3adv/pretrain_teammate_path/01/agent.th"
load_team.load_state_dict(th.load(load_path, map_location=lambda storage, loc: storage))

<All keys matched successfully>

In [7]:
st_args.n_control = 2
save_control_team = MLPNSAgent(input_shape=16, args=st_args, train_teammate=False)
for i in range(2):
    for load_param, save_param in zip(load_team.parameters(), save_control_team.parameters()):
        save_param.data.copy_(load_param.data)

In [9]:
save_path = "debug/performance_drop/mpe_stag"
os.makedirs(save_path, exist_ok=True)
th.save(save_control_team.state_dict(), "{}/controllable_agent.th".format(save_path))
th.save(load_team.state_dict(), "{}/team.th".format(save_path))

In [10]:
class Player:
    def __init__(self):
        self.controller = None
        self.position = None
        self.level = None
        self.field_size = None
        self.score = None
        self.reward = 0
        self.history = None
        self.current_step = None

        self.active = False

    def setup(self, position, level, field_size):
        self.history = []
        self.position = position
        self.level = level
        self.field_size = field_size
        self.score = 0

    def set_controller(self, controller):
        self.controller = controller

    def step(self, obs):
        return self.controller._step(obs)

In [11]:
x = [Player(), Player(), Player()]
x[0].active = True

In [12]:
for id in range(3):
    print(x[id].active)

True
False
False


In [13]:
for player in x:
    print(player.active)

True
False
False


In [14]:
for player in x:
    player.active = True

In [15]:
for id in range(3):
    print(x[id].active)

True
True
True


In [24]:
import numpy as np
class Test_class:
    def __init__(self) -> None:
        self.cnt = np.random.choice(range(10), 1).item()
    
    def step(self):
        self.cnt -= 1

In [25]:
a = Test_class()
b = Test_class()
a.cnt, b.cnt

(1, 9)

In [29]:
x = a
x.step()
a.cnt

-3