In [1]:
# # Version	            Python version	   Compiler	    Build tools	  cuDNN	  CUDA
# # tensorflow_gpu-2.10.0	   3.7-3.10	       MSVC 2019	Bazel 5.1.1	   8.1	  11.2

# # %pip install nvidia-pyindex
# # %pip install tensorflow-gpu==2.10.0
# # %pip install torch
# # %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# # %pip install transformers
# # %pip install transformers[torch]
# # %pip install accelerate>=0.26.0
# %pip install matplotlib seaborn

In [2]:
# import torch
# print("PyTorch Version:", torch.__version__)
# print("CUDA Available:", torch.cuda.is_available())
# print("CUDA Version:", torch.version.cuda)

# x = torch.rand(5, 3)
# print(x)

1. Read the Dataset

In [3]:
import pandas as pd
from sklearn.model_selection import train_test_split
import ast

# Load your dataset
df = pd.read_csv("datasets/bert_data.csv")  # Ensure your dataset is in the correct CSV format

# Safely convert stringified lists back to Python lists
df['attention_mask'] = df['attention_mask'].apply(ast.literal_eval)
small_df = df.sample(n=1000, random_state=42).reset_index(drop=True)


Phase 1: Spell Correction

In [4]:
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
import torch

# 1. Split the data into training and validation sets (80% train, 20% validation)
train_df, val_df = train_test_split(small_df, test_size=0.2, random_state=42)

# 2. Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained("sagorsarker/bangla-bert-base")

# 3. Custom Dataset for BERT
class BertDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=64):
        self.sentences = df['sentence'].tolist()
        self.labels = df['label'].tolist()  # Add label support
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]  # Extract label
        
        encoding = self.tokenizer.encode_plus(       # Tokenize and encode the sentence
            sentence,
            add_special_tokens=True,      # Add [CLS] and [SEP] tokens
            max_length=self.max_length,   # Set maximum length
            padding='max_length',         # Pad sequences to max_length
            truncation=True,              # Truncate sequences if needed
            return_attention_mask=True,   # Return attention mask
            return_tensors='pt'           # Return PyTorch tensors
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),           # Token IDs
            'attention_mask': encoding['attention_mask'].squeeze(0), # Attention mask
            'labels': torch.tensor(label, dtype=torch.long)          # Label as tensor
        }

# 4. Prepare datasets
train_dataset = BertDataset(train_df, tokenizer)
val_dataset = BertDataset(val_df, tokenizer)

# Example DataLoader
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)


  from .autonotebook import tqdm as notebook_tqdm


In [5]:

# # Inspect the batch
# for batch in train_loader:
#     for k, v in batch.items():
#         print(f"{k}: {v.shape}")  # Expected: [batch_size, seq_length]
#     break

In [6]:
# 5. Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained("sagorsarker/bangla-bert-base", num_labels=2)

# 6. Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # Output directory for model and logs
    num_train_epochs=3,              # Number of training epochs
    per_device_train_batch_size=8,   # Batch size for training
    per_device_eval_batch_size=8,    # Batch size for evaluation
    warmup_steps=500,                # Number of warmup steps
    weight_decay=0.01,               # Strength of weight decay
    logging_dir='./logs',            # Directory for logs
    logging_steps=10,                # Log every 10 steps
    evaluation_strategy="epoch",     # Evaluate every epoch
    save_strategy="epoch",           # Save the model after every epoch
    load_best_model_at_end=True      # Load the best model at the end
)

# 7. Initialize the Trainer
trainer = Trainer(
    model=model,                     # Model to train
    args=training_args,              # Training arguments
    train_dataset=train_dataset,     # Training dataset
    eval_dataset=val_dataset,        # Validation dataset
    tokenizer=tokenizer              # Tokenizer used to process input
)

trainer.train()

