<a href="https://colab.research.google.com/github/ostix360/ai-research/blob/main/encoder_to_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers peft torch
!pip install -U git+https://github.com/huggingface/datasets

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Collecting git+https://github.com/huggingface/datasets
  Cloning https://github.com/huggingface/datasets to c:\users\grenouillon\appdata\local\temp\pip-req-build-xvi5_q3h
  Resolved https://github.com/huggingface/datasets to commit c65315e4a8308f04fcb025039afe2a2e43b5684e
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/datasets 'C:\Users\grenouillon\AppData\Local\Temp\pip-req-build-xvi5_q3h'


In [4]:
from torch import nn
import torch
from transformers import AutoModel, AutoModelForCausalLM, GPT2LMHeadModel
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

class EncDec(nn.Module):
    def __init__(self, enc_model: str, dec_model: str) -> None:
        super().__init__()
        self.encoder: BertModel = AutoModel.from_pretrained(enc_model)
        self.decoder: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(dec_model, add_cross_attention=True)
        self.adapter = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)

    def forward(self, input_ids, attention_mask, labels=None):
        # Pass input through encoder
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Adapter brings the encoder outputs to the correct dimension for the decoder
        encoder_hidden_states = self.adapter(encoder_outputs.last_hidden_state)
        
        # Pass adapter outputs and decoder_input_ids to the decoder
        # In this case, "encoder_hidden_states" will be used as cross-attention "encoder_attention_mask"
        # You have to manage them according to your use-case
        decoder_outputs = self.decoder(input_ids=input_ids, encoder_hidden_states=encoder_hidden_states, labels=labels)
        return decoder_outputs
        
    
    def _get_name(self):
        return f"{self.decoder._get_name()}"

enc_model = "bert-base-cased"
dec_model = "gpt2"
model = EncDec(enc_model, dec_model)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.11.crossattention.c_proj.bias', 'h.6.crossattention.c_attn.bias', 'h.0.cr

In [5]:
import datasets
from transformers import AutoTokenizer
t_dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[:1000]")
e_dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[-100:]")

tokenizer = AutoTokenizer.from_pretrained(enc_model)

cutoff_len = 512

tokenizer(t_dataset["text"][0], text_target=t_dataset["text"][0], truncation=True, max_length=cutoff_len)

{'input_ids': [101, 1364, 1110, 1103, 2223, 2370, 1104, 1103, 1214, 1107, 1103, 5916, 1105, 18123, 1811, 8729, 1116, 117, 1105, 2502, 1206, 1345, 1105, 1318, 119, 1135, 1110, 1141, 1104, 1300, 1808, 1106, 1138, 1476, 1552, 119, 1364, 1579, 3471, 1113, 1103, 1269, 1285, 1104, 1989, 1112, 1351, 117, 1105, 19148, 117, 1356, 1107, 13660, 1201, 119, 1364, 1579, 3769, 1113, 1103, 1269, 1285, 1104, 1103, 1989, 1112, 1382, 119, 1364, 112, 188, 4637, 1132, 1103, 7643, 153, 4490, 1105, 11291, 119, 2098, 20665, 4793, 1110, 1103, 9883, 119, 1109, 2764, 1104, 1103, 9883, 1110, 15025, 119, 1109, 17545, 1364, 2502, 1206, 1345, 1105, 1318, 117, 1543, 1122, 1103, 2223, 2370, 1104, 1103, 1214, 119, 1135, 1145, 2502, 1148, 1107, 1103, 1214, 1149, 1104, 1103, 1300, 1808, 1115, 1138, 1476, 1552, 117, 1112, 1340, 117, 1347, 1105, 1379, 1132, 1224, 1107, 1103, 1214, 119, 1364, 3471, 1113, 1103, 1269, 1285, 1104, 1103, 1989, 1112, 1351, 1451, 1214, 1105, 1113, 1103, 1269, 1285, 1104, 1103, 1989, 1112, 1356, 1

In [None]:
from peft import get_peft_model
from transformers import AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments, TrainingArguments, Trainer, \
    DataCollatorForLanguageModeling

tokenizer = AutoTokenizer.from_pretrained(enc_model)

cutoff_len = 512

def tokenize(prompt):

    input_ids = tokenizer(prompt, truncation=True, max_length=cutoff_len)["input_ids"]
    input_ids = [tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
    labels = [1] * len(input_ids)
    input_ids = torch.tensor(input_ids)
    return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": input_ids.ne(tokenizer.pad_token_id),
        }

def tokenize_func(data):
  return tokenize(data["text"])

tokenized_datasets = t_dataset.map(tokenize_func, remove_columns=["text", "title", "id", "url"])
e_tokenized_datasets = e_dataset.map(tokenize_func, remove_columns=["text", "title", "id", "url"])


print(f"The column names are: {list(tokenized_datasets.features.keys())}")

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=2,
    evaluation_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    num_train_epochs=1,
    save_steps=5000,
    eval_steps=2000,
    warmup_steps=2,
    learning_rate=2e-5,
    save_total_limit=1,
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    eval_dataset=e_tokenized_datasets,
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

def freeze_params(model):
    for param in model.parameters():
        # train cross attention
        # check params with name
        param.requires_grad = False
    for n, p in model.named_parameters():
        if "crossattention" in n:
            p.requires_grad = True

freeze_params(model.encoder)
freeze_params(model.decoder)
nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {nb_trainable_params}")
trainer.train()

The column names are: ['input_ids', 'labels', 'attention_mask']
Number of trainable parameters: 28939008


You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
