In [1]:
#SAC implementation

In [305]:
import socket
from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, Callable, Dict, List, NamedTuple, Tuple, Union
import warnings

from torch.nn.utils import clip_grad_norm_
import argparse
import os
import random
import subprocess
import time
from distutils.util import strtobool
from typing import List

import asr_pb2

import numpy as np
import pandas as pd
import torch.nn.functional as F
from collections import defaultdict
import torch
from torch import nn
from torch import optim
import torch as th
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter
from gym.spaces import MultiDiscrete
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder

from gevent import monkey
import copy

from gym_microrts import microrts_ai
from gym_microrts.envs.vec_env import (
    MicroRTSGridModeSharedMemVecEnv as MicroRTSGridModeVecEnv,
)
from stable_baselines3.common.vec_env import VecEnvWrapper, VecMonitor, VecVideoRecorder
from stable_baselines3.common.buffers import BaseBuffer

from gym import spaces

from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape
# from stable_baselines3.common.type_aliases import (
#     DictReplayBufferSamples,
#     DictRolloutBufferSamples,
#     ReplayBufferSamples,
#     RolloutBufferSamples,
# )
from stable_baselines3.common.vec_env import VecNormalize

try:
    # Check memory used by replay buffer when possible
    import psutil
except ImportError:
    psutil = None

In [212]:
# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=None, mask_value=None):
        if masks is None:
            masks = []
        logits = torch.where(masks.bool(), logits, mask_value)
        super().__init__(probs, logits, validate_args)

class Transpose(nn.Module):
    def __init__(self, permutation):
        super().__init__()
        self.permutation = permutation

    def forward(self, x):
        return x.permute(self.permutation)


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class MicroRTSStatsRecorder(VecEnvWrapper):
    def __init__(self, env, gamma=0.99) -> None:
        super().__init__(env)
        self.gamma = gamma

    def reset(self):
        obs = self.venv.reset()
        self.raw_rewards = [[] for _ in range(self.num_envs)]
        self.ts = np.zeros(self.num_envs, dtype=np.float32)
        self.raw_discount_rewards = [[] for _ in range(self.num_envs)]
        return obs

    def step_wait(self):
        obs, rews, dones, infos = self.venv.step_wait()
        newinfos = list(infos[:])
        for i, done in enumerate(dones):
            self.raw_rewards[i] += [infos[i]["raw_rewards"]]
            self.raw_discount_rewards[i] += [
                (self.gamma ** self.ts[i])
                * np.concatenate((infos[i]["raw_rewards"], infos[i]["raw_rewards"].sum()), axis=None)
            ]
            self.ts[i] += 1
            if done:
                info = infos[i].copy()
                raw_returns = np.array(self.raw_rewards[i]).sum(0)
                raw_names = [str(rf) for rf in self.rfs]
                raw_discount_returns = np.array(self.raw_discount_rewards[i]).sum(0)
                raw_discount_names = ["discounted_" + str(rf) for rf in self.rfs] + ["discounted"]
                info["microrts_stats"] = dict(zip(raw_names, raw_returns))
                info["microrts_stats"].update(dict(zip(raw_discount_names, raw_discount_returns)))
                self.raw_rewards[i] = []
                self.raw_discount_rewards[i] = []
                self.ts[i] = 0
                newinfos[i] = info
        return obs, rews, dones, newinfos

    
