# Transformer Training

Requirements:
 - python 3.7+
 - pytorch
 - transformers (Hugging Face)
 - datasets (Hugging Face)
 - tqdm
 - seaborn
 - matplotlib

In [None]:
# Import cell
import os
from datasets import load_dataset
import torch
import transformers
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig,
                          DataCollatorForSeq2Seq, Seq2SeqTrainingArguments,
                          Seq2SeqTrainer, pipeline)

In [None]:

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
transformers.logging.set_verbosity_warning()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device', device)

## Translation english to french with T5

### Inference with a pretrained model - translation en-fr

In [None]:
# ...

### Fine tuning a pretrained model

In [None]:
# Load dataset (Opus Books) and split into train and validation
translation_ds = load_dataset("opus_books", "en-fr")['train']
# ...

In [None]:
# Tokenizer


# Transformer


# Data collator (that pad the sequences dynamically)


In [None]:
tokenizer(["I like Transformers.", "The movie right?"])

In [None]:
# Pre-process dataset for T5
def preprocess_ds(tokenizer, dataset, input_key, output_key, prefix='',
                  max_length=128):
    # ...
    return input_dataset

# ...

translation_ds['train'][0].keys()


In [None]:
def train_model(model):
    # ...

In [None]:
train_model(transformer)

In [None]:
# checkpoint = torch.load('results/pytorch_model.bin')
# transformer.load_state_dict(checkpoint)

In [None]:
device_num = 0 if torch.cuda.is_available() else -1
pipeline_translation = pipeline("translation_en_to_fr",
                                model=transformer,
                                tokenizer=tokenizer,
                                device=device_num)
outputs = pipeline_translation(["I like Transformers.", "The movie right?"])
output_text = [output['translation_text'] for output in outputs]
print(output_text)

### Train a model from random initialization

In [None]:
# Get base config of the needed architecture

# Changes the configuration (optional)

# Build the model



In [None]:
train_model(new_transformer)

### Attention visualization

In [None]:
def draw(data, x, y, ax):
    sns.heatmap(
        data, xticklabels=x, square=True, yticklabels=y, vmin=0.0, vmax=1.0,
        cbar=True, ax=ax, cmap="cool"
    )

def remove_underscore(tokens):
    return [token.replace(chr(9601), "") for token in tokens]

input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0],
                         skip_special_tokens=True)
input_tokens = remove_underscore(input_tokens)

output_tokens = tokenizer.convert_ids_to_tokens(
            output_with_attention.sequences[0], skip_special_tokens=True, 
        )
output_tokens = remove_underscore(output_tokens)

print(input_tokens)
print(output_tokens)
print('input len:', len(input_tokens))
print('output len', len(output_tokens))

#### Cross-attention

In [None]:
attention_raw = None  # TODO
attention = None # TODO

# attentions: (output_token, layer, sentence, head, 1, input_token)

# ...

# attentions: (layer, head, output_token, input_token)
print(attention.shape)

# mean attention over layers and heads
mean_attention = torch.mean(attention, dim=[0, 1])
print(mean_attention.shape)

normalized_mean_attention = mean_attention / torch.max(mean_attention)

# plot the heatmap of the mean attention
_, ax = plt.subplots(1, 1, figsize=(15, 15))
draw(
    normalized_mean_attention.detach().cpu().numpy(),
    input_tokens,
    output_tokens,
    ax=ax,
    )
plt.savefig("ressources/cross_attention.png", facecolor='white')
plt.show();

#### Input self-attention

#### Output self-attention

In [None]:
attention_raw = output_with_attention.decoder_attentions
n = len(attention_raw)

attention = torch.zeros((n, n), device=device)
for i in range(n):
    att = torch.stack(attention_raw[i])[:, 0, :, 0, :]
    mean_att = torch.mean(att, dim=[0, 1])
    attention[i, torch.arange(end=i+1)] = mean_att

normalized_mean_attention = attention / torch.max(attention)