results = trainer.evaluate()
print(results)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sagorsarker/bangla-bert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(
  3%|▎         | 10/300 [00:35<14:37,  3.03s/it]

{'loss': 0.6733, 'grad_norm': 7.260274410247803, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.1}


  7%|▋         | 20/300 [01:05<14:44,  3.16s/it]

{'loss': 0.6703, 'grad_norm': 5.957991600036621, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.2}


 10%|█         | 30/300 [01:37<13:55,  3.09s/it]

{'loss': 0.631, 'grad_norm': 7.557470798492432, 'learning_rate': 3e-06, 'epoch': 0.3}


 13%|█▎        | 40/300 [02:07<13:14,  3.06s/it]

{'loss': 0.5709, 'grad_norm': 6.978739261627197, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.4}


 17%|█▋        | 50/300 [02:39<12:58,  3.12s/it]

{'loss': 0.5008, 'grad_norm': 4.75669527053833, 'learning_rate': 5e-06, 'epoch': 0.5}


 20%|██        | 60/300 [03:11<12:44,  3.19s/it]

{'loss': 0.4174, 'grad_norm': 5.441647052764893, 'learning_rate': 6e-06, 'epoch': 0.6}


 23%|██▎       | 70/300 [03:42<11:52,  3.10s/it]

{'loss': 0.3493, 'grad_norm': 15.724018096923828, 'learning_rate': 7.000000000000001e-06, 'epoch': 0.7}


 27%|██▋       | 80/300 [04:13<11:20,  3.09s/it]

{'loss': 0.3972, 'grad_norm': 4.687841892242432, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.8}


 30%|███       | 90/300 [04:44<10:49,  3.09s/it]

{'loss': 0.257, 'grad_norm': 3.990324020385742, 'learning_rate': 9e-06, 'epoch': 0.9}


 33%|███▎      | 100/300 [05:15<10:20,  3.10s/it]

{'loss': 0.1922, 'grad_norm': 9.479243278503418, 'learning_rate': 1e-05, 'epoch': 1.0}


                                                 
 33%|███▎      | 100/300 [05:30<10:20,  3.10s/it]

{'eval_loss': 0.1646597683429718, 'eval_runtime': 15.271, 'eval_samples_per_second': 13.097, 'eval_steps_per_second': 1.637, 'epoch': 1.0}


Saving vocabulary to ./results\checkpoint-100\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-100\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-100\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-100\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-100\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-100\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-100\vocab.txt: vocabulary indices are not consecutive. Please check th

{'loss': 0.116, 'grad_norm': 5.357348918914795, 'learning_rate': 1.1000000000000001e-05, 'epoch': 1.1}


 40%|████      | 120/300 [06:55<09:23,  3.13s/it]

{'loss': 0.2777, 'grad_norm': 14.274627685546875, 'learning_rate': 1.2e-05, 'epoch': 1.2}


 43%|████▎     | 130/300 [07:27<08:54,  3.15s/it]

{'loss': 0.1875, 'grad_norm': 17.907670974731445, 'learning_rate': 1.3000000000000001e-05, 'epoch': 1.3}


 47%|████▋     | 140/300 [07:58<08:19,  3.12s/it]

{'loss': 0.1702, 'grad_norm': 0.2626342177391052, 'learning_rate': 1.4000000000000001e-05, 'epoch': 1.4}


 50%|█████     | 150/300 [08:31<08:04,  3.23s/it]

{'loss': 0.2557, 'grad_norm': 3.4803872108459473, 'learning_rate': 1.5e-05, 'epoch': 1.5}


 53%|█████▎    | 160/300 [09:05<07:54,  3.39s/it]

{'loss': 0.1029, 'grad_norm': 33.490211486816406, 'learning_rate': 1.6000000000000003e-05, 'epoch': 1.6}


 57%|█████▋    | 170/300 [09:40<07:26,  3.43s/it]

{'loss': 0.3043, 'grad_norm': 68.2522964477539, 'learning_rate': 1.7000000000000003e-05, 'epoch': 1.7}


 60%|██████    | 180/300 [10:17<07:10,  3.59s/it]

{'loss': 0.1937, 'grad_norm': 5.008024215698242, 'learning_rate': 1.8e-05, 'epoch': 1.8}


 63%|██████▎   | 190/300 [10:52<06:51,  3.74s/it]

{'loss': 0.3404, 'grad_norm': 35.58915328979492, 'learning_rate': 1.9e-05, 'epoch': 1.9}


 67%|██████▋   | 200/300 [11:27<05:39,  3.39s/it]

{'loss': 0.0512, 'grad_norm': 30.20867347717285, 'learning_rate': 2e-05, 'epoch': 2.0}


                                                 
 67%|██████▋   | 200/300 [11:44<05:39,  3.39s/it]

{'eval_loss': 0.13203367590904236, 'eval_runtime': 16.4823, 'eval_samples_per_second': 12.134, 'eval_steps_per_second': 1.517, 'epoch': 2.0}


Saving vocabulary to ./results\checkpoint-200\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-200\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-200\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-200\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-200\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-200\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-200\vocab.txt: vocabulary indices are not consecutive. Please check th

{'loss': 0.0945, 'grad_norm': 0.0517057329416275, 'learning_rate': 2.1e-05, 'epoch': 2.1}


 73%|███████▎  | 220/300 [16:36<04:37,  3.47s/it]

{'loss': 0.1235, 'grad_norm': 0.06930798292160034, 'learning_rate': 2.2000000000000003e-05, 'epoch': 2.2}


 77%|███████▋  | 230/300 [17:11<03:57,  3.40s/it]

{'loss': 0.0018, 'grad_norm': 0.02909289486706257, 'learning_rate': 2.3000000000000003e-05, 'epoch': 2.3}


 80%|████████  | 240/300 [17:45<03:22,  3.37s/it]

{'loss': 0.1196, 'grad_norm': 0.0242856964468956, 'learning_rate': 2.4e-05, 'epoch': 2.4}


 83%|████████▎ | 250/300 [18:19<02:49,  3.38s/it]

{'loss': 0.1721, 'grad_norm': 0.015851564705371857, 'learning_rate': 2.5e-05, 'epoch': 2.5}


 87%|████████▋ | 260/300 [18:54<02:15,  3.39s/it]

{'loss': 0.0285, 'grad_norm': 0.02523624338209629, 'learning_rate': 2.6000000000000002e-05, 'epoch': 2.6}


 90%|█████████ | 270/300 [19:29<01:46,  3.55s/it]

{'loss': 0.0254, 'grad_norm': 0.01997709833085537, 'learning_rate': 2.7000000000000002e-05, 'epoch': 2.7}


 93%|█████████▎| 280/300 [20:04<01:08,  3.41s/it]

{'loss': 0.12, 'grad_norm': 0.12705987691879272, 'learning_rate': 2.8000000000000003e-05, 'epoch': 2.8}


 97%|█████████▋| 290/300 [20:38<00:34,  3.45s/it]

{'loss': 0.0777, 'grad_norm': 21.48626708984375, 'learning_rate': 2.9e-05, 'epoch': 2.9}


100%|██████████| 300/300 [21:13<00:00,  3.40s/it]

{'loss': 0.1687, 'grad_norm': 0.11523807048797607, 'learning_rate': 3e-05, 'epoch': 3.0}


Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check th

{'eval_loss': 0.13677270710468292, 'eval_runtime': 15.673, 'eval_samples_per_second': 12.761, 'eval_steps_per_second': 1.595, 'epoch': 3.0}


Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check that the vocabulary is not corrupted!
Saving vocabulary to ./results\checkpoint-300\vocab.txt: vocabulary indices are not consecutive. Please check th

{'train_runtime': 1944.4213, 'train_samples_per_second': 1.234, 'train_steps_per_second': 0.154, 'train_loss': 0.25302942625557384, 'epoch': 3.0}


100%|██████████| 300/300 [32:25<00:00,  6.48s/it]
100%|██████████| 25/25 [00:14<00:00,  1.74it/s]

{'eval_loss': 0.13203367590904236, 'eval_runtime': 14.7981, 'eval_samples_per_second': 13.515, 'eval_steps_per_second': 1.689, 'epoch': 3.0}





Evaluate Model Performance on Validation Set

In [7]:
# Evaluate the model on the validation set
results = trainer.evaluate()
print("Evaluation Results:", results)


100%|██████████| 25/25 [00:14<00:00,  1.75it/s]

Evaluation Results: {'eval_loss': 0.13203367590904236, 'eval_runtime': 14.9007, 'eval_samples_per_second': 13.422, 'eval_steps_per_second': 1.678, 'epoch': 3.0}





Make Predictions on New Data

In [8]:
# Sample new sentences for inference
new_sentences = [
    "গাজীপুরের কালিয়াকৈর উপজেলার তেলিরচালা এলাকায়",  # Correct sentence
    "গাজীপ৳ুরের ালিয়াকৈর উপজেলংার তেলিচরালা এলাকা়"  # Sentence with errors
]

# Preprocess and tokenize new sentences
inputs = tokenizer(new_sentences, return_tensors='pt', padding=True, truncation=True, max_length=128)

# Move inputs to model's device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
inputs = {key: val.to(device) for key, val in inputs.items()}

# Make predictions
model.eval()
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.argmax(outputs.logits, axis=-1)

# Interpret predictions
for sentence, pred in zip(new_sentences, predictions):
    label = "Correct" if pred == 0 else "Incorrect"
    print(f"Sentence: {sentence}\nPrediction: {label}\n")


Sentence: গাজীপুরের কালিয়াকৈর উপজেলার তেলিরচালা এলাকায়
Prediction: Correct

Sentence: গাজীপ৳ুরের ালিয়াকৈর উপজেলংার তেলিচরালা এলাকা়
Prediction: Incorrect



Save the Trained Model

In [14]:
# Save the model and tokenizer
model.save_pretrained("./brammerly")
tokenizer.save_pretrained("./brammerly")
print("Model and tokenizer saved successfully!")


Model and tokenizer saved successfully!


Reload the model

In [15]:
from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained("./brammerly")
tokenizer = BertTokenizer.from_pretrained("./brammerly")


Inspect Misclassifications

In [16]:
# batch = {k: v.to(device) for k, v in batch.items()}

# # Debug shapes
# print("Input IDs shape:", batch['input_ids'].shape)
# print("Attention Mask shape:", batch['attention_mask'].shape)

# # Forward pass
# with torch.no_grad():
#     outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
#     logits = outputs.logits
#     preds = torch.argmax(logits, dim=-1)

# print("Predictions:", preds)


In [None]:
from sklearn.metrics import confusion_matrix, classification_report

# Create predictions for validation data
model.eval()
predictions = []
true_labels = []

for batch in val_dataset:  # val_dataloader is your DataLoader for validation data
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)

    predictions.extend(preds.cpu().numpy())
    true_labels.extend(batch['labels'].cpu().numpy())