if __name__ == "__main__":


    #parameters
    seed = 9
    torch_deterministic = True
    num_selfplay_envs = 2
    num_bot_envs = 0
    partial_obs  = False
    train_maps = ["maps/16x16/basesWorkers16x16A.xml"]
    gamma = 0.99
    q_lr = 3e-4
    policy_lr = 2.5e-4
    eps=1e-5
    buffer_size = int(1e6)
    total_timesteps = 5000000
    learning_starts = int(2e4)
    update_frequency  = 4
    batch_size = 64
    mapsize = 16 * 16
    target_network_frequency = 8000
    tau = 1.0
    
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = torch_deterministic

    envs = MicroRTSGridModeVecEnv(
            num_selfplay_envs=num_selfplay_envs,
            num_bot_envs=num_bot_envs,
            partial_obs=partial_obs,
            max_steps=2000,
            render_theme=2,
            ai2s=[microrts_ai.coacAI for _ in range(num_bot_envs - 6)]
            + [microrts_ai.randomBiasedAI for _ in range(min(num_bot_envs, 2))]
            + [microrts_ai.lightRushAI for _ in range(min(num_bot_envs, 2))]
            + [microrts_ai.workerRushAI for _ in range(min(num_bot_envs, 2))],
            map_paths=[train_maps[0]],
            reward_weight=np.array([10.0, 1.0, 1.0, 0.2, 1.0, 4.0]),
            cycle_maps=train_maps,
        )
    envs = MicroRTSStatsRecorder(envs, gamma)
    envs = VecMonitor(envs)

In [None]:
#https://github.com/timoklein/cleanrl/blob/sac-discrete/cleanrl/sac_atari.py
#https://github.com/pranz24/pytorch-soft-actor-critic
#https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/sac_continuous_action.py

In [170]:
class ReplayBufferSamples(NamedTuple):
    observations: th.Tensor
    actions: th.Tensor
    masks: th.Tensor
    next_observations: th.Tensor
    dones: th.Tensor
    rewards: th.Tensor
        
        

#class that handles masks
class ReplayBufferActs(BaseBuffer):
    def __init__(
        self,
        buffer_size: int,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        masks_space: spaces.Space,
        device: Union[th.device, str] = "auto",
        n_envs: int = 1,
        optimize_memory_usage: bool = False,
        handle_timeout_termination: bool = True,
    ):
        super().__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)

        # Adjust buffer size
        self.buffer_size = max(buffer_size // n_envs, 1)

        # Check that the replay buffer can fit into the memory
        if psutil is not None:
            mem_available = psutil.virtual_memory().available

        # there is a bug if both optimize_memory_usage and handle_timeout_termination are true
        # see https://github.com/DLR-RM/stable-baselines3/issues/934
        if optimize_memory_usage and handle_timeout_termination:
            raise ValueError(
                "ReplayBuffer does not support optimize_memory_usage = True "
                "and handle_timeout_termination = True simultaneously."
            )
        self.optimize_memory_usage = optimize_memory_usage

        self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)

        if optimize_memory_usage:
            # `observations` contains also the next observation
            self.next_observations = None
        else:
            self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype)

        self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype)
        
        
        
        self.masks = np.zeros((self.buffer_size, self.n_envs) +  masks_space, dtype=np.float32)
        
        
        
        self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
        # Handle timeouts termination properly if needed
        # see https://github.com/DLR-RM/stable-baselines3/issues/284
        self.handle_timeout_termination = handle_timeout_termination
        self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)

        if psutil is not None:
            total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.masks.nbytes +  self.rewards.nbytes + self.dones.nbytes

            if self.next_observations is not None:
                total_memory_usage += self.next_observations.nbytes

            if total_memory_usage > mem_available:
                # Convert to GB
                total_memory_usage /= 1e9
                mem_available /= 1e9
                warnings.warn(
                    "This system does not have apparently enough memory to store the complete "
                    f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB"
                )

    def add(
        self,
        obs: np.ndarray,
        next_obs: np.ndarray,
        action: np.ndarray,
        mask: np.ndarray,
        reward: np.ndarray,
        done: np.ndarray,
        infos: List[Dict[str, Any]],
    ) -> None:

        # Reshape needed when using multiple envs with discrete observations
        # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
        if isinstance(self.observation_space, spaces.Discrete):
            obs = obs.reshape((self.n_envs,) + self.obs_shape)
            next_obs = next_obs.reshape((self.n_envs,) + self.obs_shape)

        # Same, for actions
        action = action.reshape((self.n_envs, self.action_dim))

        # Copy to avoid modification by reference
        self.observations[self.pos] = np.array(obs).copy()

        if self.optimize_memory_usage:
            self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy()
        else:
            self.next_observations[self.pos] = np.array(next_obs).copy()

        self.actions[self.pos] = np.array(action).copy()
        self.masks[self.pos] = np.array(mask).copy()
        self.rewards[self.pos] = np.array(reward).copy()
        self.dones[self.pos] = np.array(done).copy()

        if self.handle_timeout_termination:
            self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos])

        self.pos += 1
        if self.pos == self.buffer_size:
            self.full = True
            self.pos = 0

    def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
        """
        Sample elements from the replay buffer.
        Custom sampling when using memory efficient variant,
        as we should not sample the element with index `self.pos`
        See https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274
        :param batch_size: Number of element to sample
        :param env: associated gym VecEnv
            to normalize the observations/rewards when sampling
        :return:
        """
        if not self.optimize_memory_usage:
            return super().sample(batch_size=batch_size, env=env)
        # Do not sample the element with index `self.pos` as the transitions is invalid
        # (we use only one array to store `obs` and `next_obs`)
        if self.full:
            batch_inds = (np.random.randint(1, self.buffer_size, size=batch_size) + self.pos) % self.buffer_size
        else:
            batch_inds = np.random.randint(0, self.pos, size=batch_size)
        return self._get_samples(batch_inds, env=env)

    def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples:
        # Sample randomly the env idx
        env_indices = np.random.randint(0, high=self.n_envs, size=(len(batch_inds),))

        if self.optimize_memory_usage:
            next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, env_indices, :], env)
        else:
            next_obs = self._normalize_obs(self.next_observations[batch_inds, env_indices, :], env)

        data = (
            self._normalize_obs(self.observations[batch_inds, env_indices, :], env),
            self.actions[batch_inds, env_indices, :],
            self.masks[batch_inds, env_indices, :],
            next_obs,
            # Only use dones that are not due to timeouts
            # deactivated by default (timeouts is initialized as an array of False)
            (self.dones[batch_inds, env_indices] * (1 - self.timeouts[batch_inds, env_indices])).reshape(-1, 1),
            self._normalize_reward(self.rewards[batch_inds, env_indices].reshape(-1, 1), env),
        )
        return ReplayBufferSamples(*tuple(map(self.to_torch, data)))

