<a href="https://colab.research.google.com/github/ostix360/ai-research/blob/main/encoder_to_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers peft torch
!pip install -U git+https://github.com/huggingface/datasets

In [None]:
from torch import nn
import torch
from transformers import AutoModel, AutoModelForCausalLM, GPT2LMHeadModel
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

class EncDec(nn.Module):
    def __init__(self, enc_model: str, dec_model: str) -> None:
        super().__init__()
        self.encoder: BertModel = AutoModel.from_pretrained(enc_model)
        self.decoder: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(dec_model, add_cross_attention=True)
        self.adapter = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)

    def forward(self, input_ids, attention_mask, labels=None, enc_input_ids=None):
        # Pass input through encoder
        encoder_outputs = self.encoder(input_ids=enc_input_ids, attention_mask=attention_mask)

        # Adapter brings the encoder outputs to the correct dimension for the decoder
        encoder_hidden_states = self.adapter(encoder_outputs.last_hidden_state)

        # Pass adapter outputs and decoder_input_ids to the decoder
        # In this case, "encoder_hidden_states" will be used as cross-attention "encoder_attention_mask"
        # You have to manage them according to your use-case
        # check len label and input_ids
        if labels is not None:
            if len(labels) != len(input_ids):
                print(len(labels), len(input_ids))
                raise ValueError("Input_ids and labels should have the same length")

        decoder_outputs = self.decoder(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states,
                                       labels=labels)
        return decoder_outputs
        
    
    def _get_name(self):
        return f"{self.decoder._get_name()}"

enc_model = "bert-base-uncased"
dec_model = "gpt2"
model = EncDec(enc_model, dec_model)

In [11]:
import datasets
from transformers import AutoTokenizer
t_dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[:1000]")
e_dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[-100:]")

tokenizer = AutoTokenizer.from_pretrained(enc_model)
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model)
dec_tokenizer = AutoTokenizer.from_pretrained(dec_model)
dec_tokenizer.pad_token = dec_tokenizer.eos_token

cutoff_len = 512

tokenizer(t_dataset["text"][0], text_target=t_dataset["text"][0], truncation=True, max_length=cutoff_len)
enc_input_ids = enc_tokenizer("hello", truncation=True, max_length=cutoff_len)["input_ids"]
enc_input_ids = torch.tensor(enc_input_ids)
dec_input_ids = dec_tokenizer("hello", truncation=True, max_length=cutoff_len)["input_ids"]
# dec_input_ids = [dec_tokenizer.pad_token] * (cutoff_len - len(dec_input_ids)) + dec_input_ids
labels = torch.tensor([1] * len(dec_input_ids))
dec_input_ids = torch.tensor(dec_input_ids)
model.decoder.train()
model.decoder(input_ids=dec_input_ids, labels=labels)

