In [1]:
# %%
import os
import sys

from matplotlib import animation
sys.path.append("/home/mbortkie/repos/crl_subgoal/src")
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# %%
import functools
import os
import distrax


import jax
import jax.numpy as jnp
from jax import random
from typing import Tuple, Dict, Any
from dataclasses import dataclass
import chex
from flax import struct
from absl import app, flags
from ml_collections import config_flags
from impls.agents import agents
from config import SRC_ROOT_DIR
from envs.block_moving_env import *
from train import *
from impls.utils.checkpoints import restore_agent, save_agent
from config import Config, ExpConfig
from envs import legal_envs
import matplotlib.pyplot as plt
from impls.utils.networks import GCDiscreteActor
import copy
import numpy as np




In [2]:
# %%
RANGE_GENERALIZATION = [1,2,3,4,5,6,7,9,11]
EPISODE_LENGTH = 100
NUM_ENVS = 1024
CHECKPOINT = 12
RUN_NAME = f"GCQIL_{CHECKPOINT}_ckpt_short_more_data"
MODEL_PATH = "/home/mbortkie/repos/crl_subgoal/experiments/test_generalization_sc_20250822_221136/runs/same_reward_0.5_expectile_next_state_remove_tgts_FIX_relabeling_future_moving_boxes_5_grid_5_range_3_7_alpha_0.3/"
EPOCHS = 101
EVAL_EVERY = 10
FIGURES_PATH = f"/home/mbortkie/repos/crl_subgoal/notebooks/figures/{RUN_NAME}"
GIF_PATH = f"{FIGURES_PATH}/gifs"
os.makedirs(FIGURES_PATH, exist_ok=True)
os.makedirs(GIF_PATH, exist_ok=True)


In [3]:
# %%
config = Config(
    exp=ExpConfig(seed=0, name="test"),
    env=BoxPushingConfig(
        grid_size=5,
        number_of_boxes_min=3,
        number_of_boxes_max=7,
        number_of_moving_boxes_max=5
    )
)

# %%
env = create_env(config.env)
env = AutoResetWrapper(env)
key = random.PRNGKey(config.exp.seed)
env.step = jax.jit(jax.vmap(env.step))
env.reset = jax.jit(jax.vmap(env.reset))
partial_flatten = functools.partial(flatten_batch, get_next_obs=config.agent.use_next_obs)
jitted_flatten_batch = jax.jit(jax.vmap(partial_flatten, in_axes=(None, 0, 0)), static_argnums=(0,))
dummy_timestep = env.get_dummy_timestep(key)


In [4]:
replay_buffer = jit_wrap(
    TrajectoryUniformSamplingQueue(
        max_replay_size=config.exp.max_replay_size,
        dummy_data_sample=dummy_timestep,
        sample_batch_size=config.exp.batch_size,
        num_envs=config.exp.num_envs,
        episode_length=config.env.episode_length,
    )
)
buffer_state = jax.jit(replay_buffer.init)(key)

example_batch = {
    'observations':dummy_timestep.grid.reshape(1, -1),  # Add batch dimension 
    'next_observations': dummy_timestep.grid.reshape(1, -1),
    'actions': jnp.ones((1,), dtype=jnp.int8) * (env._env.action_space-1), # TODO: make sure it should be the maximal value of action space  # Single action for batch size 1
    'rewards': dummy_timestep.reward.reshape(1, -1),
    'masks': 1.0 - dummy_timestep.reward.reshape(1, -1), 
    'value_goals': dummy_timestep.grid.reshape(1, -1),
    'actor_goals': dummy_timestep.grid.reshape(1, -1),
}

# %%
agent, config = restore_agent(example_batch, MODEL_PATH, CHECKPOINT)

# %%
keys = random.split(random.PRNGKey(0), NUM_ENVS)
state, info = env.reset(keys)

# %%
dummy_timestep = env.get_dummy_timestep(key)

replay_buffer = jit_wrap(
    TrajectoryUniformSamplingQueue(
        max_replay_size=config.exp.max_replay_size,
        dummy_data_sample=dummy_timestep,
        sample_batch_size=config.exp.batch_size,
        num_envs=config.exp.num_envs,
        episode_length=config.env.episode_length,
    )
)
buffer_state = jax.jit(replay_buffer.init)(key)

Restored from /home/mbortkie/repos/crl_subgoal/experiments/test_generalization_sc_20250822_221136/runs/same_reward_0.5_expectile_next_state_remove_tgts_FIX_relabeling_future_moving_boxes_5_grid_5_range_3_7_alpha_0.3//params_12.pkl