In [227]:

class Actor(nn.Module):
    def __init__(self, envs, mapsize=16 * 16):
        super().__init__()
        self.mapsize = mapsize
        h, w, c = envs.observation_space.shape
        self.encoder = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.actor = nn.Sequential(
            layer_init(nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(32, 78, 3, stride=2, padding=1, output_padding=1)),
            Transpose((0, 2, 3, 1)),
        )
        self.register_buffer("mask_value", torch.tensor(-1e8))

    def get_action(self, x, action=None, invalid_action_masks= None, envs=None, device=None):
        hidden = self.encoder(x)
        logits = self.actor(hidden)
        grid_logits = logits.reshape(-1, envs.action_plane_space.nvec.sum())
        split_logits = torch.split(grid_logits, envs.action_plane_space.nvec.tolist(), dim=1)

        if action is None:
            invalid_action_masks = invalid_action_masks.view(-1, invalid_action_masks.shape[-1])
            split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_plane_space.nvec.tolist(), dim=1)
            multi_categoricals = [
                CategoricalMasked(logits=logits, masks=iam, mask_value=self.mask_value)
                for (logits, iam) in zip(split_logits, split_invalid_action_masks)
            ]
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        else:
            invalid_action_masks = invalid_action_masks.view(-1, invalid_action_masks.shape[-1])
            action = action.view(-1, action.shape[-1]).T
            split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_plane_space.nvec.tolist(), dim=1)
            multi_categoricals = [
                CategoricalMasked(logits=logits, masks=iam, mask_value=self.mask_value)
                for (logits, iam) in zip(split_logits, split_invalid_action_masks)
            ]
    
        
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
        num_predicted_parameters = len(envs.action_plane_space.nvec)
        logprob = logprob.T.view(-1, self.mapsize, num_predicted_parameters)
        entropy = entropy.T.view(-1, self.mapsize, num_predicted_parameters)
        action = action.T.view(-1, self.mapsize, num_predicted_parameters)
        return action, logprob.sum(1).sum(1), entropy.sum(1).sum(1), invalid_action_masks


    
