In [1]:
import torch
from bertviz import head_view, model_view

from transformer_implementation import Transformer, Tokenizer, TransformerConfig

## Init

In [2]:
# init tokenizer
tokenizer = Tokenizer()

In [3]:
# init config
config = TransformerConfig(
    tokenizer,
    block_size = 256,
    batch_size = 12,
    n_layer = 3, # 6,
    n_head = 8,
    # n_embd = 512,
    max_iters = 2000,
    eval_iters = 50,
    eval_interval = 100,
    visualize = True,
)
print(config)

TransformerConfig(
	self.tokenizer=<transformer_implementation.Tokenizer.Tokenizer object at 0x0000020113801050>,
	self.block_size=256,
	self.batch_size=12,
	self.n_layer=3,
	self.n_head=8,
	self.n_embd=256,
	self.dropout=0.1,
	self.bias=False,
	self.device='cuda',
	self.learning_rate=0.0003,
	self.max_iters=2000,
	self.eval_interval=100,
	self.eval_iters=50,
	self.visualize=True,
)


In [4]:
# Create model
model = Transformer(config)
model.load_model("./out/transformer-train.pth")
model.eval()
model = model.to(config.device)

Number of Encoder parameters: 28.03M
number of Decoder parameters: 28.82M
Total number of parameters: 56.85M


In [21]:
def translate(sentences, tokenizer, model, config):
    """
    This function tokenizes input sentences, translates them using the provided model,
    and decodes the output into human-readable text. It also returns the attention dictionary from the model.

    Args:
        - sentences (list[str]): List of sentences to be translated.
        - tokenizer (Tokenizer): Tokenizer used for encoding and decoding sequences.
        - model (Transformer): The model used for translation.
        - config (Config): The configuration object that defines parameters like block_size.

    Returns:
        - decode_output (list[str]): List of translated sentences.
        - attn (dict): Dictionary containing attention information from the last layer of the model.
    """
    # Tokenize sentences
    tknzr = tokenizer.encoder
    sequences = []
    masks =  []

    # Encode each sentence and add it to the list of sequences
    for sentence in sentences:
        sequence = tokenizer.sequence_padding(tknzr.encode(sentence), config.block_size).unsqueeze(dim=0)
        mask = tokenizer.generate_padding_mask(sequence)
        sequences.append(sequence)
        masks.append(mask)

    # Concatenate the sequences into a tensor
    sequences = torch.cat(sequences, dim=0)
    masks = torch.cat(masks, dim=0)

    # Set the model to evaluation mode and translate sentences
    model.eval()
    outputs, attn = model.translate_beam_search(
        sequences.to(config.device),
        top_k=200,
        temperature=0.75,
        src_mask=masks.to(config.device)
    )

    # Initialize a list to store the decoded sentences
    decode_output = []
    print( outputs)
    # Decode each output sequence and add it to the list of decoded outputs
    for output in outputs:
        output = tokenizer.sequence_cleaner(output)
        decode_output += [tknzr.decode(output)]

    # Return the decoded sentences and the attention dictionary
    return decode_output, attn

In [27]:
input = ['Should the judiciary be deprived of power?']
# expected_output = ['Je suis un professeur.']

In [28]:
outputs, attentions = translate(input, tokenizer, model, config)

199/256tensor([[100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264, 100264,
         100264, 1002

In [29]:
outputs

[' ainsi ainsiur ( ainsiicic cetteH au développement..']

In [30]:
def format_attn(input, output, attentions, batch: int = 0):
    """
    This function formats the attention outputs and tokenized inputs and outputs for easier interpretation and visualization.

    Args:
        - input (str): The original input sentence.
        - output (str): The translated output sentence.
        - attentions (dict): A dictionary containing the attention information from the model.
        - batch (int, optional): The batch index to format. Defaults to 0.

    Returns:
        - tokens_input (list[str]): The tokenized input sentence, padded to max_len.
        - tokens_output (list[str]): The tokenized output sentence, padded to max_len.
        - tensor_encoder_attn (torch.Tensor): The attention tensor for the encoder, trimmed and reshaped.
        - tensor_cross_attn (torch.Tensor): The cross-attention tensor, trimmed and reshaped.
        - tensor_decoder_attn (torch.Tensor): The attention tensor for the decoder, trimmed and reshaped.
    """

    # Stack the attention tensors along a new dimension
    tensor_encoder_attn = torch.stack(attentions['encoder_attn'], dim=0)
    tensor_cross_attn = torch.stack(attentions['cross_attn'], dim=0)
    tensor_decoder_attn = torch.stack(attentions['decoder_attn'], dim=0)

    # Tokenize the input and output sentences
    tokens_input = tokenizer.tokenize_from_str(input[batch])
    tokens_output = tokenizer.tokenize_from_str(output[batch])

    # Find the maximum length of the input and output tokens
    max_len = min(len(tokens_input), len(tokens_output))

    # If the input tokens are shorter than the max length, pad with empty strings
    if len(tokens_input) < max_len:
        tokens_input = tokens_input + [''] * (max_len - len(tokens_input))
    # Otherwise, pad the output tokens with empty strings
    else:
        tokens_output = tokens_output + [''] * (max_len - len(tokens_output))

    # Trim and reshape the attention tensors
    tensor_encoder_attn = tensor_encoder_attn[:, batch:batch+1, :, 0:max_len, 0:max_len] # layers, batch, heads, seq_len, seq_len
    tensor_cross_attn = tensor_cross_attn[:, batch:batch+1, :, 0:max_len, 0:max_len] # layers, batch, heads, seq_len, seq_len
    tensor_decoder_attn = tensor_decoder_attn[:, batch:batch+1, :, 0:max_len, 0:max_len] # layers, batch, heads, seq_len, seq_len

    # Return the formatted tokens and attention tensors
    return tokens_input, tokens_output, tensor_encoder_attn, tensor_cross_attn, tensor_decoder_attn

In [31]:
tokens_input,\
tokens_output,\
tensor_encoder_attn,\
tensor_cross_attn,\
tensor_decoder_attn = format_attn(input, outputs, attentions)

In [32]:
html_model_view = model_view(
    encoder_attention=tensor_encoder_attn,
    decoder_attention=tensor_decoder_attn,
    cross_attention=tensor_cross_attn,
    encoder_tokens=tokens_input[0:tensor_decoder_attn.size(-1)],
    decoder_tokens=tokens_output[0:tensor_decoder_attn.size(-1)],
    html_action='return'
)
with open("./out/model_view.html", 'w') as file:
    file.write(html_model_view.data)

In [33]:
html_head_view = head_view(
    encoder_attention=tensor_encoder_attn,
    decoder_attention=tensor_decoder_attn,
    cross_attention=tensor_cross_attn,
    encoder_tokens=tokens_input[0:tensor_decoder_attn.size(-1)],
    decoder_tokens=tokens_output[0:tensor_decoder_attn.size(-1)],
    html_action='return'
)
with open("./out/head_view.html", 'w') as file:
    file.write(html_head_view.data)