# Convert to DataFrame for easy inspection
val_df['predicted_label'] = predictions
val_df['true_label'] = true_labels

# Inspect misclassified examples
misclassified = val_df[val_df['predicted_label'] != val_df['true_label']]
print("Misclassified Examples:")
print(misclassified[['sentence', 'true_label', 'predicted_label']])

Confusion Matrix

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Generate confusion matrix
conf_matrix = confusion_matrix(true_labels, predictions)

# Plot confusion matrix
plt.figure(figsize=(6, 5))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=["Correct", "Incorrect"], yticklabels=["Correct", "Incorrect"])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()


Compute Accuracy, Precision, Recall, and F1 Score

In [None]:
from sklearn.metrics import classification_report

# Generate a classification report
report = classification_report(true_labels, predictions, target_names=["Correct", "Incorrect"])
print("Classification Report:")
print(report)


Phase 2: Grammar Correction

In [None]:
# Install GECToR dependencies (done once)
# !pip install allennlp allennlp-models

from allennlp.predictors import Predictor
from allennlp_models.pretrained import load_predictor

# Load GECToR pre-trained model
predictor = load_predictor("https://storage.googleapis.com/allennlp-public-models/roberta-base-2020.06.09.tar.gz")

# Example prediction
example_sentence = "তোমরা ভুলগুলো চিহ্নিত করনি।"
prediction = predictor.predict(sentence=example_sentence)

