In [324]:
from src.match3 import *
from src.utils import InputEmbedCausalTransformer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import Levenshtein
import json
import numpy as np
import os 
import pandas as pd
import random
import torch
import torch.nn.functional as F

In [3]:
MODEL_PATH = "/home/loki/projects/filler_tokens/output_dir/2024-08-14-22-matchdata-checkpoint-final/model_weights.pt"
CONFIG_FILE = "/home/loki/projects/filler_tokens/misc/llama_d384l4h6.json"

In [4]:
train_df = pd.read_csv('data/matchdata_trainset_2024-08-13.csv', header=None, names=["text"])
test_df = pd.read_csv('data/matchdata_testset_2024-08-13.csv', header=None, names=["text"])

In [314]:
train_set = Match3VectorDataset(train_df, 3, 10, 10, 'P')

validate encodings
raw input 0  339 234 230 125 222 811 686 534 369 258 P 1- 5 2- 6 3- 4 4- 5 0- 1 0- 9 0- 8 0- 9 0- 8 1- 6 1- 9 4- 6 5- 4 6- 8 1- 7 1- 9 1- 4 3- 5 4- 2 2- 0 2- 6 7- 4 2- 9 9- 8 4- 7 3- 9 6- 7 7- 6 3- 4 3- 7 5- 3 4- 0 4- 7 4- 1 4- 4 6- 9 7- 3 5- 0 9- 0 7- 1 8- 9 6- 8 8- 8 9- 8 8- 1 9- 1 A False
encoded sample 0 {'input_ids': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16), 'labels': tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,    6,   30,
           7,   31,    8,   29,    9,   30,    5,   26,    5,   34,    5,   33,
           5,   34,    5,   33,    6,   31,    6,   34,    9,   31,   10,   29,
          11,   33,    6,   32,    6,   34,    6,   29,    8,   30,    9,   27,
           7,   25,    7,   31,   12,   29,    7,  

In [5]:
test_set = Match3VectorDataset(test_df, 3, 10, 10, 'P')
print(test_set.input_dim)

validate encodings
raw input 0  433 450 421 129 107 924 489 711 540 034 P 0- 3 0- 4 3- 2 4- 0 5- 5 0- 8 7- 1 0- 7 9- 4 1- 8 3- 5 1- 7 1- 3 6- 8 1- 1 8- 9 9- 4 3- 5 4- 2 5- 5 6- 8 2- 2 2- 6 9- 5 3- 2 5- 0 6- 5 3- 0 8- 9 9- 5 4- 2 6- 8 7- 1 8- 4 4- 3 6- 3 5- 6 8- 4 9- 8 7- 9 8- 9 6- 1 7- 2 7- 7 9- 7 9- 5 A False
encoded sample 0 {'input_ids': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16), 'labels': tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,    5,   28,
           5,   29,    8,   27,    9,   25,   10,   30,    5,   33,   12,   26,
           5,   32,   14,   29,    6,   33,    8,   30,    6,   32,    6,   28,
          11,   33,    6,   26,   13,   34,   14,   29,    8,   30,    9,   27,
          10,   30,   11,   33,    7,   27,    7,  

In [6]:
test_set.tensorize_inputs_worker(pd.Series({"text":test_df.iloc[-1]}))

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]]], dtype=torch.float16)

In [7]:
def string_to_input_tensors(input_string):
    return {"input_ids":test_set.tensorize_inputs_worker({"text":pd.Series([input_string], index=['text'], name='1999')}).squeeze()}

In [8]:
string_to_input_tensors("1 P A")

{'input_ids': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16)}

In [9]:
def string_to_label_tensors(input_string):
    return {"labels":test_set.tensorize_labels_worker({"text":pd.Series([input_string], index=['text'], name='1999')})}

In [10]:
string_to_input_tensors("100 P A")

{'input_ids': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16)}

In [387]:
class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm
        self.block_output_unembedded = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        if isinstance(output, tuple):
            self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
            return output
        else:
            self.block_output_unembedded = self.unembed_matrix(self.norm(output))
            return output

    def reset_block_output(self):
        self.block_output_unembedded = None

