Skip to content
Permalink
Browse files

Add code

  • Loading branch information...
shariqiqbal2810 committed Oct 17, 2018
1 parent 5c7da94 commit f82b34d8503eccfb0d555e691501cc0f69187ef9
@@ -0,0 +1,107 @@
# Repo Specific
models
fig_data
notebooks
multi_run*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
@@ -0,0 +1,281 @@
import torch
import torch.nn.functional as F
from torch.optim import Adam
from utils.misc import soft_update, hard_update, enable_gradients, disable_gradients, sep_clip_grad_norm
from utils.agents import AttentionAgent
from utils.critics import AttentionCritic

MSELoss = torch.nn.MSELoss()

class AttentionSAC(object):
"""
Wrapper class for SAC agents with central attention critic in multi-agent
task
"""
def __init__(self, agent_init_params, sa_size,
gamma=0.95, tau=0.01, attend_tau=0.002, pi_lr=0.01, q_lr=0.01,
reward_scale=10.,
pol_hidden_dim=128,
critic_hidden_dim=128, attend_heads=4,
**kwargs):
"""
Inputs:
agent_init_params (list of dict): List of dicts with parameters to
initialize each agent
num_in_pol (int): Input dimensions to policy
num_out_pol (int): Output dimensions to policy
sa_size (list of (int, int)): Size of state and action space for
each agent
gamma (float): Discount factor
tau (float): Target update rate
pi_lr (float): Learning rate for policy
q_lr (float): Learning rate for critic
reward_scale (float): Scaling for reward (has effect of optimal
policy entropy)
hidden_dim (int): Number of hidden dimensions for networks
"""
self.nagents = len(sa_size)

self.agents = [AttentionAgent(lr=pi_lr,
hidden_dim=pol_hidden_dim,
**params)
for params in agent_init_params]
self.critic = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
attend_heads=attend_heads)
self.target_critic = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim,
attend_heads=attend_heads)
hard_update(self.target_critic, self.critic)
self.critic_optimizer = Adam(self.critic.q_parameters(), lr=q_lr,
weight_decay=1e-3)
self.agent_init_params = agent_init_params
self.gamma = gamma
self.tau = tau
self.attend_tau = attend_tau
self.pi_lr = pi_lr
self.q_lr = q_lr
self.reward_scale = reward_scale
self.pol_dev = 'cpu' # device for policies
self.critic_dev = 'cpu' # device for critics
self.trgt_pol_dev = 'cpu' # device for target policies
self.trgt_critic_dev = 'cpu' # device for target critics
self.niter = 0

@property
def policies(self):
return [a.policy for a in self.agents]

@property
def target_policies(self):
return [a.target_policy for a in self.agents]

def step(self, observations, explore=False):
"""
Take a step forward in environment with all agents
Inputs:
observations: List of observations for each agent
Outputs:
actions: List of actions for each agent
"""
return [a.step(obs, explore=explore) for a, obs in zip(self.agents,
observations)]

def update_critic(self, sample, soft=True, logger=None, **kwargs):
"""
Update central critic for all agents
"""
obs, acs, rews, next_obs, dones = sample
# Q loss
next_acs = []
next_log_pis = []
for pi, ob in zip(self.target_policies, next_obs):
curr_next_ac, curr_next_log_pi = pi(ob, return_log_pi=True)
next_acs.append(curr_next_ac)
next_log_pis.append(curr_next_log_pi)
trgt_critic_in = list(zip(next_obs, next_acs))
critic_in = list(zip(obs, acs))
next_qs = self.target_critic(trgt_critic_in)
critic_rets = self.critic(critic_in, regularize=True,
logger=logger, niter=self.niter)
q_loss = 0
for a_i, nq, log_pi, (pq, regs) in zip(range(self.nagents), next_qs,
next_log_pis, critic_rets):
target_q = (rews[a_i].view(-1, 1) +
self.gamma * nq *
(1 - dones[a_i].view(-1, 1)))
if soft:
target_q -= log_pi / self.reward_scale
q_loss += MSELoss(pq, target_q.detach())
for reg in regs:
q_loss += reg # regularizing attention
q_loss.backward()
sep_clip_grad_norm(self.critic.q_parameters(), 0.5)
self.critic_optimizer.step()
self.critic_optimizer.zero_grad()

if logger is not None:
logger.add_scalar('losses/q_loss', q_loss, self.niter)
self.niter += 1

def update_policies(self, sample, soft=True, logger=None, **kwargs):
obs, acs, rews, next_obs, dones = sample
samp_acs = []
all_probs = []
all_log_pis = []
all_pol_regs = []

for a_i, pi, ob in zip(range(self.nagents), self.policies, obs):
curr_ac, probs, log_pi, pol_regs, ent = pi(
ob, return_all_probs=True, return_log_pi=True,
regularize=True, return_entropy=True)
logger.add_scalar('agent%i/policy_entropy' % a_i, ent,
self.niter)
samp_acs.append(curr_ac)
all_probs.append(probs)
all_log_pis.append(log_pi)
all_pol_regs.append(pol_regs)

critic_in = list(zip(obs, samp_acs))
critic_rets = self.critic(critic_in, return_all_q=True)
for a_i, probs, log_pi, pol_regs, (q, all_q) in zip(range(self.nagents), all_probs,
all_log_pis, all_pol_regs,
critic_rets):
curr_agent = self.agents[a_i]
v = (all_q * probs).sum(dim=1, keepdim=True)
pol_target = q - v
pol_target = (pol_target - pol_target.mean()) / pol_target.std()
if soft:
pol_loss = (log_pi * (log_pi / self.reward_scale - pol_target).detach()).mean()
else:
pol_loss = (log_pi * (-pol_target).detach()).mean()
for reg in pol_regs:
pol_loss += 1e-3 * reg # policy regularization
# don't want critic to accumulate gradients from policy loss
disable_gradients(self.critic)
pol_loss.backward()
enable_gradients(self.critic)