In [5]:
@jax.jit
def make_batch(buffer_state, key):
    key, batch_key, double_batch_key = jax.random.split(key, 3)
    # Sample and process transitions
    buffer_state, transitions = replay_buffer.sample(buffer_state)
    batch_keys = jax.random.split(batch_key, transitions.grid.shape[0])
    state, next_state, future_state, goal_index = jitted_flatten_batch(config.exp.gamma, transitions, batch_keys)

    state, actions, next_state, future_state, goal_index = get_single_pair_from_every_env(state, next_state, future_state, goal_index, double_batch_key, use_double_batch_trick=config.exp.use_double_batch_trick)
    if not config.exp.use_targets:
        state = state.replace(grid=GridStatesEnum.remove_targets(state.grid))
        next_state = next_state.replace(grid=GridStatesEnum.remove_targets(next_state.grid))
        future_state = future_state.replace(grid=GridStatesEnum.remove_targets(future_state.grid))
    # Create valid batch
    batch = {
        'observations': state.grid.reshape(state.grid.shape[0], -1),
        'next_observations': next_state.grid.reshape(next_state.grid.shape[0], -1),
        'actions': actions.squeeze(),
        'rewards': state.reward.reshape(state.reward.shape[0], -1),
        'masks': 1.0 - state.reward.reshape(state.reward.shape[0], -1), # TODO: add success and reward separately
        'value_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
        'actor_goals': future_state.grid.reshape(future_state.grid.shape[0], -1),
    }
    return buffer_state, batch

def value_transform(x):
    return jnp.log(jnp.maximum(x, 1e-6))

# Data creation

In [6]:
data_key = random.PRNGKey(0)
_, _, timesteps = collect_data(agent, data_key, env, config.exp.num_envs, config.env.episode_length, use_targets=config.exp.use_targets)
buffer_state = replay_buffer.insert(buffer_state, timesteps)

2025-08-22 22:52:47.221164: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


In [7]:
batch_keys = jax.random.split(data_key, config.exp.num_envs)

In [8]:
batch_keys.shape

(1024, 2)

In [9]:
timesteps = jax.tree_util.tree_map(lambda x: x.swapaxes(1, 0), timesteps)
timesteps.grid.shape

(1024, 100, 5, 5)

In [10]:
state, next_state, future_state, goal_index = jitted_flatten_batch(config.exp.gamma, timesteps, batch_keys)

In [11]:
state.grid.shape, next_state.grid.shape, future_state.grid.shape

((1024, 99, 5, 5), (1024, 99, 5, 5), (1024, 99, 5, 5))

In [12]:
for i in range(30):
    print("________________")
    print(state.grid[0, i])
    print(next_state.grid[0, i])
    print(future_state.grid[0, i])

________________
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 6  0  0  0  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  3  0  0  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
[[ 0 10 10  0  0]
 [ 2  1  0  0  0]
 [ 2  0  0  0  0]
 [ 2  0  0 10  0]
 [ 0  0  4  0  1]]
________________
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  3  0  0  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  0  3  0  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  0  0  3  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
________________
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  0  3  0  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  0  0  3  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  0  0  0  0]
 [ 2  0  0 10  0]
 [ 0  0  0  5  1]]
________________
[[ 0 10  2  0  0]
 [ 2  1  1  0  0]
 [ 2  0  0  3  0]
 [ 2  0  1  2  0]
 [ 0  0  0  1  1]]
[[ 0 10  2  0  0]
 [ 2

In [13]:
jnp.sum(next_state.grid == GridStatesEnum.BOX_ON_TARGET, axis=(-1,-2))[:2]

Array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 1, 2, 1, 1, 1, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], dtype=int32)

In [14]:
(jnp.sum(next_state.grid == GridStatesEnum.BOX_ON_TARGET, axis=(-1,-2))
+ jnp.sum(next_state.grid == GridStatesEnum.AGENT_ON_TARGET_WITH_BOX, axis=(-1,-2))
+ jnp.sum(next_state.grid == GridStatesEnum.AGENT_ON_TARGET_WITH_BOX_CARRYING_BOX, axis=(-1,-2)))[:2]