class SoftQNetwork(nn.Module):
    def __init__(self, envs, mapsize=16 * 16):
        super().__init__()
        self.mapsize = mapsize
        h, w, c = envs.observation_space.shape
        self.encoder = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.critic = nn.Sequential(
            nn.Flatten(),
            layer_init(nn.Linear(64 * 4 * 4, 128)),
            nn.ReLU(),
            layer_init(nn.Linear(128, 1), std=1),
        )
        self.register_buffer("mask_value", torch.tensor(-1e8))

    def get_value(self, x):
        return self.critic(self.encoder(x))
    

def build_models():
    actor = Actor(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=q_lr , eps=eps)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=policy_lr)
    
    
    #self.obs_shape, dtype=observation_space.dtype)
    
    #add autotune
    alpha = 0.2
    
    envs.observation_space.dtype = np.float32
    rb = ReplayBufferActs(
        buffer_size,
        envs.observation_space,
        envs.action_space,
        (mapsize, envs.action_plane_space.nvec.sum()),
        device,
        handle_timeout_termination=True,
    )
    start_time = time.time()
    # TRY NOT TO MODIFY: start the game
    obs = envs.reset()
    for global_step in range(total_timesteps):
        invalid_mask = torch.tensor(envs.get_action_mask()).to(device)
        if global_step < learning_starts:
            actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)])
            actions = actions.reshape(envs.num_envs, 256, 7)
            actions = torch.Tensor(actions).to(torch.int64).to(device)
        else:
            actions, _, _, _ = actor.get_action(torch.Tensor(obs).to(device), envs=envs, invalid_action_masks=invalid_mask, device=device)
            #action shape torch.Size([8, 256, 7])
        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, dones, infos = envs.step(actions.cpu().numpy().reshape(envs.num_envs, -1))
        
        
        
        if "episode" in infos[0].keys():
            print(f"global_step={global_step}, episodic_return={infos[0]['episode']['r']}")
        real_next_obs = next_obs.copy()
        
#         for idx, d in enumerate(dones):
#             if d:
#                 print(infos[idx])
#                 real_next_obs[idx] = infos[idx]["terminal_observation"]
        
        #ugly hack because cannot support a single env for now!
        rb.add(obs[0], real_next_obs[0], actions[0].reshape(1, -1), invalid_mask[0], rewards[0], dones[0], [infos[0]])
        
        obs = next_obs
        if global_step > learning_starts and global_step % update_frequency == 0:
            data = rb.sample(batch_size)
            
            # CRITIC training
            with torch.no_grad():
                next_state_actions, next_state_log_pi, _, _= actor.get_action(data.next_observations, envs=envs, invalid_action_masks=data.masks, device=device)
                qf1_next_target = qf1_target.get_value(data.next_observations)
                qf2_next_target = qf2_target.get_value(data.next_observations)
                 #we can use the action probabilities instead of MC sampling to estimate the expectation
                min_qf_next_target = (
                    torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
                )
                # adapt Q-target for discrete Q-function
                min_qf_next_target = min_qf_next_target.sum(dim=1)[..., None]
                next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * gamma * (min_qf_next_target).view(-1)
            
            # Use Q-values only for the taken actions
            qf1_a_values = qf1.get_value(data.observations).view(-1)
            qf2_a_values = qf2.get_value(data.observations).view(-1)
            qf1_loss = F.mse_loss(qf1_a_values, next_q_value, reduction="mean")
            qf2_loss = F.mse_loss(qf2_a_values, next_q_value, reduction="mean")
            qf_loss = qf1_loss + qf2_loss
            q_optimizer.zero_grad()
            qf_loss.backward()
            q_optimizer.step()
            
            
            #actor trainign
            state_actions, state_log_pi, _, _= actor.get_action(data.observations, envs=envs, invalid_action_masks=data.masks, device=device)
            with torch.no_grad():
                qf1_pi = qf1.get_value(data.observations)
                qf2_pi = qf2.get_value(data.observations)
                min_qf_pi = torch.min(qf1_pi, qf2_pi)
            actor_loss = ((alpha * state_log_pi) - min_qf_pi).mean()
            actor_optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()
            
            
            #print(global_step, target_network_frequency)
            
            # update the target networks
            if global_step % target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                print("HAS BEEN UPDATED")
            

