In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer
from datasets import load_dataset, Dataset

In [2]:
model_name = "google/flan-t5-large"
t5_model = T5ForConditionalGeneration.from_pretrained(model_name)
t5_tokenizer = T5Tokenizer.from_pretrained(model_name)

config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [3]:
custom_dataset = [
    {"input": "Rewrite in a formal tone: Gotta get this done ASAP!"},
    {"input": "Rewrite in a formal tone: U coming to the party?"},
    {"input": "Rewrite in a formal tone: This place is lit!"},
    {"input": "Rewrite in a formal tone: Can’t believe this dude won!"},
    {"input": "Rewrite in a formal tone: Lemme know when u free."},
    {"input": "Rewrite in a formal tone: That’s my bad."},
    {"input": "Rewrite in a formal tone: Yo, what’s up?"},
    {"input": "Rewrite in a formal tone: Gimme a sec."},
    {"input": "Rewrite in a formal tone: I ain’t got time for this."},
    {"input": "Rewrite in a formal tone: U gotta check this out!"},
]

dataset = Dataset.from_list(custom_dataset)

In [4]:
def test_model(input_text):
    formatted_input = f"Rewrite in a formal tone: {input_text}"
    inputs = t5_tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True)
    outputs = t5_model.generate(**inputs, max_length=50)
    return t5_tokenizer.decode(outputs[0], skip_special_tokens=True)

test_sentences = [
    "Wanna grab a coffee?",
    "Hey dude, what’s up?",
    "No way that actually happened!",
    "Gimme a min to check."
]

## Before train

In [5]:
for sent in test_sentences:
    print(f"Informal: {sent}")
    print(f"Formal: {test_model(sent)}\n---------------------")

Informal: Wanna grab a coffee?
Formal: Wanna grab a coffee?
---------------------
Informal: Hey dude, what’s up?
Formal: Hey dude, what’s up?
---------------------
Informal: No way that actually happened!
Formal: No way that happened!
---------------------
Informal: Gimme a min to check.
Formal: I'll give you a minute to check.
---------------------


In [6]:
optimizer = optim.Adam(t5_model.parameters(), lr=5e-6)
num_epochs = 4

for epoch in range(num_epochs):
    total_loss = 0
    for data in dataset:
        for i in range(100):
            input_text = data["input"]

            inputs = t5_tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
            outputs = t5_model.generate(**inputs, max_length=50)
            generated_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
            if ((i+1) % 20) == 0: print(f"Informal Text: {input_text}\nGenerated Formal Text: {generated_text}")
            # reward_score = float(input("Rate the output (1-5): "))
            reward_score = 1
    
            labels = t5_tokenizer(generated_text, return_tensors="pt", padding=True, truncation=True).input_ids
            decoder_inputs = labels[:, :-1]

            logits = t5_model(**inputs, decoder_input_ids=decoder_inputs).logits
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

            seq_log_probs = torch.gather(log_probs, 2, labels[:, 1:].unsqueeze(2)).squeeze(2)
            log_prob = seq_log_probs.mean()

            loss = -log_prob * reward_score
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

t5_model.save_pretrained("fine_tuned")
t5_tokenizer.save_pretrained("fine_tuned")

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Informal Text: Rewrite in a formal tone: Gotta get this done ASAP!
Generated Formal Text: Gotta get this done ASAP!
Informal Text: Rewrite in a formal tone: Gotta get this done ASAP!
Generated Formal Text: Gotta get this done ASAP!
Informal Text: Rewrite in a formal tone: Gotta get this done ASAP!
Generated Formal Text: Gotta get this done ASAP!
Informal Text: Rewrite in a formal tone: Gotta get this done ASAP!
Generated Formal Text: Gotta get this done ASAP!
Informal Text: Rewrite in a formal tone: Gotta get this done ASAP!
Generated Formal Text: Gotta get this done ASAP!
Informal Text: Rewrite in a formal tone: U coming to the party?
Generated Formal Text: U coming to the party?
Informal Text: Rewrite in a formal tone: U coming to the party?
Generated Formal Text: U coming to the party?
Informal Text: Rewrite in a formal tone: U coming to the party?
Generated Formal Text: U coming to the party?
Informal Text: Rewrite in a formal tone: U coming to the party?
Generated Formal Text: U c

('fine_tuned/tokenizer_config.json',
 'fine_tuned/special_tokens_map.json',
 'fine_tuned/spiece.model',
 'fine_tuned/added_tokens.json')

In [7]:
t5_model = T5ForConditionalGeneration.from_pretrained("fine_tuned")
t5_tokenizer = T5Tokenizer.from_pretrained("fine_tuned")

for sent in test_sentences:
    print(f"Informal: {sent}")
    print(f"Formal: {test_model(sent)}\n")

Informal: Wanna grab a coffee?
Formal: Wanna grab a coffee?

Informal: Hey dude, what’s up?
Formal: Hey dude, what’s up?

Informal: No way that actually happened!
Formal: No way that actually happened!

Informal: Gimme a min to check.
Formal: Gimme a min to check.