Array([[1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], dtype=int32)

# Reward relabelling

In [15]:
boxes_on_targets_next = (
    jnp.sum(next_state.grid == GridStatesEnum.BOX_ON_TARGET,  axis=(-1,-2))
    + jnp.sum(next_state.grid == GridStatesEnum.AGENT_ON_TARGET_WITH_BOX,  axis=(-1,-2))
    + jnp.sum(next_state.grid == GridStatesEnum.AGENT_ON_TARGET_WITH_BOX_CARRYING_BOX, axis=(-1,-2))
)
boxes_on_targets_future = (
    jnp.sum(future_state.grid == GridStatesEnum.BOX_ON_TARGET,  axis=(-1,-2))
    + jnp.sum(future_state.grid == GridStatesEnum.AGENT_ON_TARGET_WITH_BOX,  axis=(-1,-2))
    + jnp.sum(future_state.grid == GridStatesEnum.AGENT_ON_TARGET_WITH_BOX_CARRYING_BOX,  axis=(-1,-2))
)

In [16]:
boxes_on_targets_future

Array([[3, 1, 2, ..., 3, 3, 3],
       [0, 2, 0, ..., 2, 2, 2],
       [3, 2, 3, ..., 4, 4, 4],
       ...,
       [2, 4, 3, ..., 6, 6, 6],
       [3, 5, 3, ..., 6, 6, 6],
       [4, 1, 3, ..., 5, 5, 5]], dtype=int32)

In [17]:
jnp.array((boxes_on_targets_future-boxes_on_targets_next)[:2]==0, dtype=jnp.float16)

Array([[0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1.,
        0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1.],
       [1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1.,
        0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1.]], dtype=float16)

# GCQIL investigation

In [18]:
def expectile_loss(adv, diff, expectile):
    """Compute the expectile loss."""
    weight = jnp.where(adv >= 0, expectile, (1 - expectile))
    return weight * (diff**2)

def value_loss(agent, batch, grad_params):
    """Compute the IQL value loss."""
    q1, q2 = agent.network.select('target_critic')(batch['observations'], batch['value_goals'], batch['actions'])
    q = jnp.minimum(q1, q2)
    v = agent.network.select('value')(batch['observations'], batch['value_goals'], params=grad_params)
    value_loss = agent.expectile_loss(q - v, q - v, agent.config['expectile']).mean()

    return value_loss, {
        'value_loss': value_loss,
        'v_mean': v.mean(),
        'v_max': v.max(),
        'v_min': v.min(),
    }

def critic_loss(agent, batch, grad_params):
    """Compute the IQL critic loss."""
    next_v = agent.network.select('value')(batch['next_observations'], batch['value_goals'])
    q = batch['rewards'] + agent.config['discount'] * batch['masks'] * next_v

    q1, q2 = agent.network.select('critic')(
        batch['observations'], batch['value_goals'], batch['actions'], params=grad_params
    )
    critic_loss = ((q1 - q) ** 2 + (q2 - q) ** 2).mean()

    return critic_loss, {
        'critic_loss': critic_loss,
        'q_mean': q.mean(),
        'q_max': q.max(),
        'q_min': q.min(),
    }

In [19]:
key = jax.random.PRNGKey(2)
buffer_state, batch = make_batch(buffer_state, key)

In [20]:
for k, v in batch.items():
    print(f"{k}: {v.shape}")

actions: (1024,)
actor_goals: (1024, 25)
masks: (1024, 1)
next_observations: (1024, 25)
observations: (1024, 25)
rewards: (1024, 1)
value_goals: (1024, 25)


In [21]:
batch['rewards'].squeeze()[:100]

Array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int8)

In [22]:
value_loss(agent, batch, None)

2025-08-22 22:53:02.938828: W external/xla/xla/service/gpu/autotuning/dot_search_space.cc:200] All configs were filtered out because none of them sufficiently match the hints. Maybe the hints set does not contain a good representative set of valid configs?Working around this by using the full hints set instead.


(Array(2.953126e-11, dtype=float32),
 {'value_loss': Array(2.953126e-11, dtype=float32),
  'v_mean': Array(0.77267396, dtype=float32),
  'v_max': Array(0.7726959, dtype=float32),
  'v_min': Array(0.7726529, dtype=float32)})

In [47]:
batch['observations'][3].reshape(5,5)