build_models()



global_step=1866, episodic_return=1.0
global_step=3866, episodic_return=0.20000000298023224
global_step=5866, episodic_return=0.20000000298023224
global_step=7866, episodic_return=4.0
global_step=9866, episodic_return=5.0
global_step=11866, episodic_return=1.2000000476837158
global_step=13866, episodic_return=4.199999809265137
global_step=15866, episodic_return=1.0
global_step=17866, episodic_return=6.0
global_step=19866, episodic_return=0.20000000298023224
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.4844, 0.5156, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.8255e-01, -1.2006e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.4945, 0.5055, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.6923e-01, -1.4713e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.5105, 0.4895, 0.0000, 0.0000, 0.0000, 0.0000],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [ 1.8497e-01,  1.4299e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.5441, 0.4559, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]])
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_f

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.3152, 0.3700, 0.3148, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-8.8379e-02,  7.1738e-02, -8.9617e-02, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.5361, 0.4639, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [ 3.0099e-02, -1.1468e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.5375, 0.4625, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [ 1.6907e-02, -1.3355e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.4773, 0.5227, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-4.0791e-01, -3.1697e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.2941, 0.3885, 0.3174, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.4460e-01,  1.3363e-01, -6.8233e-02, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]])
torch.Size([64, 16,

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [ 3.7358e-02, -7.7647e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.6435, 0.3565, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.6138e-01, -7.5190e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.6684, 0.3316, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-8.5745e-02, -7.8662e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.6454, 0.3546, 0.0000, 0.0000, 0.0000, 0.0000],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [ 5.8136e-01, -1.7453e-02, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.6755, 0.3245, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-2.8346e-02, -7.6147e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.2634e-01, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.00

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.6613, 0.3387, 0.0000, 0.0000, 0.0000, 0.0000],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [ 5.9172e-01, -7.7331e-02, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+

tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [ 6.1693e-01, -7.2775e-02, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        ...,
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08],
        [-1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08, -1.0000e+08,
         -1.0000e+08]], grad_fn=<SWhereBackward>)
torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.166

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]])
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -1000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.]], grad_fn=<SWhereBackward>)
torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667