# Output
print("Original:", example_sentence)
print("Corrected:", prediction["tokens"])


GECToR Fine-Tuning for Bangla Grammar Correction

In [None]:
# Clone GECToR Repo
git clone https://github.com/grammarly/gector.git
cd gector

# Install Dependencies
pip install -r requirements.txt

# Train GECToR
python train.py \
    --train_set "data/train.txt" \
    --dev_set "data/dev.txt" \
    --model_dir "output/bangla_model" \
    --pretrained_transformer "sagorsarker/bangla-bert-base" \
    --vocab_path "data/vocab" \
    --batch_size 32


Fine-Tune T5 for Sentence-Level Grammar Correction

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments
from datasets import Dataset

# Initialize tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-small")  # Replace with Bangla-supported T5
model = T5ForConditionalGeneration.from_pretrained("t5-small")

# Prepare data
train_data = {"input_text": df['synth_spelling_errors'].tolist(), 
              "target_text": df['cleaned_text'].tolist()}

train_dataset = Dataset.from_dict(train_data)

def preprocess_function(examples):
    inputs = ["correct: " + text for text in examples['input_text']]
    targets = examples['target_text']
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Tokenize data
tokenized_train = train_dataset.map(preprocess_function, batched=True)

# Training arguments
training_args = TrainingArguments(
    output_dir="./t5_bangla_results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train
)