class LlamaHelper:
    def __init__(self, config_file, model_path, test_set):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        config = AutoConfig.from_pretrained(config_file)
        model = InputEmbedCausalTransformer(AutoModelForCausalLM.from_config(config), test_set.input_dim)
        state_dict = torch.load(model_path)
        model.load_state_dict(state_dict, strict=False)
        model = model.to(self.device)
        self.model = model
        self.word_index_map = test_set.word_index_map
        self.data_len = test_set.data_len
        self.mod = test_set.mod
        self.input_dim = test_set.input_dim
        
        for i, layer in enumerate(self.model.base_model.model.layers):
            self.model.base_model.model.layers[i] = BlockOutputWrapper(layer, self.model.base_model.lm_head, self.model.base_model.model.norm)

    def decode_tensors(self, sequence):
        decoded_sequence = []
        marker_found = False

        for token in sequence:
            token = token.item()
            if token == -100:
                if not marker_found:
                    decoded_sequence.append("[MASK]")
                continue
            elif token == 0:
                decoded_sequence.append("[EOS]")
                break  # Stop decoding after EOS
            elif token < len(self.word_index_map):
                # Regular word
                word = list(self.word_index_map.keys())[list(self.word_index_map.values()).index(token)]
                decoded_sequence.append(word)
                if word in ["A", "P"]:
                    marker_found = True
            else:
                # Handle digit labels
                offset = len(self.word_index_map)
                if token < offset + self.data_len * 2:
                    # Tuple index encoding
                    idx = (token - offset) % self.data_len
                    tuple_pos = (token - offset) // self.data_len
                    decoded_sequence.append(f"{tuple_pos}-{idx}")
                else:
                    # Single digit or digit in tuple
                    char_pos = (token - offset - self.data_len * 2) // self.mod
                    digit = (token - offset - self.data_len * 2) % self.mod
                    if char_pos == 0 or len(decoded_sequence) == 0 or not decoded_sequence[-1][-1].isdigit():
                        decoded_sequence.append(str(digit))
                    else:
                        decoded_sequence[-1] += str(digit)

        return " ".join(decoded_sequence)


    def set_add_attn_output(self, layer, add_output):
        self.model.base_model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.base_model.model.layers[layer].get_attn_activations()

    def reset_all_layers(self):
        for layer in self.model.base_model.model.layers:
            layer.reset_block_output()
            
    @staticmethod
    def get_tokens(model, layer_idx, input_ids, decode_tensors, num_layers, rank=1, device="cuda", skip_idx=None, skip_random=False):
        with torch.no_grad():
            outputs = model(input_ids.float().unsqueeze(0))
            logits = outputs.logits
        last_token_logits = logits[0, -1, :]
        if layer_idx < num_layers:
            layer = model.base_model.model.layers[layer_idx]
            if layer.block_output_unembedded is not None:
                last_token_logits = layer.block_output_unembedded[0, -1, :]

        # Get top k values and indices, where k is max(rank + 10, 100) to allow for skipping and random selection
        k = max(rank + 10, 100)
        val, idx = torch.topk(last_token_logits, k)

        # If skip_idx is provided, remove it from consideration
        if skip_idx is not None:
            mask = ~torch.isin(idx, torch.tensor(skip_idx, device=device))
            idx = idx[mask]
            val = val[mask]

        if skip_random and skip_idx is not None:
            # Randomly select from top 10 non-skipped tokens
            random_idx = torch.randint(0, min(10, len(idx)), (1,))
            selected_idx = random_idx.item()
        else:
            # Select the token at the specified rank (subtracting 1 because rank is 1-indexed)
            selected_idx = min(rank - 1, len(idx) - 1)

        token = decode_tensors(idx[selected_idx].unsqueeze(-1)).strip()
        return token, idx[selected_idx].item()


    def create_new_token_input(self, token_id):
        new_input = torch.zeros(1, self.input_dim, dtype=torch.float16)
        if token_id < len(self.word_index_map):
            new_input[0, token_id] = 1
        else:
            # Handle digit sequences
            offset = len(self.word_index_map)
            if token_id < offset + self.data_len * 2:
                # Tuple index encoding
                idx = (token_id - offset) % self.data_len
                tuple_pos = (token_id - offset) // self.data_len
                new_input[0, offset + tuple_pos * self.data_len + idx] = 1
            else:
                # Single digit or digit in tuple
                char_pos = (token_id - offset - self.data_len * 2) // self.mod
                digit = (token_id - offset - self.data_len * 2) % self.mod
                new_input[0, offset + self.data_len * 2 + char_pos * self.mod + digit] = 1
        return new_input
    

    def print_logit_progression(self, inputs,
                                max_new_tokens=len(test_set[0]['labels']),
                                layer_number=None,
                                rank=1,
                                skip_idx=None,
                                input_length=None,
                                skip_random=False):

        self.reset_all_layers()
        num_layers = len(self.model.base_model.model.layers)
        result_dict = {f"h{i}_out": [] for i in range(num_layers)}
        result_dict["h_out"] = []
        input_ids = inputs['input_ids'].to(self.device)
        if input_length:
            input_ids = input_ids[:input_length]
        generated_sequence = input_ids.clone()
        for _ in range(max_new_tokens):
            self.reset_all_layers()

            if layer_number is not None:
                if layer_number > num_layers:
                    print(f"Error: Layer number {layer_number} is out of range. Max layer is {num_layers}.")
                    return {}
                token, token_id = self.get_tokens(self.model, layer_number, generated_sequence, self.decode_tensors, num_layers, rank, self.device, skip_idx, skip_random)
                layer_name = f"h{layer_number}_out" if layer_number < num_layers else "h_out"
                result_dict[layer_name].append(token)
            else:
                for i in range(num_layers + 1):
                    token, token_id = self.get_tokens(self.model, i, generated_sequence, self.decode_tensors, num_layers, rank, self.device, skip_idx, skip_random)
                    layer_name = f"h{i}_out" if i < num_layers else "h_out"
                    result_dict[layer_name].append(token)

            if token in ["[EOS]","True","False"]: break
            try:
                new_token_input = self.create_new_token_input(token_id).to(self.device)

                generated_sequence = torch.cat([generated_sequence, new_token_input], dim=0)
            except:
                generated_sequence = torch.cat([generated_sequence], dim=0)
                break
        # Print results
        if layer_number is not None:
            layer_name = f"h{layer_number}_out" if layer_number < num_layers else "h_out"
            print(f"{rank}th highest logit for {layer_name}:")
            print(" ".join(result_dict[layer_name]))
        else:
            print(f"{rank}th highest logit:")
            for layer_name, tokens in result_dict.items():
                print(f"{layer_name:<5}: " + " ".join(tokens))
        return result_dict

    def get_layer_logits(self, inputs, layer_idx):
        self.reset_all_layers()
        num_layers = len(self.model.base_model.model.layers)
        
        if layer_idx > num_layers:
            raise ValueError(f"Error: Layer number {layer_idx} is out of range. Max layer is {num_layers}.")
        
        input_ids = inputs['input_ids'].to(self.device)
        
        with torch.no_grad():
            outputs = self.model(input_ids.float().unsqueeze(0))
            logits = outputs.logits
        
        if layer_idx < num_layers:
            layer = self.model.base_model.model.layers[layer_idx]
            if layer.block_output_unembedded is not None:
                logits = layer.block_output_unembedded
        
        return logits


