<a href="https://colab.research.google.com/github/ostix360/ai-research/blob/main/encoder_to_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets apache_beam peft torch

In [2]:
from torch import nn
import torch
from transformers import AutoModel, AutoModelForCausalLM
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

class EncDec(nn.Module):
    def __init__(self, enc_model: str, dec_model: str) -> None:
        super().__init__()
        self.encoder: BertModel = AutoModel.from_pretrained(enc_model)
        self.decoder: GPT2Model = AutoModelForCausalLM.from_pretrained(dec_model)
        self.adapter = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
        self.decoder.wpe = None # Remove position embeddings from decoder
        self.decoder.wte = None # Remove token embeddings from decoder

    def forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask):
        # Pass input through encoder
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Adapter brings the encoder outputs to the correct dimension for the decoder
        encoder_hidden_states = self.adapter(encoder_outputs.last_hidden_state)
        
        # Pass adapter outputs and decoder_input_ids to the decoder
        # In this case, "encoder_hidden_states" will be used as cross-attention "encoder_attention_mask"
        # You have to manage them according to your use-case
        decoder_outputs = self.decoder(input_ids=decoder_input_ids,
                                       attention_mask=decoder_attention_mask,
                                       encoder_hidden_states=encoder_hidden_states)
        return decoder_outputs

enc_model = "bert-base-cased"
dec_model = "gpt2"
model = EncDec(enc_model, dec_model)

Downloading (…)lve/main/config.json: 100%|██████████| 570/570 [00:00<00:00, 2.20MB/s]
Downloading model.safetensors: 100%|██████████| 436M/436M [00:04<00:00, 92.7MB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 665/665 [00:00<00:00, 1.84MB/s]
Downloading model.safetensors: 100%|██████████| 548M/548M [00:06<00:00, 86.2MB/s]


: 

In [None]:
import datasets
dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[:10]")


In [None]:
def encode(text, add_bos_token):
    result = tokenizer.encode(text, truncation=True, max_length=cutoff_len)
    # Check if the first two tokens are BOS
    if len(result) >= 2 and result[:2] == [tokenizer.bos_token_id, tokenizer.bos_token_id]:
        result = result[1:]

    if not add_bos_token and result[0] == tokenizer.bos_token_id:
        result = result[1:]
    return result

def tokenize(prompt, append_eos_token=False):
    input_ids = encode(prompt, True)

    if append_eos_token and input_ids[-1] != tokenizer.eos_token_id and len(input_ids) < cutoff_len:
        input_ids.append(tokenizer.eos_token_id)

    input_ids = [tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
    labels = [1] * len(input_ids)

    input_ids = torch.tensor(input_ids)
    return {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": input_ids.ne(tokenizer.pad_token_id),
    }


In [None]:
from transformers import AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments


tokenizer = AutoTokenizer.from_pretrained(enc_model)

cutoff_len = 1024

tokenized_datasets = dataset.map(tokenize, batched=True, batch_size=8)

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    evaluation_strategy="epoch",
    logging_steps=2,
    save_steps=2,
    eval_steps=2,
    warmup_steps=2,
    save_total_limit=1,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    tokenizer=tokenizer,
)

def freeze_params(model):
    for param in model.parameters():
        param.requires_grad = False

freeze_params(model.encoder)
freeze_params(model.decoder)

nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)