In [2]:
from unittest.util import _count_diff_hashable
import numpy as np
import os
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch as th
from utils.logging import get_logger
import yaml

from run import run

from run import args_sanity_check
from types import SimpleNamespace as SN

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.episode_buffer import Prioritized_ReplayBuffer

from envs import REGISTRY as env_REGISTRY
from functools import partial

from components.episodic_memory_buffer import Episodic_memory_buffer
from components.transforms import OneHot


SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console
logger = get_logger()

ex = Experiment('pymarl', interactive=True)
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds

pygame 2.3.0 (SDL 2.24.2, Python 3.7.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


[DEBUG 17:13:36] git.cmd Popen(['git', 'version'], cwd=/home/hellone/Project/emc/pymarl/src, universal_newlines=False, shell=None, istream=None)
[DEBUG 17:13:36] git.cmd Popen(['git', 'version'], cwd=/home/hellone/Project/emc/pymarl/src, universal_newlines=False, shell=None, istream=None)
[DEBUG 17:13:36] git.util Failed checking if running in CYGWIN due to: FileNotFoundError(2, "No such file or directory: '/usr/bin/uname'")
[DEBUG 17:13:36] git.cmd Popen(['git', 'diff', '--cached', '--abbrev=40', '--full-index', '--raw'], cwd=/home/hellone/Project/emc, universal_newlines=False, shell=None, istream=None)
[DEBUG 17:13:36] git.cmd Popen(['git', 'diff', '--abbrev=40', '--full-index', '--raw'], cwd=/home/hellone/Project/emc, universal_newlines=False, shell=None, istream=None)
[DEBUG 17:13:36] git.cmd Popen(['git', 'cat-file', '--batch-check'], cwd=/home/hellone/Project/emc, universal_newlines=False, shell=None, istream=<valid stream>)
[DEBUG 17:13:36] git.cmd Popen(['git', 'diff', '--cache

In [2]:
__file__ = '/home/hellone/Project/ReinforcementLearning/EpisodicCuriosity/EMC/pymarl/src/'
results_path = os.path.join(dirname(dirname(abspath('/home/hellone/Project/ReinforcementLearning/EpisodicCuriosity/EMC/pymarl/src/'))), "results_debug")

In [3]:
argvs = ['debug.py', '--config=EMC_toygame', '--env-config=gridworld_reversed', 'with', 'env_args.map_name=reversed']
params = deepcopy(argvs)
params

['debug.py',
 '--config=EMC_toygame',
 '--env-config=gridworld_reversed',
 'with',
 'env_args.map_name=reversed']

In [4]:
def _get_config_env(params, arg_name, subfolder):
    config_name = None
    for _i, _v in enumerate(params):
        if _v.split("=")[0] == arg_name:
            config_name = _v.split("=")[1]
            del params[_i]
            break

    if config_name is not None:
        with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
            try:
                config_dict = yaml.load(f)
            except yaml.YAMLError as exc:
                assert False, "{}.yaml error: {}".format(config_name, exc)
        return config_dict

def _get_config_alg(params, arg_name, subfolder,map_name):
    config_name = None
    for _i, _v in enumerate(params):
        if _v.split("=")[0] == arg_name:
            config_name = _v.split("=")[1]
            del params[_i]
            break

    if map_name=="3s5z_vs_3s6z":
        config_name="EMC_sc2_3s5z_vs_3s6z"
    elif map_name=="6h_vs_8z":
        config_name="EMC_sc2_6h_vs_8z"
    elif map_name=="corridor":
        config_name="EMC_sc2_corridor"
    elif map_name=="origin":
        config_name="EMC_toygame"
    elif map_name=="reversed":
        config_name="EMC_toygame"


    

    if config_name is not None:
        with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
            try:
                config_dict = yaml.load(f)
            except yaml.YAMLError as exc:
                assert False, "{}.yaml error: {}".format(config_name, exc)
        return config_dict


def recursive_dict_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.Mapping):
            d[k] = recursive_dict_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def config_copy(config):
    if isinstance(config, dict):
        return {k: config_copy(v) for k, v in config.items()}
    elif isinstance(config, list):
        return [config_copy(v) for v in config]
    else:
        return deepcopy(config)

In [5]:
with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
    try:
        config_dict = yaml.load(f)
    except yaml.YAMLError as exc:
        assert False, "default.yaml error: {}".format(exc)

env_config= _get_config_env(params, "--env-config", "envs")
config_dict = recursive_dict_update(config_dict, env_config)
map_name="3m"
for _i, _v in enumerate(params):
    if _v.split("=")[0] == "env_args.map_name":
        map_name = _v.split("=")[1]

print("Map Name:",map_name)
alg_config = _get_config_alg(params, "--config", "algs",map_name)
# config_dict = {**config_dict, **env_config, **alg_config}

config_dict = recursive_dict_update(config_dict, alg_config)

Map Name: reversed




In [6]:
config_dict['seed'] = 427095566

In [7]:
config = config_copy(config_dict)
np.random.seed(config["seed"])
th.manual_seed(config["seed"])
config['env_args']['seed'] = config["seed"]

In [8]:
args = SN(**config)
args.use_cuda = False
args.device = "cuda" if args.use_cuda else "cpu"
set_device = os.getenv('SET_DEVICE')
if args.use_cuda and set_device != '-1':
    if set_device is None:
        args.device = "cuda"
    else:
        args.device = f"cuda:{set_device}"
else:
    args.device = "cpu"

In [9]:
args.buffer_cpu_only = True

In [10]:
    # Init runner so we can get env info
runner = r_REGISTRY[args.runner](args=args, logger=logger)

# Set up schemes and groups here
env_info = runner.get_env_info()
args.episode_limit = env_info["episode_limit"]
args.n_agents = env_info["n_agents"]
args.n_actions = env_info["n_actions"]
args.state_shape = env_info["state_shape"]
args.unit_dim = env_info["unit_dim"]

# Default/Base scheme
scheme = {
    "state": {"vshape": env_info["state_shape"]},
    "obs": {"vshape": env_info["obs_shape"], "group": "agents"},
    "actions": {"vshape": (1,), "group": "agents", "dtype": th.long},
    "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int},
    "reward": {"vshape": (1,)},
    "terminated": {"vshape": (1,), "dtype": th.uint8},
}
groups = {
    "agents": args.n_agents
}
preprocess = {
    "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)])
}

env_name = args.env

print(scheme)


{'state': {'vshape': 92}, 'obs': {'vshape': 46, 'group': 'agents'}, 'actions': {'vshape': (1,), 'group': 'agents', 'dtype': torch.int64}, 'avail_actions': {'vshape': (5,), 'group': 'agents', 'dtype': torch.int32}, 'reward': {'vshape': (1,)}, 'terminated': {'vshape': (1,), 'dtype': torch.uint8}}


In [11]:
buffer = Prioritized_ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1,
                                        args.prioritized_buffer_alpha,
                                        preprocess=preprocess,
                                        device="cpu" if args.buffer_cpu_only else args.device)
ec_buffer=Episodic_memory_buffer(args,scheme)
mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args)
# Give runner the scheme
# runner = episode
runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac)
# Learner
learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args, groups=groups)