In [388]:
model = LlamaHelper(CONFIG_FILE, MODEL_PATH, test_set)

In [390]:
model.print_logit_progression(test_set[0], rank=rank, layer_number=4, skip_idx=3, input_length=11, skip_random=True)

1th highest logit for h_out:
0-4 3 0-9 8 0-3 4 0-4 6 0-0 0 0-5 4 0-9 1 0-6 9 0-0 8 0-9 9 0-8 0-5 0-0 1 0-2 5 0-4 7 0-8 8 0-8 4 0-3 0 0-1 7 0-0 1 0-6 0-3 3 0-9 3 0-9 4 0-6 3 5 0-8 0 0-1 3 0-3 7 7 0-0 5 0-3 5 0-9 3 0-9 2 0-1 9 0-4 5 0-2 0-9 0-9 9 0-0 9 0-6 7 0-3 0-0 0-9 5 1 0-0 0-8 5 0-3 3 0-2 0-3 1 0-8 3 0-9 9 A 0-0 0-9 3 0-5 1 0-2 0-6 0-3 0-7 [EOS]


{'h0_out': [],
 'h1_out': [],
 'h2_out': [],
 'h3_out': [],
 'h_out': ['0-4',
  '3',
  '0-9',
  '8',
  '0-3',
  '4',
  '0-4',
  '6',
  '0-0',
  '0',
  '0-5',
  '4',
  '0-9',
  '1',
  '0-6',
  '9',
  '0-0',
  '8',
  '0-9',
  '9',
  '0-8',
  '0-5',
  '0-0',
  '1',
  '0-2',
  '5',
  '0-4',
  '7',
  '0-8',
  '8',
  '0-8',
  '4',
  '0-3',
  '0',
  '0-1',
  '7',
  '0-0',
  '1',
  '0-6',
  '0-3',
  '3',
  '0-9',
  '3',
  '0-9',
  '4',
  '0-6',
  '3',
  '5',
  '0-8',
  '0',
  '0-1',
  '3',
  '0-3',
  '7',
  '7',
  '0-0',
  '5',
  '0-3',
  '5',
  '0-9',
  '3',
  '0-9',
  '2',
  '0-1',
  '9',
  '0-4',
  '5',
  '0-2',
  '0-9',
  '0-9',
  '9',
  '0-0',
  '9',
  '0-6',
  '7',
  '0-3',
  '0-0',
  '0-9',
  '5',
  '1',
  '0-0',
  '0-8',
  '5',
  '0-3',
  '3',
  '0-2',
  '0-3',
  '1',
  '0-8',
  '3',
  '0-9',
  '9',
  'A',
  '0-0',
  '0-9',
  '3',
  '0-5',
  '1',
  '0-2',
  '0-6',
  '0-3',
  '0-7',
  '[EOS]']}

In [389]:
model.print_logit_progression(test_set[0], rank=rank, layer_number=4, skip_idx=3, input_length=11)

1th highest logit for h_out:
0-0 8 0-2 8 0-0 5 0-0 3 0-5 7 0-6 2 0-0 4 0-0 7 0-0 6 0-2 8 0-3 9 0-4 5 0-1 7 0-1 9 0-7 1 0-8 9 0-1 4 0-3 5 0-2 2 0-5 4 0-6 0 0-2 2 0-2 6 0-9 5 0-4 2 0-3 3 0-3 8 0-3 3 0-8 6 0-9 1 0-4 1 0-6 8 0-4 8 0-4 4 0-4 1 0-6 3 0-5 5 0-5 4 0-5 8 0-6 1 0-6 9 0-6 4 0-7 1 0-9 7 0-8 5 0-9 5 A False


{'h0_out': [],
 'h1_out': [],
 'h2_out': [],
 'h3_out': [],
 'h_out': ['0-0',
  '8',
  '0-2',
  '8',
  '0-0',
  '5',
  '0-0',
  '3',
  '0-5',
  '7',
  '0-6',
  '2',
  '0-0',
  '4',
  '0-0',
  '7',
  '0-0',
  '6',
  '0-2',
  '8',
  '0-3',
  '9',
  '0-4',
  '5',
  '0-1',
  '7',
  '0-1',
  '9',
  '0-7',
  '1',
  '0-8',
  '9',
  '0-1',
  '4',
  '0-3',
  '5',
  '0-2',
  '2',
  '0-5',
  '4',
  '0-6',
  '0',
  '0-2',
  '2',
  '0-2',
  '6',
  '0-9',
  '5',
  '0-4',
  '2',
  '0-3',
  '3',
  '0-3',
  '8',
  '0-3',
  '3',
  '0-8',
  '6',
  '0-9',
  '1',
  '0-4',
  '1',
  '0-6',
  '8',
  '0-4',
  '8',
  '0-4',
  '4',
  '0-4',
  '1',
  '0-6',
  '3',
  '0-5',
  '5',
  '0-5',
  '4',
  '0-5',
  '8',
  '0-6',
  '1',
  '0-6',
  '9',
  '0-6',
  '4',
  '0-7',
  '1',
  '0-9',
  '7',
  '0-8',
  '5',
  '0-9',
  '5',
  'A',
  'False']}