sep_clip_grad_norm(curr_agent.policy.parameters(), 0.5)
curr_agent.policy_optimizer.step()
curr_agent.policy_optimizer.zero_grad()

if logger is not None:
logger.add_scalar('agent%i/losses/pol_loss' % a_i,
pol_loss, self.niter)


def update_all_targets(self):
"""
Update all target networks (called after normal updates have been
performed for each agent)
"""
for target_param, param in zip(self.target_critic.nonattend_parameters(),
self.critic.nonattend_parameters()):
target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)
for target_param, param in zip(self.target_critic.attend_parameters(),
self.critic.attend_parameters()):
target_param.data.copy_(target_param.data * (1.0 - self.attend_tau) + param.data * self.attend_tau)
for a in self.agents:
soft_update(a.target_policy, a.policy, self.tau)

def prep_training(self, device='gpu'):
self.critic.train()
self.target_critic.train()
for a in self.agents:
a.policy.train()
a.target_policy.train()
if device == 'gpu':
fn = lambda x: x.cuda()
else:
fn = lambda x: x.cpu()
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device
if not self.critic_dev == device:
self.critic = fn(self.critic)
self.critic_dev = device
if not self.trgt_pol_dev == device:
for a in self.agents:
a.target_policy = fn(a.target_policy)
self.trgt_pol_dev = device
if not self.trgt_critic_dev == device:
self.target_critic = fn(self.target_critic)
self.trgt_critic_dev = device

def prep_rollouts(self, device='cpu'):
for a in self.agents:
a.policy.eval()
if device == 'gpu':
fn = lambda x: x.cuda()
else:
fn = lambda x: x.cpu()
# only need main policy for rollouts
if not self.pol_dev == device:
for a in self.agents:
a.policy = fn(a.policy)
self.pol_dev = device

def save(self, filename):
"""
Save trained parameters of all agents into one file
"""
self.prep_training(device='cpu') # move parameters to CPU before saving
save_dict = {'init_dict': self.init_dict,
'agent_params': [a.get_params() for a in self.agents],
'critic_params': {'critic': self.critic.state_dict(),
'target_critic': self.target_critic.state_dict(),
'critic_optimizer': self.critic_optimizer.state_dict()}}
torch.save(save_dict, filename)

@classmethod
def init_from_env(cls, env, gamma=0.95, tau=0.01, attend_tau=0.002,
pi_lr=0.01, q_lr=0.01,
reward_scale=10.,
pol_hidden_dim=128, critic_hidden_dim=128, attend_heads=4,
**kwargs):
"""
Instantiate instance of this class from multi-agent environment
env: Multi-agent Gym environment
gamma: discount factor
tau: rate of update for target networks
lr: learning rate for networks
hidden_dim: number of hidden dimensions for networks
"""
agent_init_params = []
sa_size = []
for acsp, obsp in zip(env.action_space,
env.observation_space):
agent_init_params.append({'num_in_pol': obsp.shape[0],
'num_out_pol': acsp.n})
sa_size.append((obsp.shape[0], acsp.n))

init_dict = {'gamma': gamma, 'tau': tau, 'attend_tau': attend_tau,
'pi_lr': pi_lr, 'q_lr': q_lr,
'reward_scale': reward_scale,
'pol_hidden_dim': pol_hidden_dim,
'critic_hidden_dim': critic_hidden_dim,
'attend_heads': attend_heads,
'agent_init_params': agent_init_params,
'sa_size': sa_size}
instance = cls(**init_dict)
instance.init_dict = init_dict
return instance

@classmethod
def init_from_save(cls, filename, load_critic=False):
"""
Instantiate instance of this class from file created by 'save' method
"""
save_dict = torch.load(filename)
instance = cls(**save_dict['init_dict'])
instance.init_dict = save_dict['init_dict']
for a, params in zip(instance.agents, save_dict['agent_params']):
a.load_params(params)

if load_critic:
critic_params = save_dict['critic_params']
instance.critic.load_state_dict(critic_params['critic'])
instance.target_critic.load_state_dict(critic_params['target_critic'])
instance.critic_optimizer.load_state_dict(critic_params['critic_optimizer'])
return instance
@@ -0,0 +1,7 @@
import imp
import os.path as osp


def load(name):
pathname = osp.join(osp.dirname(__file__), name)
return imp.load_source('', pathname)

2 comments on commit f82b34d

@YaozuGen

This comment has been minimized.

Copy link

replied Apr 11, 2019

utils/env_wrappers.py confused me.
worker() method has arg env_fn_wrapper which is inherited from multiagent.environment.MultiAgentEnv.
while MultiAgentEnv inherited from gym.Env which is a class with abstract method and MultiAgent did not override these method.So in worker() method ,codes below is proper:

def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env._step(data)
if all(done):
ob = env._reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env._reset()#2019/4/10,change from reset()to_reset()
remote.send(ob)
elif cmd == 'reset_task':
ob = env.reset_task()#does not have this fun
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
elif cmd == 'get_agent_types':
if all([hasattr(a, 'adversary') for a in env.agents]):
remote.send(['adversary' if a.adversary else 'agent' for a in
env.agents])
else:
remote.send(['agent' for _ in env.agents])
else:
raise NotImplementedError

@YaozuGen

This comment has been minimized.

Copy link

replied Apr 11, 2019

logs directory result like this .\models\fullobs_collect_treasure\fct\run1\logs\agent7\attention\head3_entropy
it seems unreasonable.
Besides can you upload a policy network validation script that can load file "model.pt"?

Please sign in to comment.
You can’t perform that action at this time.