In [12]:
args.gener_goal_interval = 2

In [13]:
if hasattr(args, "save_buffer") and args.save_buffer:
    learner.buffer = buffer

if args.use_cuda:
    learner.cuda()


# start training
episode = 0
last_test_T = -args.test_interval - 1
last_log_T = 0
model_save_time = 0

In [None]:
getattr(args, "use_emdqn", False)

In [27]:
episode_batch = runner.run(test_mode=False)

  v = th.tensor(v, dtype=dtype, device=self.device)


In [47]:
buffer.insert_episode_batch(episode_batch)

  v = th.tensor(v, dtype=dtype, device=self.device)


In [52]:
print(args.is_save_buffer)
print(args.is_prioritized_buffer)
print(args.agent_output_type)
print(hasattr(args, 'use_individual_Q') and args.use_individual_Q)
print(args.rnn_hidden_dim)
print('n_actions: ', args.n_actions)

False
True
q
False
64
n_actions:  5


In [53]:
actions = mac.select_actions(episode_batch, t_ep=0, t_env=0, test_mode=False)

In [61]:
mac_out = mac.forward(episode_batch, 1)

In [66]:
# 输出的n=2个智能体的各个动作的Q值
mac_out
# shape (1,2,5)

tensor([[[ 0.0864, -0.1351, -0.0448,  0.1170, -0.1308],
         [ 0.0347, -0.1361, -0.0332,  0.0919, -0.1402]]],
       grad_fn=<ViewBackward>)