In [None]:
results_list = []
for rank in [1, 2, 3]:
    results_df = []
    for idx in tqdm(range(len(test_df))):
        # Without skip
        result = model.print_logit_progression(test_set[idx], rank=rank, input_length=11)
        
        # With skip
        result_skip = model.print_logit_progression(test_set[idx], rank=rank, layer_number=4, skip_idx=3, input_length=11)
        
        # With skip_random
        result_skip_random = model.print_logit_progression(test_set[idx], rank=rank, layer_number=4, skip_idx=3, input_length=11, skip_random=True)
        
        # Combine results
        for k, v in result_skip.items():
            result[k + "_skip"] = v
        
        for k, v in result_skip_random.items():
            result[k + "_skip_random"] = v
        
        results_df.append(result)
    
    results_list.append({
        'rank': rank,
        'results': results_df
    })

  0%|          | 0/2000 [00:00<?, ?it/s]

1th highest logit:
h0_out: . [EOS] A 0-0 0-0 [EOS] [EOS] [EOS] [EOS] [EOS] A [EOS] 0 0 0-2 0-9 4 4 0 9 0-0 0-0 1 1 [EOS] 2 A 0 [EOS] [EOS] 9 True 7 0-2 0-2 A A [EOS] [EOS] [EOS] A . . [EOS] A A A [EOS] [EOS] [EOS] [EOS] 3 4 1 0-0 0-0 5 9 9 [EOS] [EOS] [EOS] A 1 6 9 [EOS] [EOS] [EOS] [EOS] A 4 [EOS] 6 6 [EOS] [EOS] [EOS] [EOS] A 5 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] A 2 2 [EOS] [EOS] A
h1_out: . 6 4 4 4 4 4 4 2 2 A A A A 2 2 4 4 3 3 2 2 4 4 4 A A A A A 8 2 2 A A A A 3 8 8 A A A 6 6 6 A 8 8 9 1 1 5 A 5 8 8 9 9 A [EOS] A A 0 2 2 2 A A A A A A A 5 A A 1 0 0 5 8 8 8 A A A 5 A A [EOS] [EOS] [EOS] [EOS] [EOS] A A 0 9 A [EOS] False
h2_out: . . . 4 4 . [EOS] . . . . . . . . . True True [EOS] A . . . . . . . . . . . . . . . . . . . . . A A . . . . . . . 8 . . . . . . . . . . . A . . . . . A A . . . A . . [EOS] . A . . . . . [EOS] . . . . A [EOS] [EOS] [EOS] [EOS] [EOS] A . . . A [EOS] False
h3_out: . . . . . . . . . . . . . . . . . . . . . . 

  0%|          | 1/2000 [00:02<1:10:15,  2.11s/it]

1th highest logit for h_out:
0-4 0-7 3 0-6 6 0-6 5 0-5 8 9 0-6 1 0-8 3 0-8 5 0-4 9 0-3 5 0-3 8 3 1 0-5 4 0-9 6 0-9 0 0-1 9 0-9 1 0-6 0 0-8 9 0-0 5 0-7 5 0-4 2 0-4 9 0-5 1 0-3 0 0-9 7 0-4 9 0-7 3 0-8 3 0-3 4 0-9 0 0-0 9 0-7 7 0-4 5 0-3 4 0-0 0-9 0-2 7 0-2 0 0-0 9 0-5 5 0-1 9 0-6 4 0-1 2 0-2 7 0-1 9 0-2 6 [EOS]
1th highest logit:
h0_out: . [EOS] [EOS] 0 6 A 1 0-6 [EOS] [EOS] [EOS] [EOS] [EOS] 1 A A [EOS] [EOS] A 9 0 0 0-0 0-0 0-0 0-0 . A A A 9 1 1 1 1 A [EOS] 9 9 9 A A [EOS] [EOS] A [EOS] A A A A A 0 0 5 0-0 0-0 . 1 A [EOS] A A 1 6 6 6 6 0 [EOS] [EOS] [EOS] A 1 A A [EOS] [EOS] [EOS] [EOS] . . [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS]
h1_out: . A 2 2 8 7 6 6 8 8 9 A A A 8 8 9 9 1 1 A 8 3 3 4 5 5 5 5 5 2 2 2 A 9 A 3 3 5 5 A A A 1 1 1 A 4 4 A 2 2 9 A 2 2 2 4 A A 2 2 A 2 4 4 4 A A A A A A A A A A 9 8 8 A 0 A A 8 A A [EOS] [EOS] A A A [EOS] [EOS] [EOS] A A 2 A A [EOS] False
h2_out: . . . A . . A A . . A . A . 

  0%|          | 2/2000 [00:04<1:12:12,  2.17s/it]