# Start training
trainer.train()

# Save model
model.save_pretrained("./t5_bangla_finetuned")


Phase 3: Sentence Correction

In [None]:
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load T5 model and tokenizer
model_name = "google/mt5-small"  # Use T5 pre-trained model for multilingual tasks
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

# Tokenize data for T5
def preprocess_function(examples):
    inputs = [f"grammar: {i}" for i in examples["incorrect_sentence"]]
    outputs = examples["cleaned_text"]
    model_inputs = tokenizer(inputs, text_target=outputs, max_length=128, truncation=True)
    return model_inputs

# Tokenize train and validation sets
from datasets import Dataset
train_dataset = Dataset.from_pandas(train_data)
val_dataset = Dataset.from_pandas(val_data)

train_dataset = train_dataset.map(preprocess_function, batched=True)
val_dataset = val_dataset.map(preprocess_function, batched=True)

# Fine-tune the model
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    predict_with_generate=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

trainer.train()
trainer.save_model("t5_sentence_correction_model")


Integrate GECToR, T5, and BERT

In [None]:
def correction_pipeline(sentence, gector_model, t5_model, bert_model, tokenizer):
    # Step 1: GECToR Token Correction
    gector_corrected_sentence = gector_model.correct(sentence)

    # Step 2: T5 Sentence-Level Correction
    input_text = "correct: " + gector_corrected_sentence
    inputs = tokenizer(input_text, return_tensors="pt")
    t5_outputs = t5_model.generate(**inputs)
    t5_corrected_sentence = tokenizer.decode(t5_outputs[0], skip_special_tokens=True)

    # Step 3: BERT Error Verification
    inputs = tokenizer(t5_corrected_sentence, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    outputs = bert_model(**inputs)
    logits = outputs.logits
    prediction = torch.argmax(logits, dim=-1)

    return {"original": sentence, 
            "gector": gector_corrected_sentence, 
            "t5": t5_corrected_sentence, 
            "bert_label": "Correct" if prediction.item() == 1 else "Incorrect"}

# Example usage
sentence = "গাজীপ৳ুরের ালিয়াকৈর উপজেলংার তেলিচরালা এলাকা়..."
result = correction_pipeline(sentence, gector_model, t5_model, bert_model, tokenizer)
print(result)