In [63]:
actions

tensor([[1, 2]])

In [68]:
mac_out[:, :-1]

tensor([[[ 0.0864, -0.1351, -0.0448,  0.1170, -0.1308]]],
       grad_fn=<SliceBackward>)

In [67]:
chosen_action_qvals = th.gather(mac_out[:, :-1], dim=2, index=actions).squeeze(3) 
chosen_action_qvals

RuntimeError: Index tensor must have the same number of dimensions as input tensor

In [46]:
for i in episode_batch['actions'][0]:
    print(i[0][0], ',')
# episode_batch['actions'][0][20][0][0]

tensor(4) ,
tensor(4) ,
tensor(4) ,
tensor(4) ,
tensor(4) ,
tensor(3) ,
tensor(2) ,
tensor(1) ,
tensor(4) ,
tensor(0) ,
tensor(4) ,
tensor(4) ,
tensor(4) ,
tensor(1) ,
tensor(0) ,
tensor(1) ,
tensor(0) ,
tensor(1) ,
tensor(3) ,
tensor(2) ,
tensor(0) ,
tensor(3) ,
tensor(4) ,
tensor(1) ,
tensor(3) ,
tensor(4) ,
tensor(3) ,
tensor(3) ,
tensor(2) ,
tensor(2) ,
tensor(2) ,


In [26]:
runner.batch_size

1

In [22]:
args.batch_size

32

In [19]:
data = {
    "first": [1, 2 ,3],
    "second": [4, 5, 6],
}

In [20]:
data.items()

dict_items([('first', [1, 2, 3]), ('second', [4, 5, 6])])

In [21]:
for k , v in data.items():
# data.items()
    print(k, '\t', v)

first 	 [1, 2, 3]
second 	 [4, 5, 6]


In [None]:
from envs import REGISTRY as env_REGISTRY
from functools import partial

In [None]:
class stu:
    def __init__(self) -> None:
        self.__mess = 6
        self.__name = 1

    @property
    def mess(self):
        return self.__mess
    
    @property
    def name(self):
        return self.__name

In [None]:
a = stu()

In [None]:
a._name

In [None]:
a.name

In [None]:
a._name = 2
a.name

In [107]:
x = th.tensor([1,2,3.0], requires_grad=True).float()
y = x**2

In [117]:
loss = 2*x[1]-1
loss2 = 4 * x[1] - 1
# loss = 0.

In [112]:
x.grad.zero_()

tensor([0., 0., 0.])

In [121]:
th.tensor([[[5,5,5,6]]]).shape

torch.Size([1, 1, 4])

In [120]:
loss

tensor(3., grad_fn=<SubBackward0>)

In [118]:
loss.backward()

In [119]:
x.grad

tensor([0., 2., 0.])

In [25]:
y.backward()

RuntimeError: grad can be implicitly created only for scalar outputs

In [24]:
x.grad

tensor([2., 4., 6.])

In [12]:
x = th.tensor([[[1,2,3],[4,5,6]]])
a = th.randn((1,31,2,46))

In [15]:
b = a[:,1]
b.shape

torch.Size([1, 2, 46])

In [27]:
c = th.cat([_.reshape(2, 1, -1) for _ in b], dim = 2)
c.shape

torch.Size([2, 1, 46])

In [23]:
for i in b:
    print(i.shape)
    print(i.reshape(2,1,-1).shape)