1th highest logit for h_out:
0-6 0-5 0-1 9 0-9 8 0-9 3 0-7 7 0-5 8 0-6 3 0-8 3 0-5 0 0-3 1 0-8 9 0-3 0 0-2 6 0-9 7 0-5 9 0-4 1 0-1 8 0-1 8 0-2 5 0-3 5 0-9 3 0-6 3 0-5 9 0-8 7 0-4 1 0-4 5 0-4 9 0-6 4 0-4 2 0-9 2 0-0 8 0-5 1 0-5 1 0-7 8 0-6 0-8 0-9 8 0-5 8 0-5 7 0-7 7 0-9 9 0-8 2 0-4 6 0-2 7 0-7 7 0-9 3 7 1 2 0-9 9 0-3 0-6 0-6 5 0-3 [EOS]
1th highest logit:
h0_out: . [EOS] 9 4 7 2 [EOS] [EOS] 2 2 4 4 6 8 6 6 0-0 0-0 1 A [EOS] [EOS] [EOS] [EOS] 1 1 2 A A A 1 A A A A 8 6 A A A A 6 6 A A 5 A . . [EOS] [EOS] [EOS] [EOS] 8 [EOS] [EOS] [EOS] 9 [EOS] A 1 1 A 1 6 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 9 1 A A A [EOS] A 5 5 6 [EOS] [EOS] [EOS] [EOS] A A 6 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 6 8 [EOS] [EOS] False
h1_out: . A A A 2 2 A A A A 0 0 A A A A 0 0 A A A 7 0 0 A 2 2 2 2 A 8 0 A A A 6 6 6 1 1 8 4 4 9 2 2 A 2 2 1 5 A 8 A 4 4 0 A 5 A 4 A A 9 A A A A A A A A A A 2 2 A 2 2 2 A 1 1 1 1 A A 4 4 A A A A [EOS] [EOS] A A A 4 [EOS] [EOS] 5
h2_out: . A A A 0-5 0-5 . . A . . . . . . . . 

  0%|          | 3/2000 [00:06<1:07:37,  2.03s/it]

1th highest logit for h_out:
0-0 1 0-0 3 0-0 4 0-4 5 0-0 0 0-6 8 0-7 4 0-8 4 0-9 9 0-2 4 0-3 5 0-4 8 0-1 0-9 0-1 1 0-1 5 0-8 5 0-9 2 0-3 5 0-2 0 0-2 0 0-2 7 0-7 9 0-2 6 0-9 4 0-4 7 0-5 0 0-3 5 0-7 2 0-8 8 0-9 2 0-4 7 0-4 1 0-4 1 0-4 1 0-4 9 0-6 7 0-5 5 0-5 5 0-5 4 0-6 4 0-6 4 0-6 9 0-8 8 0-9 5 0-9 8 0-9 8 A True
1th highest logit for h_out:
0-8 0 0-2 9 0-1 1 0-4 8 0-6 3 0-9 9 0-6 4 0-4 7 0-5 2 0-9 0 0-4 True
1th highest logit:
h0_out: . [EOS] 9 4 4 7 9 2 6 6 1 1 6 9 A A 6 3 4 1 1 0-0 [EOS] 4 [EOS] [EOS] 5 0 0 3 [EOS] 9 9 . . 3 A 2 7 7 6 . . 6 A A 0 9 9 6 9 9 9 A 6 6 6 . A 0 0 0 6 0 A 0 9 [EOS] [EOS] [EOS] A A A A A A [EOS] [EOS] 6 0 6 [EOS] [EOS] [EOS] [EOS] 6 6 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 6 [EOS] [EOS] [EOS] False
h1_out: . 5 7 7 6 6 2 2 8 0 0 A A 2 8 0 0 0 A A A 0 0 0 0 7 7 7 7 A 8 8 8 2 2 4 4 4 9 9 1 6 6 9 3 3 A 3 3 9 4 4 1 A 4 4 8 A 7 7 0 0 A 2 A A A A A A A A 3 A 0 0 1 1 1 1 2 A A A A A A 2 2 A 2 2 [EOS] [EOS] [EOS] A A A A A [EOS] False
h2_out: . A

  0%|          | 4/2000 [00:08<1:08:28,  2.06s/it]