Array([[0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [1, 0, 0, 1, 0]], dtype=int8)

In [48]:
batch['next_observations'][3].reshape(5,5)

Array([[0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [5, 0, 0, 1, 0]], dtype=int8)

In [25]:
q1, q2 = agent.network.select('target_critic')(batch['observations'], batch['value_goals'], batch['actions'])

In [26]:
q1.mean(), q1.max(), q1.min()

(Array(0.7726741, dtype=float32),
 Array(0.772676, dtype=float32),
 Array(0.77267134, dtype=float32))

In [27]:
q2.mean(), q2.max(), q2.min()

(Array(0.77267534, dtype=float32),
 Array(0.7727233, dtype=float32),
 Array(0.77264476, dtype=float32))

In [28]:
next_v = agent.network.select('value')(batch['next_observations'], batch['value_goals'])
next_v.mean(), next_v.max(), next_v.min()

(Array(0.7726742, dtype=float32),
 Array(0.7727013, dtype=float32),
 Array(0.77264297, dtype=float32))

In [29]:
batch['masks']

Array([[1.],
       [1.],
       [1.],
       ...,
       [1.],
       [1.],
       [1.]], dtype=float32, weak_type=True)

In [30]:
agent.config['discount'] * batch['masks'] * next_v

Array([[0.76494855, 0.76494765, 0.7649525 , ..., 0.76495105, 0.7649447 ,
        0.76493776],
       [0.76494855, 0.76494765, 0.7649525 , ..., 0.76495105, 0.7649447 ,
        0.76493776],
       [0.76494855, 0.76494765, 0.7649525 , ..., 0.76495105, 0.7649447 ,
        0.76493776],
       ...,
       [0.76494855, 0.76494765, 0.7649525 , ..., 0.76495105, 0.7649447 ,
        0.76493776],
       [0.76494855, 0.76494765, 0.7649525 , ..., 0.76495105, 0.7649447 ,
        0.76493776],
       [0.76494855, 0.76494765, 0.7649525 , ..., 0.76495105, 0.7649447 ,
        0.76493776]], dtype=float32)

In [31]:
(batch['rewards'] + agent.config['discount'] * batch['masks'] * next_v).max()

Array(1., dtype=float32)

In [32]:
(batch['rewards']==1).any(), (batch['masks']==1).any()

(Array(True, dtype=bool), Array(True, dtype=bool))

In [33]:
((batch['masks']==1) & (batch['rewards']==1)).any()

Array(False, dtype=bool)

# Trajectory investigation

In [34]:
use_targets = False


timesteps_tr = jax.lax.cond(
    use_targets,
    lambda: timesteps.replace(),
    lambda: timesteps.replace(
        grid=GridStatesEnum.remove_targets(timesteps.grid),
        goal=GridStatesEnum.remove_targets(timesteps.goal)
    )
)

tr_states = timesteps_tr.grid
tr_actions = timesteps_tr.action
tr_goals = timesteps_tr.goal
tr_states.shape, tr_goals.shape, tr_actions.shape

((1024, 100, 5, 5), (1024, 100, 5, 5), (1024, 100))

In [35]:
use_targets = False


timesteps_tr = jax.lax.cond(
    use_targets,
    lambda: timesteps.replace(),
    lambda: timesteps.replace(
        grid=GridStatesEnum.remove_targets(timesteps.grid),
        goal=GridStatesEnum.remove_targets(timesteps.goal)
    )
)

tr_states = timesteps_tr.grid
tr_actions = timesteps_tr.action
tr_goals = timesteps_tr.goal
tr_states.shape, tr_goals.shape, tr_actions.shape

((1024, 100, 5, 5), (1024, 100, 5, 5), (1024, 100))

In [36]:
tr_states, tr_goals, tr_actions = tr_states.reshape((-1, tr_states.shape[-2],tr_states.shape[-1])), tr_goals.reshape((-1, tr_goals.shape[-2],tr_goals.shape[-1])), tr_actions.reshape((-1,))
tr_states.shape, tr_goals.shape, tr_actions.shape

((102400, 5, 5), (102400, 5, 5), (102400,))

In [37]:
tr_states, tr_goals = tr_states.reshape(tr_states.shape[0], -1), tr_goals.reshape(tr_goals.shape[0], -1)
tr_states.shape, tr_goals.shape, tr_actions.shape

((102400, 25), (102400, 25), (102400,))

In [38]:
timesteps.grid[0,0]

Array([[ 0, 10,  2,  0,  0],
       [ 2,  1,  1,  0,  0],
       [ 6,  0,  0,  0,  0],
       [ 2,  0,  1,  2,  0],
       [ 0,  0,  0,  1,  1]], dtype=int8)

In [39]:
tr_states, tr_goals, tr_actions = tr_states.reshape((-1, tr_states.shape[-2],tr_states.shape[-1])), tr_goals.reshape((-1, tr_goals.shape[-2],tr_goals.shape[-1])), tr_actions.reshape((-1,))
tr_states.shape, tr_goals.shape, tr_actions.shape

((1, 102400, 25), (1, 102400, 25), (102400,))

In [40]:
tr_states, tr_goals = tr_states.reshape(tr_states.shape[0], -1), tr_goals.reshape(tr_goals.shape[0], -1)
tr_states.shape, tr_goals.shape, tr_actions.shape

((1, 2560000), (1, 2560000), (102400,))

In [41]:
timesteps.grid[0,0]

Array([[ 0, 10,  2,  0,  0],
       [ 2,  1,  1,  0,  0],
       [ 6,  0,  0,  0,  0],
       [ 2,  0,  1,  2,  0],
       [ 0,  0,  0,  1,  1]], dtype=int8)