In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import T5Tokenizer, T5ForConditionalGeneration
from datasets import load_dataset, Dataset

In [2]:
MODEL_NAME = "google/flan-t5-large"
t5_tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
t5_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


config.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.13G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [3]:
custom_dataset = [
    {"input": "Gotta get this done ASAP!"},
    {"input": "U coming to the party?"},
    {"input": "Can't believe this dude won!"},
    {"input": "Lemme know when u free."},
    {"input": "That's my bad."},
    {"input": "Gimme a sec."},
    {"input": "I ain't got time for this."},
    {"input": "BRB, just gotta grab something."},
    {"input": "Yo, what's up?"}, 
    {"input": "Can't even right now"},
    {"input": "I'm dead, this is too funny."},
    {"input": "Let's chill sometime."},
    {"input": "That's crazy, bro!"}, 
    {"input": "What's the move tonight?"}, 
    {"input": "Hit me up when you're free."},
    {"input": "Lowkey, I'm not feeling this."}
]

dataset = Dataset.from_list(custom_dataset)

In [4]:
abb_words = [
    "gimme", "lemme", "dude", "u", "lol", "brb", "idk", "omg", "btw", 
    "wanna", "gonna", "kinda", "cuz", "aint", "yolo", "fomo", "tbh", "smh", 
    "lmao", "rofl", "lit", "srsly", "thx", "pls", "yup", "nah", "xoxo", "wth", "asap",
    "i'm", "i'll", "it's", "you're", "they're", "we're", "isn't", "wasn't", "don't", "doesn't", "can't", "won't",
    "let's", "that's", "what's", "jk", "np", "imo", "fyi", "btw", "smh", "lmfao", "asap", "omw", "ppl", "ttyl", "brb",
    "bff", "g2g", "cya", "tmi", "l8r", "wdym", "cuz", "yass", "idc", "roflmao", "ttyl", "lmk"
]

def get_reward(input_text, generated_text):
    for word in abb_words:
        if word.lower() in generated_text.lower():
            return -0.4
    return 0.4

In [5]:
def test_model(input_text, tokenizer, model):
    prompt = f"""Please rewrite the following text without any abbreviations: {input_text}"""
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
    outputs = model.generate(**inputs)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

test_sentences = [
    "Wanna grab a coffee?",
    "Hey dude, what's up?",
    "I don't get what you mean.",
    "Gimme a min to check."
]

## Before train

In [6]:
for sent in test_sentences:
    print(f"Original: {sent}")
    print(f"Normalized: {test_model(sent, t5_tokenizer, t5_model)}\n---------------------")

Original: Wanna grab a coffee?
Normalized: Wanna grab a coffee?
---------------------
Original: Hey dude, what's up?
Normalized: Hey dude, what's up?
---------------------
Original: I don't get what you mean.
Normalized: I don't get what you mean.
---------------------
Original: Gimme a min to check.
Normalized: Give me a minute to check.
---------------------


## Train

In [7]:
optimizer = optim.Adam(t5_model.parameters(), lr=5e-6)
num_epochs = 10
previous_generated_text = ""

for epoch in range(num_epochs):
    total_loss = 0
    for data in dataset:
        for i in range(10):
            input_text = data["input"]
            prompt = f"""Please rewrite the following text without any abbreviations: {input_text}"""

            inputs = t5_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
            outputs = t5_model.generate(**inputs)
            generated_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True)

            if generated_text != previous_generated_text:
                print(f"----------\nOriginal Text: {input_text}\nNormalized Text: {generated_text}")
                previous_generated_text = generated_text
                
            reward_score = get_reward(input_text, generated_text)
    
            labels = t5_tokenizer(generated_text, return_tensors="pt", padding=True, truncation=True).input_ids
            decoder_inputs = labels[:, :-1]

            logits = t5_model(**inputs, decoder_input_ids=decoder_inputs).logits
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

            seq_log_probs = torch.gather(log_probs, 2, labels[:, 1:].unsqueeze(2)).squeeze(2)
            log_prob = seq_log_probs.mean()

            loss = -log_prob * reward_score
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss:.4f}")

t5_tokenizer.save_pretrained("fine_tuned")
t5_model.save_pretrained("fine_tuned")

----------
Original Text: Gotta get this done ASAP!
Normalized Text: Gotta get this done ASAP!


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


----------
Original Text: U coming to the party?
Normalized Text: U coming to the party?
----------
Original Text: U coming to the party?
Normalized Text: You coming to the party?
----------
Original Text: Can't believe this dude won!
Normalized Text: Can't believe this guy won!
----------
Original Text: Can't believe this dude won!
Normalized Text: Can't believe this guy won.
----------
Original Text: Lemme know when u free.
Normalized Text: Let me know when you are free.
----------
Original Text: Lemme know when u free.
Normalized Text: Let's know when you're free.
----------
Original Text: That's my bad.
Normalized Text: That's my fault.
----------
Original Text: Gimme a sec.
Normalized Text: I'll give you a sec.
----------
Original Text: I ain't got time for this.
Normalized Text: I don't have time for this.
----------
Original Text: BRB, just gotta grab something.
Normalized Text: I just gotta grab something.
----------
Original Text: BRB, just gotta grab something.
Normalized Tex

## After train

In [8]:
t5_test_tokenizer = T5Tokenizer.from_pretrained("fine_tuned")
t5_test_model = T5ForConditionalGeneration.from_pretrained("fine_tuned")

for sent in test_sentences:
    print(f"Original: {sent}")
    print(f"Normalized: {test_model(sent, t5_test_tokenizer, t5_test_model)}\n")

Original: Wanna grab a coffee?
Normalized: Wanna grab a coffee?

Original: Hey dude, what's up?
Normalized: Hey dude, what's up?

Original: I don't get what you mean.
Normalized: I don't get what you mean.

Original: Gimme a min to check.
Normalized: Gimme a minute to check.