1th highest logit for h_out:
0-0 2 0-5 3 0-6 1 0-5 5 0-0 2 0-1 0-9 0-8 0-9 0-8 3 0-3 4 0-2 5 0-4 0-9 0-6 8 0-5 3 0-6 3 0-1 5 0-6 0-7 0-8 2 0-4 8 0-1 4 0-3 9 0-6 9 0-4 9 0-2 6 0-3 0-9 0-9 3 0-2 8 0-0 7 0-9 8 0-5 3 0-1 8 0-5 2 0-6 0-7 0-7 5 0-9 0 0-7 9 0-6 2 0-9 7 0-2 7 0-7 9 0-8 5 0-6 6 0-8 3 0-2 5 0-7 5 0-4 0-4 0-0 3 A False
1th highest logit:
h0_out: . [EOS] 6 6 6 A 9 0 A A [EOS] [EOS] A . 8 8 5 5 2 8 0-7 0-1 0-0 0-0 0-7 3 A 3 3 7 A 2 A A A A A A A A A 9 9 9 A A A . . . . 2 A [EOS] 0-0 0-0 . . 0 2 A A 6 [EOS] 2 2 2 2 [EOS] [EOS] A . [EOS] 3 A [EOS] [EOS] 3 [EOS] 1 . [EOS] [EOS] [EOS] A [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] A
h1_out: . 2 6 6 7 7 A A 2 2 A A A A 2 2 9 9 8 8 2 2 5 5 9 9 9 9 9 9 3 3 3 2 2 A 0 0 5 5 A 8 7 8 8 8 A 8 8 A 3 6 6 A 7 7 7 A 2 2 2 2 A A A A A A A A A A 6 A A A A A 6 6 5 9 9 9 A A A A A A A A A [EOS] [EOS] A A A 9 A [EOS] False
h2_out: . . . . . . A . . . . . A . . . . . . . . . . . . . . . . . A A A . . . 

  0%|          | 5/2000 [00:10<1:08:57,  2.07s/it]

1th highest logit for h_out:
0-7 0 0-9 4 0-0 7 0-4 4 0-1 0-6 0-9 0-7 7 0-2 3 3 0-2 5 0-4 6 0-4 6 0-9 7 0-7 3 0-6 3 0-5 4 0-3 2 0-4 7 0-5 1 0-6 9 0-7 4 0-7 6 0-9 9 0-4 6 0-0 2 0-1 3 0-3 5 0-7 3 0-1 7 0-8 3 0-7 0 0-6 7 0-3 8 0-1 9 0-6 8 0-7 7 0-1 1 0-8 3 0-7 8 0-7 0 0-3 6 0-7 8 0-5 6 0-7 7 0-3 9 0-1 4 0-9 3 9 [EOS]
1th highest logit:
h0_out: 0-0 [EOS] [EOS] 2 0 [EOS] 9 4 8 8 [EOS] [EOS] [EOS] 9 0-0 0-0 0-0 0-0 9 9 A A 0-0 0-0 0-0 A . . . . 6 A A 1 1 [EOS] [EOS] [EOS] [EOS] [EOS] 9 . . 9 7 [EOS] . . . . 4 . A 0-2 0-0 0-0 . . A 5 3 . 4 8 . 2 2 A [EOS] [EOS] [EOS] [EOS] [EOS] A A [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 0
h1_out: . 6 2 2 9 9 8 8 9 9 A A A 8 9 9 9 9 5 5 6 9 9 9 8 9 9 9 9 9 0 2 2 6 6 2 2 2 2 2 A 3 8 A A A A 8 8 9 2 2 A A 9 9 9 6 8 A A A A 1 A A A A A A A A A A 2 2 A A 8 8 9 8 8 A A A A A A A A A A A A A A A A A [EOS] False
h2_out: . . A A . . . . 9 . 0-9 6 A . . 

  0%|          | 6/2000 [00:12<1:09:12,  2.08s/it]

1th highest logit for h_out:
0-5 3 0-5 2 0-2 7 0-0 5 0-2 5 0-0 5 0-5 2 0-5 0-8 0-9 8 0-2 4 0-0 2 0-2 8 0-0 1 0-1 3 0-5 8 0-9 6 0-2 4 0-9 7 0-7 2 0-6 3 0-8 0 0-5 7 0-5 8 0-1 0-5 0-1 4 0-3 0-0 0-9 0-7 0-2 2 0-9 0 0-4 8 0-2 1 0-3 7 0-2 4 0-1 0-1 0-6 0-6 0-2 3 0-5 6 0-3 1 0-7 0 0-7 6 0-0 6 0-2 0-9 0-1 0-8 0-4 0 0-0 6 0-2 0-8 [EOS]
1th highest logit:
h0_out: 1-6 A A [EOS] 4 A 5 0 3 3 [EOS] [EOS] 2 6 8 4 1 1 A A 5 5 0-9 0-9 5 5 A [EOS] [EOS] [EOS] 2 6 6 A A A 4 A A A A . . 5 [EOS] [EOS] [EOS] 9 [EOS] [EOS] 1 [EOS] [EOS] 1 [EOS] [EOS] [EOS] 3 3 A 9 A 6 1 1 9 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] . . 0 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 2 [EOS] [EOS] [EOS] False
h1_out: . 4 9 9 2 2 3 3 9 9 6 6 A A 9 9 5 6 A A A 9 6 6 5 5 5 5 5 4 9 A A A 4 A 7 A A A A A 2 A 9 9 A 6 6 4 6 6 7 A 1 1 A 8 0 A 0 0 A 2 9 9 9 A A A A A A A A A A A A A A A A A 8 9 A 6 6 A A A [EOS] [EOS] [EOS] A A 2 A [EOS] [EOS] False
h2_out

  0%|          | 7/2000 [00:14<1:09:20,  2.09s/it]

