<a href="https://colab.research.google.com/github/ostix360/ai-research/blob/main/encoder_to_decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers peft torch
!pip install -U git+https://github.com/huggingface/datasets

In [None]:
from torch import nn
import torch
from transformers import AutoModel, AutoModelForCausalLM
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

class EncDec(nn.Module):
    def __init__(self, enc_model: str, dec_model: str) -> None:
        super().__init__()
        self.encoder: BertModel = AutoModel.from_pretrained(enc_model)
        self.decoder: GPT2Model = AutoModelForCausalLM.from_pretrained(dec_model)
        self.adapter = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
        self.decoder.wte = None # Remove token embeddings from decoder

    def forward(self, input_ids, attention_mask,):
        # Pass input through encoder
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # Adapter brings the encoder outputs to the correct dimension for the decoder
        encoder_hidden_states = self.adapter(encoder_outputs.last_hidden_state)
        
        # Pass adapter outputs and decoder_input_ids to the decoder
        # In this case, "encoder_hidden_states" will be used as cross-attention "encoder_attention_mask"
        # You have to manage them according to your use-case
        decoder_outputs = self.decoder(inputs_embeds=encoder_hidden_states)
        return decoder_outputs

enc_model = "bert-base-cased"
dec_model = "gpt2"
model = EncDec(enc_model, dec_model)

In [None]:
import datasets
dataset = datasets.load_dataset("wikipedia", "20220301.simple", split="train[:10]")

In [None]:
from transformers import AutoTokenizer, Seq2SeqTrainer, Seq2SeqTrainingArguments


tokenizer = AutoTokenizer.from_pretrained(enc_model)

cutoff_len = 1024

def tokenize_func(data):
  return tokenizer(data["text"], text_target=data["text"], truncation=True, max_length=cutoff_len)

tokenized_datasets = dataset.map(tokenize_func, batched=True, batch_size=1)

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    evaluation_strategy="epoch",
    logging_steps=2,
    save_steps=2,
    eval_steps=2,
    warmup_steps=2,
    save_total_limit=1,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets,
    tokenizer=tokenizer,
)

def freeze_params(model):
    for param in model.parameters():
        param.requires_grad = False

freeze_params(model.encoder)
freeze_params(model.decoder)

nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
nb_trainable_params
