In [None]:
import torch
import numpy as np
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Trainer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
import evaluate

In [2]:
print("hello wrold")

hello wrold


In [None]:
from transformers import AutoModelForSeq2SeqLM
model_name = "google/flan-t5-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map="auto"
)

lora_config = LoraConfig(
    r=16, #rank
    lora_alpha=32, 
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

model = get_peft_model(model, lora_config)

In [None]:
raw_dataset = load_dataset("Hritshhh/T5-Dataset", split="train[:100%]")
raw_dataset = raw_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = raw_dataset["train"]
eval_dataset = raw_dataset["test"]

In [None]:
def preprocess_function(examples):
    inputs = ["grammar correction: " + inp for inp in examples["sentence"]]
    targets = examples["corrections"]  # already corrected sentences
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_eval_dataset = eval_dataset.map(preprocess_function, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
metric = evaluate.load("sacrebleu")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"sacrebleu": result["score"]}

In [None]:
training_args = TrainingArguments(
    output_dir="./t5_grammar_model",
    evaluation_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset, 
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
model = model.merge_and_unload()

In [None]:
model.save_pretrained("C:\Sarayu\T5 Grammarator\Fine Tuning")
tokenizer.save_pretrained("C:\Sarayu\T5 Grammarator\Fine Tuning")

In [None]:
from transformers import pipeline

grammar_pipe = pipeline(
    "text2text-generation",
    model= "google/flan-t5-xl",
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

test_sentence = "She no went to school yesterday"
print(grammar_pipe(f"grammar correction: {test_sentence}")[0]["generated_text"])