torch.Size([2, 46])
torch.Size([2, 1, 46])


In [20]:
z = [_.reshape(2, 1, -1) for _ in b]
z

[tensor([[[-1.3974,  0.9128,  0.9448,  1.1784, -0.1011, -1.9543, -0.8169,
           -0.8084, -0.1665,  0.1685, -1.5195, -0.1413,  0.8248,  1.8848,
            0.6537,  0.4770, -0.1100, -0.4075, -0.8264,  1.4410,  1.1888,
            0.8314,  0.4867, -0.5320, -0.5498, -0.6408, -0.1479, -0.3711,
            1.2912,  0.2493, -0.0408, -1.3887, -1.0742, -0.3676, -0.2206,
           -0.2998, -0.4120,  2.1785, -0.3528,  2.2399,  0.7986,  0.4189,
            0.3701, -0.4737, -0.3823, -1.6351]],
 
         [[-0.3844, -1.0168, -0.5389, -1.0042,  0.5496,  1.4545, -0.9078,
            0.2225, -0.8199,  2.0154, -0.7226, -0.5918,  0.4451,  0.5402,
           -1.1657,  0.5409,  2.6998,  1.0856, -1.1007,  0.6558, -2.2412,
            1.1272, -2.3806,  2.3751, -1.4456, -0.0570,  0.3554,  0.4483,
           -0.4134, -1.3400, -1.2609,  0.9387,  0.1290, -0.3728,  0.3171,
            0.7606, -0.7222,  1.8649,  0.3903,  0.9736,  0.0995,  0.3403,
           -0.1093, -0.6419, -0.0509,  0.9853]]])]

In [30]:
v = th.randn((1,1,46))
dest = th.randn(1,1,92)

In [31]:
idx = len(v.shape) - 1
for s in dest.shape[::-1]:
    if v.shape[idx] != s:
        if s != 1:
            raise ValueError("Unsafe reshape of {} to {}".format(v.shape, dest.shape))
    else:
        idx -= 1

ValueError: Unsafe reshape of torch.Size([1, 1, 46]) to torch.Size([1, 1, 92])

In [14]:
g = th.randn((3,5,4))
m = th.randn((3,5,5))

o = th.cat([g, m], dim=-1)
o.shape

torch.Size([3, 5, 9])

In [15]:
g

tensor([[[ 1.0280, -1.3871, -0.6722,  0.2010],
         [ 2.2553, -0.4990, -0.6995, -0.3495],
         [-0.2049, -0.3213,  0.2025,  0.2779],
         [-2.2110, -0.8507,  0.2646,  0.8108],
         [ 0.5326, -1.1670, -0.9495,  1.5615]],

        [[-0.2686, -1.7691,  0.6596, -0.0980],
         [ 0.7310, -0.7556,  0.0057,  0.1665],
         [-0.6148, -1.0748, -1.0956, -2.6119],
         [ 0.9488, -0.1117, -0.2040,  0.5007],
         [ 0.8175,  3.0204,  0.6222, -0.1982]],

        [[ 1.5877,  0.0244, -1.0106, -2.1215],
         [-0.6011,  1.6328,  1.6232, -0.2276],
         [-0.1663,  0.9752, -0.3509, -0.7005],
         [ 0.9254,  0.7500, -1.1478,  1.2592],
         [ 1.5677,  0.2073, -0.2870,  0.2015]]])

In [17]:
gg = g.reshape(15,4).reshape(3,5,4)
gg

tensor([[[ 1.0280, -1.3871, -0.6722,  0.2010],
         [ 2.2553, -0.4990, -0.6995, -0.3495],
         [-0.2049, -0.3213,  0.2025,  0.2779],
         [-2.2110, -0.8507,  0.2646,  0.8108],
         [ 0.5326, -1.1670, -0.9495,  1.5615]],

        [[-0.2686, -1.7691,  0.6596, -0.0980],
         [ 0.7310, -0.7556,  0.0057,  0.1665],
         [-0.6148, -1.0748, -1.0956, -2.6119],
         [ 0.9488, -0.1117, -0.2040,  0.5007],
         [ 0.8175,  3.0204,  0.6222, -0.1982]],

        [[ 1.5877,  0.0244, -1.0106, -2.1215],
         [-0.6011,  1.6328,  1.6232, -0.2276],
         [-0.1663,  0.9752, -0.3509, -0.7005],
         [ 0.9254,  0.7500, -1.1478,  1.2592],
         [ 1.5677,  0.2073, -0.2870,  0.2015]]])

