In [None]:
# Copyright 2025 Victor Semionov
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
import os
from inspect import signature

import numpy as np
import seaborn as sns
import yaml

from xlabml.datamodules import XLabDataModule
from xlabml import ROOT_DIR, CONF_DIR

sns.set_theme()

In [None]:
os.chdir(ROOT_DIR)

In [None]:
with open(CONF_DIR / 'defaults.yaml') as f:
    config = yaml.safe_load(f)
data_kwargs = {}
for parameter in signature(XLabDataModule.__init__).parameters:
    if parameter in config.get('data', {}):
        data_kwargs[parameter] = config['data'][parameter]

In [None]:
data_kwargs.update(dict(
    concatenate=True,
    pad_incomplete=False,
))

In [None]:
datamodule = XLabDataModule(**data_kwargs)
datamodule.prepare_data()

In [None]:
sequence_dataset = datamodule.datasets['train']

In [None]:
# np.random.seed(42)

In [None]:
def print_random_attention():
    # https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal
    class bcolors:
        HEADER = '\033[95m'
        OKBLUE = '\033[94m'
        OKCYAN = '\033[96m'
        OKGREEN = '\033[92m'
        WARNING = '\033[93m'
        FAIL = '\033[91m'
        ENDC = '\033[0m'
        BOLD = '\033[1m'
        UNDERLINE = '\033[4m'

    seq_idx = np.random.randint(len(sequence_dataset))
    indices, _, mask = sequence_dataset[seq_idx]
    word_idx = np.random.randint(len(indices))
    word_attention = mask[word_idx]

    tokenizer = datamodule.tokenizer
    tokens = [tokenizer.get_token(int(index)) for index in indices]

    hl_tokens = []
    for i, token in enumerate(tokens):
        if tokenizer.processor.is_byte(int(indices[i])):
            hl_tokens.append(tokenizer.decode([int(indices[i])]))
        else:
            token = f'{bcolors.FAIL}{token}{bcolors.ENDC}' if i == word_idx else token
            hl_tokens.append(f'{bcolors.UNDERLINE}{token}{bcolors.ENDC}' if word_attention[i] else token)
    print(''.join(hl_tokens).replace('▁', ' '))

In [None]:
print_random_attention()