# How to save the t5 Encoder Decoder separately 

In [1]:
from transformers import T5Tokenizer, T5Model, T5Config
import torch

In [2]:
model = T5Model.from_pretrained('t5-small')

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model

T5Model(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseReluDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Dropout(p=0.1, inplace

In [4]:
PATH = "my-model/"

## Save Encoder

In [5]:
torch.save(model.encoder.state_dict(), PATH + "encoder")

In [6]:
model.encoder.load_state_dict(torch.load(PATH + "encoder"))

<All keys matched successfully>

## Save Decoder

In [7]:
torch.save(model.decoder.state_dict(), PATH + "decoder")

In [8]:
model.decoder.load_state_dict(torch.load(PATH + "decoder"))

<All keys matched successfully>

# T5 Encoder Output

In [1]:
from pathlib import Path
import torch
import re
import time

In [2]:
BATCH_SIZE = 16

SHUFFEL_SIZE = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

learning_rate = 3e-5

model_size = "t5-small"

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained(model_size)
model = T5ForConditionalGeneration.from_pretrained(model_size).to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

In [4]:
task_specific_params["translation_en_to_de"] = {'early_stopping': True,
  'max_length': 500,
  'num_beams': 4,
  'prefix': 'translate English to German: '}

task_specific_params["translation_de_to_en"] = {'early_stopping': True,
  'max_length': 500,
  'num_beams': 4,
  'prefix': 'translate German to English: '}

In [5]:
def read_files(name):
    article_path = "../data/%s/article" % name
    highlights_path = "../data/%s/highlights" % name
    
    articles = [x.rstrip() for x in open(article_path).readlines()]
    highlights = [x.rstrip() for x in open(highlights_path).readlines()]
    
    assert len(articles) == len(highlights)
    return articles, highlights

In [6]:
model.config.prefix

'summarize: '

In [7]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(model.config.prefix + self.transfrom(self.x[index]), max_length=512, return_tensors="pt", padding='max_length', truncation=True)
        y = tokenizer.encode(self.transfrom(self.y[index]), max_length=150, return_tensors="pt", padding='max_length',  truncation=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [8]:
def get_dataset(name):
    articles, highlights = read_files(name)
    return MyDataset(articles, highlights)

In [9]:
train_ds = get_dataset("train")
test_ds = get_dataset("test")
val_ds = get_dataset("val")

In [10]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)

In [11]:
pad_token_id = tokenizer.pad_token_id
def step(input_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    print(input_ids.device, attention_mask.device, y_ids.device, lm_labels.device)
#     output = model.my_forward(
#         inputs_ids, 
#         attention_mask=attention_mask, 
#         decoder_input_ids=y_ids, 
#         decoder_input_ids_translate=y_ids, 
#         labels=lm_labels)
#     output = model(
#         inputs_ids, 
#         attention_mask=attention_mask, 
#         decoder_input_ids=y_ids,  
#         labels=lm_labels)
    output = model.encoder(inputs_ids, attention_mask=attention_mask)
    encoder_hiddenstate = output[0]
    print(encoder_hiddenstate)
    print(encoder_hiddenstate.shape)
    return encoder_hiddenstate # output[0] # loss

In [12]:
EPOCHS = 1
log_interval = 200
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
        inputs_ids = inputs_ids.to(device)
        attention_mask = attention_mask.to(device)
        y = y.to(device)
        
        
        optimizer.zero_grad()
        loss = step(inputs_ids, attention_mask, y)
        
#         train_loss.append(loss.item())
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
#         optimizer.step()
        break

cuda:0 cuda:0 cuda:0 cuda:0
tensor([[[ 3.5751e-03,  1.8350e-01,  0.0000e+00,  ...,  2.1707e-02,
          -1.8738e-01, -9.7750e-02],
         [ 7.4263e-03,  2.1323e-01, -3.2261e-02,  ...,  5.0878e-02,
          -1.6820e-01, -1.3898e-01],
         [ 6.6919e-02,  2.1176e-01, -4.5208e-02,  ...,  1.3854e-02,
          -1.6082e-01, -9.4041e-02],
         ...,
         [ 1.1192e-01,  4.4799e-02,  0.0000e+00,  ..., -2.0247e-01,
           7.3583e-02, -6.3583e-02],
         [ 1.2644e-01, -1.4333e-02,  4.4274e-02,  ..., -1.4826e-01,
          -3.8766e-02, -1.3667e-01],
         [ 1.7332e-01,  2.6289e-02,  1.9515e-01,  ..., -2.4653e-01,
           8.8466e-02,  0.0000e+00]],

        [[-4.5673e-03,  0.0000e+00,  0.0000e+00,  ...,  7.8243e-02,
          -1.8001e-01, -7.1644e-02],
         [-1.5151e-02,  1.9559e-01, -1.3489e-01,  ...,  8.7894e-02,
           0.0000e+00, -1.0954e-01],
         [ 5.3341e-02,  2.1092e-01, -1.5833e-01,  ...,  8.0666e-02,
          -1.3503e-01, -1.0059e-01],
         ..

## Now we will try to replace the original Forward Method 