# mT5 sandbox

In [1]:
from transformers import MT5ForConditionalGeneration, MT5TokenizerFast
from transformers import DataCollatorForSeq2Seq
import torch
import torch.optim as optim

In [81]:
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = MT5TokenizerFast.from_pretrained("google/mt5-small")
model.to("cuda")
optimizer = optim.AdamW(model.parameters(), lr=1e-5)



In [3]:
collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding="longest")

In [4]:
text = ["那是一隻狗", "鳳凰台上鳳凰遊"]
batch = []
for txt in text:
    toks = tokenizer(txt, return_tensors="pt")  
    with tokenizer.as_target_tokenizer():
        label_ids = tokenizer(txt, return_tensors="pt")
    batch.append({
        "input_ids": toks["input_ids"][0], 
        "attention_mask": toks["attention_mask"][0],
        "labels": label_ids["input_ids"][0]
    })

['▁', '這是', '一個', '故事', '</s>']

In [5]:
batch = collator(batch).to("cuda")

{'input_ids': tensor([[   259, 222077,  71285,  66204,      1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [None]:
for idx in range(1000):
    optimizer.zero_grad()
    encode_out = model.encoder(input_ids=batch["input_ids"],
                               attention_mask=batch["attention_mask"], 
                               output_hidden_states=True)
    # encode_out.last_hidden_state = torch.vstack(encode_out[1][-4:]).mean(axis=0, keepdim=True)
    encode_out.last_hidden_state = encode_out.last_hidden_state[:,2,:].unsqueeze(1)
    decode_out = model(inputs_embeds=encode_out.last_hidden_state,
                       decoder_input_ids=batch["decoder_input_ids"],
                       labels=batch["labels"],
                       encoder_outputs=encode_out)
    if idx % 100 == 0:
        print(idx, decode_out.loss.item())
        print(tokenizer.batch_decode(decode_out.logits.argmax(2)))    
    decode_out.loss.backward()
    optimizer.step()