# Trainer NNX scratchpad

Generate training data
```
python rgi/main.py --game connect4 --player1 random --player2 random --num_games 100 --save_trajectories
```

In [1]:
import os
from typing import Any

from collections import Counter

import jax
import jax.numpy as jnp
from flax import nnx

print(jax.devices())
assert jax.devices()[0].platform == 'gpu'

# print wider lines to stop arrays wrapping so soon. numpy default is 75.
jnp.set_printoptions(linewidth=150)

[CudaDevice(id=0)]


# Load Trajectories

In [2]:
from rgi.core import trajectory

game_name = "connect4"
trajectories_glob = os.path.join("..", "data", "trajectories", game_name, "*.trajectory.npy")
trajectories = trajectory.load_trajectories(trajectories_glob)

print(f'trajectories_glob: {trajectories_glob}')
print(f'num_trajectories: {len(trajectories)}')

trajectories_glob: ../data/trajectories/connect4/*.trajectory.npy
num_trajectories: 100


In [3]:
def print_trajectory_stats():
    c = Counter((t.actions[0].item(), t.final_rewards[0].item()) for t in trajectories)
    for action in range(1,7+1):
        win_count, loss_count = c[(action, 1.0)], c[(action, -1.0)]
        print(f'action: {action}, win: {win_count:3d}, loss: {loss_count:3d}, win_pct: {win_count / (win_count + loss_count) * 100:.2f}%')

print_trajectory_stats()

# num_trajectories: 100
# action: 1, win:  14, loss:   9, win_pct: 60.87%
# action: 2, win:   6, loss:   6, win_pct: 50.00%
# action: 3, win:   8, loss:   4, win_pct: 66.67%
# action: 4, win:  11, loss:   3, win_pct: 78.57%
# action: 5, win:  11, loss:   5, win_pct: 68.75%
# action: 6, win:   8, loss:   5, win_pct: 61.54%
# action: 7, win:   5, loss:   5, win_pct: 50.00%

2024-10-15 12:56:16.635965: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


action: 1, win:  14, loss:   9, win_pct: 60.87%
action: 2, win:   6, loss:   6, win_pct: 50.00%
action: 3, win:   8, loss:   4, win_pct: 66.67%
action: 4, win:  11, loss:   3, win_pct: 78.57%
action: 5, win:  11, loss:   5, win_pct: 68.75%
action: 6, win:   8, loss:   5, win_pct: 61.54%
action: 7, win:   5, loss:   5, win_pct: 50.00%


In [4]:
from rgi.core.trajectory import EncodedTrajectory

# define trajectory NamedTuple
from typing import NamedTuple
class TrajectoryStep(NamedTuple):
    move_index: int
    state: jax.Array
    action: jax.Array
    next_state: jax.Array
    reward: jax.Array


def unroll_trajectory(encoded_trajectories: list[EncodedTrajectory]):
    for t in encoded_trajectories:
        for i in range(t.length - 1):
            yield TrajectoryStep(i, t.states[i], t.actions[i], t.states[i + 1], t.state_rewards[i])

all_trajecoty_steps = list(unroll_trajectory(trajectories))
print(f'num_trajectory_steps: {len(all_trajecoty_steps)}')

num_trajectory_steps: 2193


In [5]:
state_batch = jnp.array([t.state for t in all_trajecoty_steps])
action_batch = jnp.array([t.action for t in all_trajecoty_steps])
batch_full = {'state': state_batch, 'action': action_batch}

# batch with 1 or 2 steps for testing
batch_1 = {'state': state_batch[:1], 'action': action_batch[:1]}
batch_2 = {'state': state_batch[:2], 'action': action_batch[:2]}

# Super easy to get 100% accuracy.
batch_repeated = {'state': jnp.tile(batch_2['state'], (1000,1)), 'action': jnp.tile(batch_2['action'], (1000,))}


# Define models

In [6]:
from rgi.games import connect4
game = connect4.Connect4Game()
serializer = connect4.Connect4Serializer()

In [7]:
class Connect4StateEmbedder(nnx.Module):
    embedding_dim: int = 64
    hidden_dim: int = 256

    def _state_to_array(self, encoded_state_batch: jax.Array) -> jax.Array:
        board_array = encoded_state_batch[:, :-1].reshape([-1, 6, 7])
        return board_array

    def __init__(self, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.linear1 = nnx.Linear(6*7*64, self.hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(self.hidden_dim, self.embedding_dim, rngs=rngs)

    def __call__(self, encoded_state_batch: jax.Array):
        x = self._state_to_array(encoded_state_batch)
        x = x[..., None]  # Add channel dimension
        x = nnx.relu(self.conv1(x))
        x = nnx.relu(self.conv2(x))
        x = x.reshape(x.shape[0], -1)  # Flatten while preserving batch dimension
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)

        return x


In [19]:
def test_state_embedder():
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))
    print(state_embedder(trajectories[0].states[:1]))

test_state_embedder()

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
  0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [10]:
class Connect4ActionEmbedder(nnx.Module):
    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, *, rngs: nnx.Rngs):
        self.embedding = nnx.Embed(num_embeddings=self.num_actions, features=self.embedding_dim, rngs=rngs)

    def __call__(self, action: jax.Array) -> jax.Array:
        # Ensure action is 0-indexed
        action = action - 1
        return self.embedding(action)
    
    def all_action_embeddings(self):
        return self(jnp.arange(1,self.num_actions+1))

In [21]:
def test_action_embedder():
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(0))
    print(action_embedder(trajectories[0].actions[:1]))

test_action_embedder()

[[ 6.66427687e-02  1.09368816e-01 -1.89152256e-01  8.17345604e-02 -7.85972252e-02 -1.82855129e-01  1.57606214e-01  1.16470948e-01 -7.91573077e-02
   1.19286872e-01 -3.06055341e-02 -2.94262134e-02 -2.04068035e-01  2.47045040e-01 -2.14104000e-02 -1.46370962e-01  3.72717567e-02  6.29766062e-02
  -4.52014022e-02 -2.30931133e-01  1.21301249e-01  2.88546652e-01 -3.97264175e-02 -1.19131254e-02  1.30750760e-01  1.29211560e-01  3.51449512e-02
   1.26254661e-02  7.07970336e-02  1.36475870e-02  3.25598754e-02  5.28224790e-03 -7.92904049e-02  2.16124147e-01  1.23112321e-01 -5.16295396e-02
  -5.90516999e-03 -7.26998299e-02  1.06619947e-01  4.65621985e-02  1.09553948e-01 -3.07245195e-01  6.44185543e-02 -5.80371059e-02  1.00223973e-01
  -4.86722216e-02  2.83067394e-02  2.01851264e-01  3.05441283e-02 -9.47859734e-02 -8.36363509e-02 -1.35540087e-02 -5.33435494e-02  7.62056708e-02
   8.15421045e-02  1.95111567e-03 -1.68878466e-01 -1.05658814e-01 -2.38340721e-02  6.14980869e-02  7.95577615e-02 -1.5707267

In [22]:
class StateToActionModel(nnx.Module):
    state_embedder: Connect4StateEmbedder
    action_embedder: Connect4ActionEmbedder

    embedding_dim: int = 64
    num_actions: int = 7

    def __init__(self, state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, *, rngs: nnx.Rngs):
        self.state_embedder = state_embedder
        self.action_embedder = action_embedder

    def __call__(self, state: jax.Array) -> jax.Array:
        return self.logits(state)

    def logits(self, state_batch: jax.Array) -> jax.Array:
        state_embeddings = self.state_embedder(state_batch)
        all_action_embeddings = self.action_embedder.all_action_embeddings()
        logits = state_embeddings @ all_action_embeddings.T
        return logits
    
    def probs(self, state_batch: jax.Array) -> jax.Array:
        logits = self.logits(state_batch)
        return jax.nn.softmax(logits)


In [26]:
def test_state_action_model():
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(1))
    state_action_model = StateToActionModel(state_embedder, action_embedder, rngs=nnx.Rngs(2))
    print('logits:\n', state_action_model.logits(state_batch[:2]))
    print('\nprobs:\n', state_action_model.probs(state_batch[:2]))

test_state_action_model()


logits:
 [[ 0.          0.          0.          0.          0.          0.          0.        ]
 [-0.05365274  0.04191583 -0.05556455 -0.0902497  -0.04111161  0.00694049  0.01304063]]

probs:
 [[0.14285715 0.14285715 0.14285715 0.14285715 0.14285715 0.14285715 0.14285715]
 [0.13876376 0.15267958 0.13849871 0.13377723 0.14051497 0.14743187 0.14833395]]


# Train Model

In [63]:

# learning_rate = 0.005
# momentum = 0.9

# optimizer = nnx.Optimizer(state_embedder, optax.adamw(learning_rate, momentum))
# metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'))


In [32]:
import optax

@nnx.jit
def l2_loss(model, alpha=0.001):
    loss = sum(
        (alpha * (w ** 2).sum()) 
        for w in jax.tree.leaves(nnx.state(model, nnx.Param))
    )
    return loss

@nnx.jit
def loss_fn(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, batch, l2_weight: float = 1e-4):
    state_embeddings = state_embedder(batch['state'])
    all_action_embeddings = action_embedder.all_action_embeddings()
    logits = state_embeddings @ all_action_embeddings.T

    labels = batch['action']-1
    data_loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()

    state_embedder_l2 = l2_loss(state_embedder)
    action_embedder_l2 = l2_loss(action_embedder)
    
    param_loss = l2_weight * (state_embedder_l2 + action_embedder_l2)

    total_loss = data_loss + param_loss
    return total_loss, logits

In [34]:
def test_loss_fn():
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(0))
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(1))
    loss, logits = loss_fn(state_embedder, action_embedder, batch_2)
    print('loss: ', loss)
    print('logits: ', logits)

test_loss_fn()

loss:  1.927142
logits:  [[ 0.          0.          0.          0.          0.          0.          0.        ]
 [-0.05365274  0.04191583 -0.05556455 -0.0902497  -0.04111161  0.00694049  0.01304064]]


In [35]:
@nnx.jit
def train_step(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state_embedder, action_embedder, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action']-1)  # In-place updates.
  optimizer.update(grads)  # In-place updates.

@nnx.jit
def eval_step(state_embedder: Connect4StateEmbedder, action_embedder: Connect4ActionEmbedder, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(state_embedder, action_embedder, batch)
  metrics.update(loss=loss, logits=logits, labels=batch['action'])  # In-place updates.

In [52]:
def test_train_step():
    state_embedder = Connect4StateEmbedder(rngs=nnx.Rngs(10))
    action_embedder = Connect4ActionEmbedder(rngs=nnx.Rngs(1))
    state_action_model = StateToActionModel(state_embedder, action_embedder, rngs=nnx.Rngs(2))
    optimizer = nnx.Optimizer(state_embedder, optax.adamw(learning_rate=0.005))
    metrics = nnx.MultiMetric(accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average('loss'))
    
    state_action_counts = {}    
    for i in range(5):
        state = batch_full['state'][i]
        state_action_counts[i] = Counter(a.item() for s,a in zip(batch_full['state'], batch_full['action']) if (s == state).all())

    for i in range(1000+1):
        train_step(state_embedder, action_embedder, optimizer, metrics, batch_full)
        if i % 100 == 0:
            print(f'train_step: {i}: {metrics.compute()}')

            for i in range(5):                
                state = batch_full['state'][i]
                probs = state_action_model.probs(jnp.array([state]))
                print(i, jnp.array_str(probs * 100, precision=4, max_line_width=100, suppress_small=True), state_action_counts[i])
            print()

test_train_step()

# train_step: 0: {'accuracy': Array(0.13269493, dtype=float32), 'loss': Array(1.9836767, dtype=float32)}
# 0 [[14.1977 11.6685 13.5787 15.1997 17.1362 14.6427 13.5763]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
# 1 [[11.3556  3.2975  9.4457 16.5656 36.5498 13.3997  9.3861]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
# 2 [[ 7.1729  0.8654  4.9158 14.0405 57.5077 10.0624  5.4355]] Counter({4: 1, 3: 1, 1: 1})
# 3 [[ 4.6876  0.2901  2.6779 10.9177 70.7835  7.5228  3.1204]] Counter({5: 1})
# 4 [[ 1.9264  0.0413  1.0377  6.3201 85.5749  4.0096  1.0899]] Counter({5: 1})

# train_step: 100: {'accuracy': Array(0.22706813, dtype=float32), 'loss': Array(1.9329642, dtype=float32)}
# 0 [[22.3287 10.9658 13.5952 12.6839 11.7041 13.5523 15.1701]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
# 1 [[ 5.4769  7.0353 24.7796  4.1279 10.7706 23.9065 23.9032]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
# 2 [[12.5476  6.2633 10.5911 20.2539 10.8778 19.9631 19.5032]] Counter({4: 1, 3: 1, 1: 1})
# 3 [[28.0764  3.5381  5.1928  9.8187 20.605   3.5495 29.2195]] Counter({5: 1})
# 4 [[25.1313  7.6401 17.2355 16.7269  8.0618 11.8552 13.3493]] Counter({5: 1})

# train_step: 200: {'accuracy': Array(0.44298345, dtype=float32), 'loss': Array(1.4044574, dtype=float32)}
# 0 [[23.3568 12.2055 11.3946 13.931  16.6376 13.0367  9.4378]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
# 1 [[ 7.5375  7.4087 21.2666  0.0822 14.9679 23.1842 25.5529]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
# 2 [[32.8782  0.1298 29.2799 36.4475  0.0006  0.0172  1.2468]] Counter({4: 1, 3: 1, 1: 1})
# 3 [[ 1.8262  0.0037  0.      1.5528 96.6167  0.0005  0.0001]] Counter({5: 1})
# 4 [[ 0.0001  0.      0.3623  0.0428 99.493   0.0952  0.0065]] Counter({5: 1})


train_step: 0: {'accuracy': Array(0.14363885, dtype=float32), 'loss': Array(1.9785088, dtype=float32)}
0 [[14.4563 14.4022 12.5562 17.7009 13.1094 15.1485 12.6265]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
1 [[11.6371 12.0297  5.7556 42.9646  5.8367 16.224   5.5523]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
2 [[ 7.8842  8.5035  2.2254 63.7363  2.6249 12.7118  2.3138]] Counter({4: 1, 3: 1, 1: 1})
3 [[ 3.7658  4.0274  0.5173 82.5937  0.7845  7.7762  0.5352]] Counter({5: 1})
4 [[ 0.9663  1.1183  0.0607 94.4268  0.131   3.2404  0.0565]] Counter({5: 1})

train_step: 100: {'accuracy': Array(0.29418987, dtype=float32), 'loss': Array(1.8454852, dtype=float32)}
0 [[22.5386 17.2494 11.2021 11.8529  9.2233 17.2992 10.6345]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
1 [[ 7.4677  8.4804 27.2941  4.0218  5.8395 24.1607 22.7358]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
2 [[14.6115  9.4252 23.7354 23.2455  6.8535 10.8455 11.2835]] Counter({4: 1, 3: 1, 1: 1

In [66]:
# TODO: How do we stream a bunch of different batches?
# train_batches = [(i, batch) for i in range(2000)]
# test_batches = [(i, batch) for i in range(100)]

#train_batches = [(i, batch_repeated) for i in range(2000)]
#test_batches = [(i, batch_repeated) for i in range(100)]

# NUM_TRAIN_BATCHES = 2000
NUM_TRAIN_BATCHES = 1
train_batches = [(i, batch_full) for i in range(NUM_TRAIN_BATCHES)]
test_batches = [(i, batch_full) for i in range(100)]


metrics_history = {
  'train_loss': [],
  'train_accuracy': [],
  'test_loss': [],
  'test_accuracy': [],
}

train_steps = 10000
eval_every = 200
batch_size = 32

# for step, batch in enumerate(train_ds.as_numpy_iterator()):
for step, batch in train_batches:
  # with jax.disable_jit():
  #   train_step(state_embedder, action_embedder, optimizer, metrics, batch)
  train_step(state_embedder, action_embedder, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for _, test_batch in test_batches:
      eval_step(state_embedder, action_embedder, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    print(
      f"[train] step: {step}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    # print(
    #   f"[test] step: {step}, "
    #   f"loss: {metrics_history['test_loss'][-1]}, "
    #   f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    # )


# ## Outputs with embedding normalization.
# [train] step: 200, loss: 1.4493087530136108, accuracy: 41.806243896484375
# [train] step: 400, loss: 0.495466947555542, accuracy: 77.71910858154297
# [train] step: 600, loss: 0.4014337956905365, accuracy: 80.45576477050781
# [train] step: 800, loss: 0.3858652412891388, accuracy: 80.7986831665039
# [train] step: 1000, loss: 0.38055431842803955, accuracy: 80.97492218017578
# [train] step: 1200, loss: 0.37459635734558105, accuracy: 81.28226470947266
# [train] step: 1400, loss: 0.3812824487686157, accuracy: 81.351806640625
# [train] step: 1600, loss: 0.366107702255249, accuracy: 81.67510986328125
# [train] step: 1800, loss: 0.3646901547908783, accuracy: 81.75855255126953
# [train] step: 2000, loss: 0.3648184537887573, accuracy: 81.75786590576172
# [train] step: 2200, loss: 0.35771697759628296, accuracy: 82.1707763671875
# [train] step: 2400, loss: 0.3564577102661133, accuracy: 82.21227264404297
# [train] step: 2600, loss: 0.3553982973098755, accuracy: 82.25535583496094
# [train] step: 2800, loss: 0.3551204204559326, accuracy: 82.25627136230469
# [train] step: 3000, loss: 0.35472187399864197, accuracy: 82.2571792602539
# [train] step: 3200, loss: 0.35463690757751465, accuracy: 82.25672912597656
# [train] step: 3400, loss: 0.35420480370521545, accuracy: 82.25741577148438
# [train] step: 3600, loss: 0.35807478427886963, accuracy: 82.2020034790039
# [train] step: 3800, loss: 0.35097551345825195, accuracy: 82.39557647705078
# [train] step: 4000, loss: 0.3513658940792084, accuracy: 82.39649200439453
# [train] step: 4200, loss: 0.3515341877937317, accuracy: 82.39580535888672
# [train] step: 4400, loss: 0.35152217745780945, accuracy: 82.39762878417969
# [train] step: 4600, loss: 0.35154837369918823, accuracy: 82.39649200439453
# [train] step: 4800, loss: 0.3515065014362335, accuracy: 82.40766143798828
# [train] step: 5000, loss: 0.3419311046600342, accuracy: 83.08686828613281
# ...
# [train] step: 19200, loss: 0.21814590692520142, accuracy: 90.15048217773438
# [train] step: 19400, loss: 0.2181217074394226, accuracy: 90.15048217773438
# [train] step: 19600, loss: 0.21812711656093597, accuracy: 90.15048217773438
# [train] step: 19800, loss: 0.21816301345825195, accuracy: 90.14934539794922


# logits: [[-0.12943102 -0.18300387 -0.31407255 -0.26985747 -0.03317889  0.8657205  -0.06745078]
#          [-0.3108766  -0.11808072 -0.00877056 -0.12999986 -0.18233624 -0.03913596   0.86498916]]
# probs: [[0.11739378 0.11127014 0.09760144 0.10201374 0.12925485 0.31756592  0.1249001 ]
#         [0.09565745 0.11599759 0.12939627 0.1146232  0.10877853 0.12552617  0.31002077]]


In [67]:
# Compare action counts to action probabilities.
for i in range(10):
    state = batch_full['state'][i]
    counts = Counter(a.item() for s,a in zip(batch_full['state'], batch_full['action']) if (s == state).all())
    probs = state_action_model.probs(jnp.array([state]))
    print(i, jnp.array_str(probs * 100, precision=4, max_line_width=100, suppress_small=True), counts)

## No training.

# 0 [[16.0473 15.4123 13.6238 13.4007 12.6302 15.335  13.5506]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
# 1 [[29.2187 19.2536 10.1736  7.9687  5.0374 18.8691  9.4789]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
# 2 [[39.9711 18.2044  6.5298  5.4635  2.1863 20.3188  7.3262]] Counter({4: 1, 3: 1, 1: 1})
# 3 [[48.2277 17.6604  3.9698  3.1657  1.1044 21.3496  4.5223]] Counter({5: 1})
# 4 [[52.5742 19.8617  1.9912  1.4634  0.372  21.7475  1.9901]] Counter({5: 1})
# 5 [[59.2366 18.4123  1.1171  0.8116  0.1682 19.1115  1.1429]] Counter({7: 1})
# 6 [[62.148  21.6044  0.7341  0.471   0.061  14.4665  0.515 ]] Counter({1: 1})
# 7 [[63.5058 21.5263  0.4813  0.303   0.0317 13.7824  0.3696]] Counter({7: 1})
# 8 [[76.6878 13.1025  0.1152  0.0922  0.006   9.8978  0.0985]] Counter({1: 1})
# 9 [[79.2133 11.9533  0.0701  0.0548  0.0032  8.6434  0.0619]] Counter({1: 1})

0 [[16.0473 15.4123 13.6238 13.4007 12.6302 15.335  13.5506]] Counter({1: 23, 5: 16, 4: 14, 6: 13, 3: 12, 2: 12, 7: 10})
1 [[29.2187 19.2536 10.1736  7.9687  5.0374 18.8691  9.4789]] Counter({7: 3, 6: 3, 3: 3, 5: 2, 1: 1, 2: 1})
2 [[39.9711 18.2044  6.5298  5.4635  2.1863 20.3188  7.3262]] Counter({4: 1, 3: 1, 1: 1})
3 [[48.2277 17.6604  3.9698  3.1657  1.1044 21.3496  4.5223]] Counter({5: 1})
4 [[52.5742 19.8617  1.9912  1.4634  0.372  21.7475  1.9901]] Counter({5: 1})
5 [[59.2366 18.4123  1.1171  0.8116  0.1682 19.1115  1.1429]] Counter({7: 1})
6 [[62.148  21.6044  0.7341  0.471   0.061  14.4665  0.515 ]] Counter({1: 1})
7 [[63.5058 21.5263  0.4813  0.303   0.0317 13.7824  0.3696]] Counter({7: 1})
8 [[76.6878 13.1025  0.1152  0.0922  0.006   9.8978  0.0985]] Counter({1: 1})
9 [[79.2133 11.9533  0.0701  0.0548  0.0032  8.6434  0.0619]] Counter({1: 1})