1th highest logit for h_out:
0-7 9 0-7 0 0-6 9 0-9 0 0-0 0 0-8 0-9 0-3 7 0-7 6 0-5 7 0-2 6 0-9 5 0-9 2 0-0 3 0-2 0 0-3 9 0-2 6 0-2 8 0-6 8 0-0 8 0-3 1 0-8 6 0-5 9 0-5 4 0-6 2 0-2 4 0-9 9 0-4 2 0-6 0-6 0-4 5 0-1 8 0-9 2 0-9 0-9 0-0 9 0-8 1 0-0 0-4 0-4 2 0-6 1 0-3 0 0-6 2 0-0 0-8 0-0 0-6 0-1 0-3 0-0 0-0 0-9 3 0-4 6 0-1 0-9 7 [EOS]
1th highest logit:
h0_out: . [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 1 6 6 [EOS] [EOS] 8 1 2 2 True True 2 9 A 5 0-0 0-0 0-0 0-0 5 . . . . 8 8 4 . A A A A A A . . 6 6 0 . 9 9 . A . . A 9 0 . 1 1 6 3 9 [EOS] [EOS] [EOS] 1 4 [EOS] [EOS] [EOS] A 6 [EOS] [EOS] 5 [EOS] [EOS] 4 A 3 3 8 [EOS] [EOS] A A . . . [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 8 8 8 [EOS] [EOS] A
h1_out: . 0 6 6 8 8 2 2 0 0 4 4 A A 0 2 6 6 A A A 2 4 4 4 0 2 2 2 2 5 5 5 2 6 4 4 A 3 3 A 0 1 7 8 8 A 4 4 6 8 8 1 3 6 6 6 2 A A 2 2 A 0 A A A A A A A A 4 4 0 0 A A 2 2 5 A A A 8 A A 4 A A 5 A [EOS] [EOS] [EOS] A A A A A [EOS] False
h2_out: . . 6 6 . . . . . . . . . . . . . . A A . . . . . A . . . . . . . . 

  0%|          | 8/2000 [00:16<1:09:29,  2.09s/it]

1th highest logit for h_out:
0-1 4 0-2 7 0-4 2 0-4 1 0-8 0-7 0-9 0-8 9 0-7 5 0-8 4 1 0-2 2 0-6 8 0-6 2 0-2 8 0-3 4 0-6 2 0-1 2 0-7 3 0-9 4 0-8 6 0-2 6 0-0 0 0-9 9 0-2 1 0-6 5 0-0 3 0-9 2 0-7 6 0-6 1 0-0 5 0-6 3 0-7 4 0-3 8 0-4 5 0-0 0-9 0-5 6 0-6 3 0-2 0-9 0-7 4 0-0 0-9 0-6 6 0-7 8 0-8 0 0-1 2 0-5 0-5 0 7 0-9 6 7 3 [EOS]
1th highest logit:
h0_out: 0-0 [EOS] 4 6 4 9 9 9 A A A A 3 2 [EOS] [EOS] A A 0 0 [EOS] [EOS] 0-8 0-2 0-2 1 6 A A A 8 0 0 6 6 9 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 3 [EOS] 3 . [EOS] [EOS] . 3 3 8 0-2 0-2 . 6 6 [EOS] . 7 7 6 6 6 3 [EOS] [EOS] [EOS] [EOS] 5 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] . . 0 [EOS] [EOS] [EOS] [EOS] 6 A [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] A
h1_out: . 8 5 5 6 6 2 2 1 1 9 9 A A 1 1 3 3 A A 7 1 9 9 9 4 4 4 4 4 5 5 A 2 2 9 9 8 6 6 6 1 1 A A A A 6 6 A 8 8 A A 8 7 A A 7 A 2 A A 9 6 6 6 A A A A A A A 5 5 A 2 5 5 0 9 9 9 A A A 2 2 A 0 0 A [EOS] [EOS] A A A A A [EOS] False
h2_out: . . . 5 9 . A A . 

  0%|          | 9/2000 [00:18<1:10:06,  2.11s/it]

1th highest logit for h_out:
0-2 9 0-8 6 0-0 8 0-5 0-9 0-5 0 0-1 8 0-2 7 0-4 4 0-6 0-7 0-8 0 0-9 0-9 0-0 0-2 0 0-8 9 0-3 2 0-8 0-2 0-8 0-7 5 0-1 0-9 0-2 7 0-1 7 0-2 5 0-8 7 0-1 6 0-5 0 0-6 2 0-7 0 0-5 3 0-1 5 0-5 7 0-1 0-1 0-2 4 0-4 8 0-0 5 0-2 9 0-1 1 0-5 8 0-0 9 0-5 4 0-5 8 0-4 8 0-8 0 0-6 0-9 0-4 0-9 0-2 7 0-2 5 0-6 6 A False
1th highest logit:
h0_out: 0-1 [EOS] [EOS] [EOS] [EOS] A 7 0-0 A A [EOS] [EOS] 9 8 3 3 0 1 3 9 0-7 0-0 0-2 0-2 [EOS] A 1 . . . . A A 6 6 A A A 8 A A . . A A [EOS] . . . . . . . [EOS] [EOS] [EOS] [EOS] A A 6 A A A 5 3 A 4 [EOS] [EOS] [EOS] [EOS] 4 [EOS] [EOS] A [EOS] [EOS] [EOS] . 4 6 [EOS] [EOS] A A [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 2 2 [EOS] [EOS] False
h1_out: . 4 A 9 2 A 9 9 4 4 [EOS] [EOS] A 9 4 4 5 5 0 A 4 4 5 5 A 4 4 4 4 4 7 7 A 8 8 9 9 9 5 A A 6 6 1 2 2 A 3 3 2 A 4 7 A 0 2 0 8 A A A A A 7 A A A A A A A A A A 8 8 A 7 A A 9 A A A A A A 9 A A 9 A [EOS] [EOS] [EOS] A A A A A [EOS] False
h2_out: . . . . A A . . . . . . A 

  0%|          | 10/2000 [00:20<1:10:14,  2.12s/it]

1th highest logit for h_out:
0-0 5 0-2 4 0-3 3 0-7 3 0-3 8 0-5 0 0-8 1 0-0 5 0-0 8 0-7 6 0-4 4 0-4 6 0-6 0-6 0-1 8 0-5 8 0-9 7 0-6 7 0-8 6 0-7 0-9 0-9 1 0-4 1 0-9 8 0-1 8 0-2 0 0-7 4 0-4 6 0-8 8 0-9 0-9 9 0-3 2 0-7 0-5 0-4 0-0 0-7 0-0 0-6 0 7 0-9 6 0-6 7 0-9 0 0-2 0-8 0-7 6 0-0 0-7 0-8 0-8 5 0-5 7 0-7 9 0-9 3 0-8 2 7 0-9 4 7 0 9 0 0-7 9 0 0-4 [EOS]
1th highest logit:
h0_out: . [EOS] A 2 2 8 8 5 3 3 3 3 . . . . 8 8 A 9 0-2 0-2 0-0 0-0 5 5 1 3 3 3 A . A 8 8 [EOS] 0 A A A A . 7 6 6 6 . 6 0 . . 6 [EOS] 6 0-0 0-0 . . 5 9 . . 3 5 6 6 5 [EOS] [EOS] [EOS] 3 . . . . [EOS] [EOS] [EOS] 6 . 3 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 3 [EOS] [EOS] [EOS] A
h1_out: . 5 A A 7 7 3 3 9 9 6 6 A 3 9 9 6 6 6 6 9 9 3 6 2 2 2 2 2 A 5 5 A A 7 4 4 4 9 9 A A 2 7 7 7 A 4 4 A 6 7 4 4 8 8 8 A 8 8 5 5 A A A A A A A A A A A A 7 A 2 2 7 9 0 A A A 0 8 A 8 8 A A A 0 A A A A A A A [EOS] False
h2_out: . . . . . . . . . . A A . . . . A A A A . . A A . . . . . . [EOS] 

  1%|          | 11/2000 [00:23<1:09:39,  2.10s/it]

1th highest logit for h_out:
0-8 8 0-4 7 0-7 2 0-8 0-9 0-8 0-9 0-5 0 0-5 8 0-1 8 0-2 1 0-7 0-1 4 0-0 0-1 6 0-4 4 0-7 0-9 0-6 2 0-5 5 0-7 3 0-9 0 0-9 4 0-6 4 0-6 4 0-5 2 0-8 0 0-2 0 0-2 5 0-2 9 0-8 8 0-0 9 0-0 0-2 0-3 0-6 0-3 6 0-7 2 0-4 7 0-0 1 0-4 5 0-2 7 0-7 8 0-1 4 0-6 9 0-1 0 0-3 True
1th highest logit:
h0_out: . [EOS] [EOS] [EOS] [EOS] 6 A 0 1 1 [EOS] [EOS] A 7 6 6 [EOS] [EOS] 3 9 0-0 0-0 . 0-2 A A 9 A [EOS] [EOS] A 6 6 4 4 A [EOS] A 5 8 A . . 3 A [EOS] [EOS] [EOS] [EOS] . 1 2 6 [EOS] [EOS] [EOS] [EOS] . . 9 A A [EOS] [EOS] [EOS] 1 4 6 [EOS] [EOS] [EOS] [EOS] 5 1 7 [EOS] [EOS] [EOS] . . 5 [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] A [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] [EOS] 5 0 [EOS] [EOS] A
h1_out: . A A 4 0 0 2 2 8 8 2 2 A 2 8 8 2 2 2 2 A 8 2 2 A A 5 7 7 A 1 1 1 A 6 6 6 6 5 5 7 6 6 3 3 3 A 1 1 0 7 7 6 A 5 5 5 A 3 A 6 6 A 9 9 9 9 A A A A A A A 5 5 A 8 4 4 A 9 9 9 A A A A A A A A A [EOS] [EOS] A A 9 A A [EOS] False
h2_out: . A . . . . . . . . . . A . . . . . . . . 8 . . A 0-2

  1%|          | 12/2000 [00:25<1:10:14,  2.12s/it]

1th highest logit for h_out:
0-5 2 7 2 0-0 1 0-7 2 0-9 0-9 0-2 0-5 0-9 0-8 4 0-4 1 0 0-5 3 0-1 1 0-4 9 0-1 4 0-2 4 0-3 3 0-5 6 0-9 5 0-0 2 0-4 0 0-1 0 0-1 0-4 0-9 0-1 6 0-2 9 0-8 4 0-0 0 0-9 3 0-4 0-9 0-6 4 0-8 0 0-4 0-8 0-6 9 0-9 1 0-1 0-6 0-2 8 0-5 6 0-8 6 0-8 5 9 0-4 3 0-7 1 0-1 0-2 3 0-9 9 0-5 1 0-9 3 7 5 0-9 6 3 0-7 3 9 3 7 0-8 0-2 0-2 2 3 0-7 6 0-6 1 0-2 0-2 7 0-7 0-4 0-7