CausalLMOutputWithCrossAttentions(loss=tensor(nan, grad_fn=<NllLossBackward0>), logits=tensor([[-76.3027, -75.9184, -80.7080,  ..., -86.5434, -84.0965, -77.6158]],
       grad_fn=<MmBackward0>), past_key_values=((tensor([[[[-4.4836e-01,  1.9620e+00,  1.8307e-01, -1.5441e-01,  1.5597e+00,
           -3.0622e-02, -3.1441e-01,  1.0669e-01, -1.3469e+00,  1.0243e+00,
            5.8403e-01,  7.5574e-01,  3.2267e-01,  8.0160e-01, -1.0708e-01,
           -7.8357e-01, -4.3006e-01,  1.3127e+00,  2.0068e+00, -2.3264e-01,
           -5.3094e-01,  3.5875e-01,  7.4015e-03, -1.0099e+00, -3.1829e-01,
           -9.9604e-01, -1.9868e-01, -4.5447e-01, -3.8226e-01, -1.3607e+00,
            2.3592e+00, -3.3755e-01, -3.3643e-01,  9.8411e-02, -2.2562e-01,
            1.0410e-01,  9.1688e-01,  1.2518e-01,  4.4715e-02,  1.5624e+00,
            1.9267e-02,  2.8267e-02, -1.7194e+00, -4.0587e-01, -4.9830e-02,
           -3.5938e-01,  1.6348e+00, -6.2045e-01, -8.1223e-01, -4.6455e-01,
           -4.9509e-01,  6.

In [None]:
from peft import get_peft_model
from transformers import AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, TrainingArguments, Trainer, \
    DataCollatorForLanguageModeling

enc_tokenizer = AutoTokenizer.from_pretrained(enc_model)
dec_tokenizer = AutoTokenizer.from_pretrained(dec_model)
dec_tokenizer.pad_token = dec_tokenizer.eos_token

cutoff_len = 512

def tokenize(prompt):
    enc_input_ids = enc_tokenizer(prompt, truncation=True, max_length=cutoff_len)["input_ids"]
    enc_input_ids = torch.tensor(enc_input_ids)
    dec_input_ids = dec_tokenizer(prompt, truncation=True, max_length=cutoff_len)["input_ids"]
    # TODO Check if this is needed
    # dec_input_ids = [dec_tokenizer.pad_token] * (cutoff_len - len(dec_input_ids)) + dec_input_ids
    labels = [1] * len(dec_input_ids)
    dec_input_ids = torch.tensor(dec_input_ids)
    return {
        "input_ids": dec_input_ids,
        "labels": labels,
        "attention_mask": enc_input_ids.ne(enc_tokenizer.pad_token_id),
        "enc_input_ids": enc_input_ids,
    }

def tokenize_func(data):
  return tokenize(data["text"])

tokenized_datasets = t_dataset.map(tokenize_func, remove_columns=["text", "title", "id", "url"])
e_tokenized_datasets = e_dataset.map(tokenize_func, remove_columns=["text", "title", "id", "url"])


print(f"The column names are: {list(tokenized_datasets.features.keys())}")

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=1,  # Batch problems with more than 1
    per_device_eval_batch_size=1,
    evaluation_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    num_train_epochs=1,
    save_steps=5000,
    eval_steps=2000,
    warmup_steps=2,
    learning_rate=2e-5,
    save_total_limit=1,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=e_tokenized_datasets,
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

def freeze_params(model):
    for param in model.parameters():
        param.requires_grad = False
    # unfreeze attention layers
    for n, p in model.named_parameters():
        if "crossattention" in n or "c_proj" in n or "c-attn" in n:
            p.requires_grad = True

freeze_params(model.encoder)
freeze_params(model.decoder)
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {nb_trainable_params}")
# trainer.train()


# Results

| Training Type               | Number of training parameters | Batch Size | Steps eq | Validation Loss |
|-----------------------------|-------------------------------|------------|----------|-----------------|
| Full decoder training       | 153 397 248                   | 5          | 20 000   | 6.08            |
| Full model enc dec training | 261 707 520                   | 5          | 20 000   | 6.00            |
| GPT2 model training         | 124 439 808                   | 5          | 20 000   | 2.94            |

where steps eq is the number of steps times the batch size to get the number of examples seen by the model.
By adding the encoder to the decoder, we get a 2x increase in the number of parameters, but we also get a 4% decrease in validation loss when training the hole model.
And by training the decoder with encoder (freeze) we get a 3% decrease in validation loss compared to training the decoder alone.


In [None]:
from transformers import AutoTokenizer

bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
gpt2_tokenizer = AutoTokenizer.from_pretrained("gpt2")

In [None]:
bert_tokenizer("Hello, my dog is cute", truncation=True, max_length=512)

In [None]:
text = """
A or a is the first letter of the English alphabet. The small letter, a or α, is used as a lower case vowel. 

When it is spoken,  ā is said as a long a, a diphthong of ĕ and y. A is similar to  alpha of the Greek alphabet. That is not surprising, because it stands for the same sound. 

"Alpha and omega" (the last letter of the Greek alphabet) means from beginning to the end. In musical notation, the letter A is the symbol of a note in the scale, below B and above G. In binary numbers, the letter A is 01000001.
"""

In [None]:
gpt2_tokenizer(text, truncation=True, max_length=512)

# Custom Trainer

In [25]:
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, BertTokenizerFast, DataCollatorForLanguageModeling

import datasets
from transformers import AutoTokenizer
t_dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[:2]")
e_dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[-2:]")

tokenizer = AutoTokenizer.from_pretrained(enc_model)
enc_tokenizer = AutoTokenizer.from_pretrained(enc_model)
dec_tokenizer = AutoTokenizer.from_pretrained(dec_model)
dec_tokenizer.pad_token = dec_tokenizer.eos_token

def datasets_post_process(tokenized_dataset):
    tokenized_dataset.set_format(type="torch")
    # Debug
    print(tokenized_dataset.column_names)
    return tokenized_dataset

def tokenize(prompt):
    enc_input_ids = enc_tokenizer(prompt, truncation=True, max_length=cutoff_len)["input_ids"]
    enc_input_ids = torch.tensor(enc_input_ids)
    dec_input_ids = dec_tokenizer(prompt, truncation=True, max_length=cutoff_len)["input_ids"]
    # dec_input_ids = [dec_tokenizer.pad_token] * (cutoff_len - len(dec_input_ids)) + dec_input_ids
    labels = [1] * len(dec_input_ids)
    dec_input_ids = torch.tensor(dec_input_ids, dtype=torch.long)
    return {
        "input_ids": dec_input_ids,
        "labels": labels,
        "attention_mask": enc_input_ids.ne(enc_tokenizer.pad_token_id),
        "enc_input_ids": enc_input_ids,
    }

def tokenize_func(data):
  return tokenize(data["text"])

def debug_data_processing(train_dataloader):
    batch = None
    for batch in train_dataloader:
        break
    print({k: v.shape for k, v in batch.items()})
    return batch


t_tokenized_datasets = t_dataset.map(tokenize_func, remove_columns=["text", "title", "id", "url"])
e_tokenized_datasets = e_dataset.map(tokenize_func, remove_columns=["text", "title", "id", "url"])

t_final_tokenized_datasets = datasets_post_process(t_tokenized_datasets)
e_final_tokenized_datasets = datasets_post_process(e_tokenized_datasets)

data_collator = DataCollatorForLanguageModeling(tokenizer=dec_tokenizer, mlm=False)

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

['input_ids', 'labels', 'attention_mask', 'enc_input_ids']
['input_ids', 'labels', 'attention_mask', 'enc_input_ids']


In [26]:
t_final_tokenized_datasets[0]

{'input_ids': tensor([16784,   318,   262,  5544,  1227,   286,   262,   614,   287,   262,
         18322,   290,  8547, 22618, 50215,    11,   290,  2058,  1022,  2805,
           290,  1737,    13,   632,   318,   530,   286,  1440,  1933,   284,
           423,  1542,  1528,    13,   198,   198, 16784,  1464,  6140,   319,
           262,   976,  1110,   286,  1285,   355,  2901,    11,   290, 36527,
            11,  3269,   287, 16470,   812,    13,  3035,  1464,  5645,   319,
           262,   976,  1110,   286,   262,  1285,   355,  3426,    13,   198,
           198, 16784,   338, 12734,   389,   262, 15335,  2631,    64,   290,
         40355,    13,  6363,  4082,  6440,   318,   262, 15291,    13,   383,
          3616,   286,   262, 15291,   318, 24211,    13,   198,   198,   464,
         16061,   220,   198,   198, 16784,  2058,  1022,  2805,   290,  1737,
            11,  1642,   340,   262,  5544,  1227,   286,   262,   614,    13,
           632,   635,  2058,   717,   

In [27]:
train_dataloader = DataLoader(
    t_final_tokenized_datasets, batch_size=1, shuffle=True, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    e_final_tokenized_datasets, batch_size=1, collate_fn=data_collator
)
debug_data_processing(train_dataloader)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'input_ids': torch.Size([1, 512]), 'labels': torch.Size([1, 512]), 'attention_mask': torch.Size([1, 512]), 'enc_input_ids': torch.Size([1, 512])}


{'input_ids': tensor([[17908,   357, 12512,  2014,   318,   262, 16974,  1227,   286,   262,
           614,   287,   262,  8547, 22618, 11845,    11,  2406,  1022,  2901,
           290,  2693,    13,   632,   468,  3261,  1528,    13,   632,   318,
          3706,   706,   262,  7993, 23129, 48339, 24088,    13,   198,   198,
         17908,   857,   407,  2221,   319,   262,   976,  1110,   286,   262,
          1285,   355,   597,   584,  1227,   287,  2219,   812,    11,   475,
          6140,   319,   262,   976,  1110,   286,   262,  1285,   355,  3945,
           287, 16470,   812,    13,  2932,  1464,  5645,   319,   262,   976,
          1110,   286,   262,  1285,   355,  3389,    13,   198,   198,   464,
         16061,   220,   198,   198,  1212,  1227,   373,   717,  1444,  1001,
           742,   346,   271,   287,  9133,    11,   780,   340,   373,   262,
         11695,  1227,   287,   262,  1468,  7993, 11845,    13,   383,  7993,
         11845,  2540,   287,  2805,  