torch.Size([64, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([16384, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000

torch.Size([2, 16, 16, 27])
tensor([[0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        ...,
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       grad_fn=<SoftmaxBackward>)
torch.Size([512, 6])
tensor([[-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        ...,
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -100000000., -100000000.,
         -100000000.],
        [-100000000., -100000000., -100000000., -10000000

KeyboardInterrupt: 

In [None]:

class Actor(nn.Module):
    def __init__(self, envs, mapsize=16 * 16):
        super().__init__()
        self.mapsize = mapsize
        h, w, c = envs.observation_space.shape
        self.encoder = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.actor = nn.Sequential(
            layer_init(nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(32, 78, 3, stride=2, padding=1, output_padding=1)),
            Transpose((0, 2, 3, 1)),
        )
        self.register_buffer("mask_value", torch.tensor(-1e8))

    def get_action(self, x, action=None, invalid_action_masks= None, envs=None, device=None):
        hidden = self.encoder(x)
        logits = self.actor(hidden)
        grid_logits = logits.reshape(-1, envs.action_plane_space.nvec.sum())
        split_logits = torch.split(grid_logits, envs.action_plane_space.nvec.tolist(), dim=1)

        if action is None:
            invalid_action_masks = invalid_action_masks.view(-1, invalid_action_masks.shape[-1])
            split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_plane_space.nvec.tolist(), dim=1)
            multi_categoricals = [
                CategoricalMasked(logits=logits, masks=iam, mask_value=self.mask_value)
                for (logits, iam) in zip(split_logits, split_invalid_action_masks)
            ]
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        else:
            invalid_action_masks = invalid_action_masks.view(-1, invalid_action_masks.shape[-1])
            action = action.view(-1, action.shape[-1]).T
            split_invalid_action_masks = torch.split(invalid_action_masks, envs.action_plane_space.nvec.tolist(), dim=1)
            multi_categoricals = [
                CategoricalMasked(logits=logits, masks=iam, mask_value=self.mask_value)
                for (logits, iam) in zip(split_logits, split_invalid_action_masks)
            ]
        
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        probs = [categorical.probs for categorical in multi_categoricals]
        
        
        entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
        num_predicted_parameters = len(envs.action_plane_space.nvec)
        logprob = logprob.T.view(-1, self.mapsize, num_predicted_parameters)
        entropy = entropy.T.view(-1, self.mapsize, num_predicted_parameters)
        action = action.T.view(-1, self.mapsize, num_predicted_parameters)
        probs = [prob.T.reshape(-1, self.mapsize * prob.shape[1]) for prob in probs]
        return action, logprob.sum(1).sum(1), entropy.sum(1).sum(1), invalid_action_masks, probs


    
class SoftQNetwork(nn.Module):
    def __init__(self, envs, mapsize=16 * 16):
        super().__init__()
        self.mapsize = mapsize
        h, w, c = envs.observation_space.shape
        self.encoder = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(c, 32, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, kernel_size=3, padding=1)),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.ReLU(),
        )

        self.critic = nn.Sequential(
            nn.Flatten(),
            layer_init(nn.Linear(64 * 4 * 4, 128)),
            nn.ReLU(),
            layer_init(nn.Linear(128, 1), std=1),
        )
        self.register_buffer("mask_value", torch.tensor(-1e8))

    def get_value(self, x):
        return self.critic(self.encoder(x))

def learn_CQL_SAC():
    actor = Actor(envs).to(device)
    qf1 = SoftQNetwork(envs).to(device)
    qf2 = SoftQNetwork(envs).to(device)
    qf1_target = SoftQNetwork(envs).to(device)
    qf2_target = SoftQNetwork(envs).to(device)
    qf1_target.load_state_dict(qf1.state_dict())
    qf2_target.load_state_dict(qf2.state_dict())
    q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=q_lr , eps=eps)
    actor_optimizer = optim.Adam(list(actor.parameters()), lr=policy_lr)
    
    clip_grad_param = 1
    #alpha
    learning_rate_alpha = 5e-4
    action_size = len(envs.action_space.sample())
    target_entropy = -action_size
    log_alpha = torch.tensor([0.0], requires_grad=True)
    alpha = log_alpha.exp().detach()
    alpha_optimizer = optim.Adam(params=[log_alpha], lr=learning_rate_alpha) 
    
    envs.observation_space.dtype = np.float32
    rb = ReplayBufferActs(
        buffer_size,
        envs.observation_space,
        envs.action_space,
        (mapsize, envs.action_plane_space.nvec.sum()),
        device,
        handle_timeout_termination=True,
    )
    start_time = time.time()
    # TRY NOT TO MODIFY: start the game
    obs = envs.reset()
    for global_step in range(total_timesteps):
        invalid_mask = torch.tensor(envs.get_action_mask()).to(device)
        if global_step < learning_starts:
            actions = np.array([envs.action_space.sample() for _ in range(envs.num_envs)])
            actions = actions.reshape(envs.num_envs, 256, 7)
            actions = torch.Tensor(actions).to(torch.int64).to(device)
        else:
            actions, _, _, _, _ = actor.get_action(torch.Tensor(obs).to(device), envs=envs, invalid_action_masks=invalid_mask, device=device)
            #action shape torch.Size([8, 256, 7])
        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards, dones, infos = envs.step(actions.cpu().numpy().reshape(envs.num_envs, -1))
        
        
        
        if "episode" in infos[0].keys():
            print(f"global_step={global_step}, episodic_return={infos[0]['episode']['r']}")
        real_next_obs = next_obs.copy()
        
        #ugly hack because cannot support a single env for now!
        rb.add(obs[0], real_next_obs[0], actions[0].reshape(1, -1), invalid_mask[0], rewards[0], dones[0], [infos[0]])
        
        obs = next_obs
        if global_step > learning_starts and global_step % update_frequency == 0:
            data = rb.sample(batch_size)
            current_alpha = copy.deepcopy(alpha)
            _, log_pis, _, _, action_probs = actor.get_action(data.observations, envs=envs, invalid_action_masks=data.masks, device=device)
            q1 = qf1.get_value(data.observations)
            q2 = qf2.get_value(data.observations)
            min_Q = torch.min(q1,q2)
            actor_loss = 0
            log_action_pi= []
            for act_prob in action_probs:
                actor_loss += (act_prob * (alpha.to(device) * log_pis.unsqueeze(1) - min_Q )).sum(1).mean()
                log_action_pi.append(torch.sum(log_pis.unsqueeze(1) * act_prob, dim=1))
            
            actor_optimizer.zero_grad()
            actor_loss.backward()
            actor_optimizer.step()
            
            #compute alpha loss
            alpha_loss = - (log_alpha.exp() * (log_pis.cpu() + target_entropy).detach().cpu()).mean()
            alpha_optimizer.zero_grad()
            alpha_loss.backward()
            alpha_optimizer.step()
            alpha = log_alpha.exp().detach()
            
            #update critic
            with torch.no_grad():
                #mask is not correct!
                _, next_state_log_pis, _, _, next_action_probs = actor.get_action(data.next_observations, envs=envs, invalid_action_masks=data.masks, device=device)
                qf1_next_target = qf1_target.get_value(data.next_observations)
                qf2_next_target = qf2_target.get_value(data.next_observations)
                #print((torch.min(qf1_next_target, qf2_next_target) - alpha.to(device) * next_state_log_pis.unsqueeze(1)).shape)
                Q_targets = []
                for act_prob_next in next_action_probs:
                    Q_target_next = act_prob_next * (torch.min(qf1_next_target, qf2_next_target) - alpha.to(device) * next_state_log_pis.unsqueeze(1))
                    Q_targets.append(data.rewards + (gamma * (1 - data.dones) * Q_target_next.sum(dim=1).unsqueeze(-1))) 
                
                Q_targets = torch.stack(Q_targets, dim=1).squeeze(2)
            
            q1_ = qf1.get_value(data.observations)
            q2_ = qf2.get_value(data.observations)
            
            
            critic_loss1 = 0.5 * F.mse_loss(q1_, Q_targets, reduction='none').mean(1).mean(0)
            critic_loss2 = 0.5 * F.mse_loss(q2_, Q_targets, reduction='none').mean(1).mean(0)
            
            cql1_scaled_loss = torch.logsumexp(q1_, dim=1).mean() - q1_.mean()
            cql2_scaled_loss = torch.logsumexp(q2_, dim=1).mean() - q2_.mean()
            
            total_c1_loss = critic_loss1 + cql1_scaled_loss
            total_c2_loss = critic_loss2 + cql2_scaled_loss
            
            q_optimizer.zero_grad()
            total_c1_loss.backward(retain_graph=True)
            clip_grad_norm_(qf1.parameters(), clip_grad_param)
            # critic 2
            total_c2_loss.backward()
            clip_grad_norm_(qf2.parameters(), clip_grad_param)
            q_optimizer.step()
            
            
            if global_step % target_network_frequency == 0:
                for param, target_param in zip(qf1.parameters(), qf1_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                for param, target_param in zip(qf2.parameters(), qf2_target.parameters()):
                    target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
                print("HAS BEEN UPDATED")
            
#             qf1_a_values = qf1.get_value(data.observations).view(-1)
#             qf2_a_values = qf2.get_value(data.observations).view(-1)
#             qf1_loss = F.mse_loss(qf1_a_values, next_q_value, reduction="mean")
#             qf2_loss = F.mse_loss(qf2_a_values, next_q_value, reduction="mean")
            
            
            
            
#             qf1_next_target = qf1_target.get_value(data.next_observations)
#             qf2_next_target = qf2_target.get_value(data.next_observations)
#                  #we can use the action probabilities instead of MC sampling to estimate the expectation
#                 min_qf_next_target = (
#                     torch.min(qf1_next_target, qf2_next_target) - alpha * next_state_log_pi
#                 )
#                 # adapt Q-target for discrete Q-function
#                 min_qf_next_target = min_qf_next_target.sum(dim=1)[..., None]
#                 next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * gamma * (min_qf_next_target).view(-1)
            
# #             raise("COOl")

learn_CQL_SAC()



global_step=1919, episodic_return=2.200000047683716
global_step=3919, episodic_return=1.2000000476837158
global_step=5919, episodic_return=4.0
global_step=7919, episodic_return=2.0
global_step=9919, episodic_return=1.0
global_step=11919, episodic_return=1.0
global_step=13919, episodic_return=1.2000000476837158
global_step=15919, episodic_return=0.20000000298023224
global_step=17919, episodic_return=2.0
global_step=19919, episodic_return=1.2000000476837158
global_step=21919, episodic_return=26.200000762939453
global_step=23919, episodic_return=7.199999809265137
global_step=25919, episodic_return=68.0
global_step=27919, episodic_return=21.0
global_step=29919, episodic_return=51.0
global_step=31919, episodic_return=78.0
global_step=33919, episodic_return=40.0
global_step=35919, episodic_return=25.0


In [112]:
#remarks
- Currently actor and critic do not share a common encoder (in original PPO they share a common encoder)
But according to authors of the SAC paper, sharing a CNN encoder between Actor and Critics is not recommended for SAC

- Autotune for alpha not implemented

- Does not handle 'terminal_observation' (not supported by env), check how is it handle by PPO

- Memory does not save yet the invalid masks!

- #ugly hack because cannot support a single env for now!
    rb.add(obs[0], real_next_obs[0], actions[0].reshape(1, -1), rewards[0], dones[0], [infos[0]])
    Change that to support a single env!
    
- critic should use actions too ? not just states? + recheck code !

ERRUER IMPORTANTE (ou pas)
- FOR NEXT STATE ACTOR, SHOULD USE NEXT MASK??? (store next mask too!!!)

- USE ACTION PROBS


/var/folders/b_/wlgny6454n9bk39_zzmgsm9m0000gn/T/ipykernel_58771/1710395406.py:183: UserWarning: Using a target size (torch.Size([64, 7])) that is different to the input size (torch.Size([64, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
            
            update of critic is likely to be wrong (actor should be okay , especally uses wrong masks!)
            
            
BUG OF CRITIC !!! (check with same update of critic as in original PPO)

SyntaxError: invalid syntax (784648057.py, line 2)

In [None]:
Original paper (discrete SAC)
https://arxiv.org/pdf/1910.07207.pdf
https://github.com/BY571/SAC_discrete/blob/main/agent.py
https://github.com/twni2016/pomdp-baselines/blob/main/policies/rl/sacd.py
https://github.com/BY571/CQL/blob/main/CQL-SAC-discrete/agent.py
    
    
Multidiscrete sac
https://github.com/Maggern3/SAC/blob/AltSAC/sac.py
https://github.com/BY571/SAC_discrete/blob/main/agent.py
https://github.com/twni2016/pomdp-baselines/blob/main/policies/rl/sacd.py
https://github.com/BY571/CQL/blob/main/CQL-SAC-discrete/agent.py