In [8]:
th.max(g, dim=-1)

torch.return_types.max(
values=tensor([[1.7202],
        [0.4227],
        [0.7932]]),
indices=tensor([[2],
        [2],
        [2]]))

In [2]:
c = th.randn((2,2))
# c[1,] = th.argmax(g[1,])

In [80]:
oo = th.cat([th.max(o[:,1,:3],-1)[1], th.max(o[:,1,3:8],-1)[1]], dim=-1)

In [82]:
th.max(o[:,1,:3],-1)[1].shape

torch.Size([3])

In [7]:
ts = th.zeros((1,1,46))
ts

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

In [4]:
c

tensor([[ 0.1384, -1.1889],
        [-0.9080,  0.0974]])

In [6]:
import torch.nn.functional as F

In [47]:
a = th.tensor([[1., 2.],[3., 4.]])
m = th.mean(a, 1)
m.shape

torch.Size([2])

In [60]:
c = th.randn((32,30, 1))
d = th.mean(c,1, True)
d * 10

tensor([[[-3.7337]],

        [[ 1.0963]],

        [[ 0.9674]],

        [[ 3.3269]],

        [[-1.0490]],

        [[ 2.6757]],

        [[-2.0014]],

        [[ 0.0713]],

        [[-2.2090]],

        [[-2.0173]],

        [[ 3.2781]],

        [[ 0.1343]],

        [[-0.9378]],

        [[-0.1153]],

        [[-0.8183]],

        [[-2.6410]],

        [[ 0.8053]],

        [[ 4.0464]],

        [[ 0.0970]],

        [[-0.3055]],

        [[-1.5841]],

        [[-2.7927]],

        [[-0.8880]],

        [[-2.6001]],

        [[ 0.2720]],

        [[-2.6267]],

        [[ 0.8737]],

        [[ 2.8862]],

        [[-0.1407]],

        [[ 2.8194]],

        [[ 0.7722]],

        [[-6.1921]]])

In [63]:
th.mean(c)

tensor(-0.0267)

In [64]:
c.sum() / 32 / 30

tensor(-0.0267)

In [62]:
e = d.multiply(10)
e

tensor([[[-3.7337]],

        [[ 1.0963]],

        [[ 0.9674]],

        [[ 3.3269]],

        [[-1.0490]],

        [[ 2.6757]],

        [[-2.0014]],

        [[ 0.0713]],

        [[-2.2090]],

        [[-2.0173]],

        [[ 3.2781]],

        [[ 0.1343]],

        [[-0.9378]],

        [[-0.1153]],

        [[-0.8183]],

        [[-2.6410]],

        [[ 0.8053]],

        [[ 4.0464]],

        [[ 0.0970]],

        [[-0.3055]],

        [[-1.5841]],

        [[-2.7927]],

        [[-0.8880]],

        [[-2.6001]],

        [[ 0.2720]],

        [[-2.6267]],

        [[ 0.8737]],

        [[ 2.8862]],

        [[-0.1407]],

        [[ 2.8194]],

        [[ 0.7722]],

        [[-6.1921]]])

In [19]:
b = th.tensor([[2.], [5.]])
b.shape

torch.Size([2, 1])

In [18]:
c = F.mse_loss(a, b)
c

tensor(2.5000)

In [21]:
d = (a - b) ** 2
d.sum()

tensor(5.)

In [28]:
a1 = th.randn((2,3,1))
b1 = th.randn((2,3,1))

c1 = (F.mse_loss(a1, b1))
c2 = ((a1 - b1) ** 2).sum()

In [35]:
# a1[:, 1:] = a1[:, :-1]
# a1

tensor([[[0.],
         [0.],
         [0.]],

        [[0.],
         [0.],
         [0.]]])

In [32]:
a1[:, 0] = 0

In [33]:
a1

tensor([[[ 0.0000],
         [-1.0212],
         [-0.9089]],

        [[ 0.0000],
         [ 0.3320],
         [-1.4326]]])

In [30]:
a2 = a1.reshape(6,1)
b2 = b1.reshape(6,1)

c3 = F.mse_loss(a2, b2)
c3

tensor(1.3761)

In [24]:
c2

tensor(9.4873)

In [100]:
b ** 2

tensor([[ 2.2996, 10.1026],
        [ 5.9469, 11.9811]])

In [99]:
(b ** 2).sum()

tensor(30.3303)

In [84]:
c / 2

tensor([[-0.4422],
        [ 1.0000]])

In [18]:
a = g.reshape((-1,))
a

tensor([-0.7967,  1.8744, -0.2290, -0.7707, -1.8532, -0.3989])

In [20]:
import numpy.matlib

In [35]:
a = np.matlib.ones((2,2))
b = np.array([[1,2],[2,1]])

In [51]:
a = np.array([[2, 3]])
print(np.std(a, axis=1))

[0.5]


In [54]:
np.linalg.norm([3,4])

5.0

In [95]:
l = []
l.append(2)
l[0] += 3
sum(l)

5

In [13]:
out = th.tensor([[[ 3.3945,  1.3915,  4.4290,  1.3794,  3.4985],
         [-0.0163, -0.7975,  0.4944, -0.7929, -0.2515]]])
out

SyntaxError: invalid syntax (2806840858.py, line 2)

In [7]:
out.shape

torch.Size([1, 2, 5])

In [10]:
out.max(dim=2)[1]

tensor([[2, 2]])

In [5]:
random_numbers = th.rand_like(out[:, :, 0])

In [6]:
random_numbers

tensor([[0.1386, 0.5706]])

In [9]:
pick_random = (random_numbers < 0).long()
pick_random

tensor([[0, 0]])

In [12]:
picked_action = (1 - pick_random) * out.max(dim=2)[1]
picked_action

tensor([[2, 2]])

In [104]:
o = th.tensor([[[[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
          [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]
          ],
          [[[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
          [[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]
          ]
          ])

In [96]:
o.shape

torch.Size([2, 2, 2, 46])

In [98]:
print(th.max(o[:,:,0,:11], -1)[1], th.max(o[:,:,0,11:23], -1)[1], th.max(o[:,:,1,:11], -1)[1], th.max(o[:,:,1,11:23], -1)[1])

tensor([[2, 1],
        [2, 1]]) tensor([[0, 1],
        [0, 1]]) tensor([[8, 7],
        [8, 7]]) tensor([[10, 11],
        [10, 11]])


In [99]:
c= th.cat([th.max(o[:,:,0,:11], -1)[1], th.max(o[:,:,0,11:23], -1)[1], th.max(o[:,:,1,:11], -1)[1], th.max(o[:,:,1,11:23], -1)[1]], dim=0)
c.reshape(2,2,-1)

tensor([[[ 2,  1,  2,  1],
         [ 0,  1,  0,  1]],

        [[ 8,  7,  8,  7],
         [10, 11, 10, 11]]])

In [105]:
a1 = th.max(o[:,:,0,:11], -1)[1].reshape(2,2,-1)
a2 = th.max(o[:,:,0,11:23], -1)[1].reshape(2,2,-1)
a3 = th.max(o[:,:,1,:11], -1)[1].reshape(2,2,-1)
a4 = th.max(o[:,:,1,11:23], -1)[1].reshape(2,2,-1)

In [106]:
th.cat([a1, a2, a3, a4], dim=-1)

tensor([[[ 2,  0,  8, 10],
         [ 1,  1,  7, 11]],

        [[ 1,  0,  8, 10],
         [ 1,  1,  7, 11]]])