In [5]:
import torch
from transformers import MarianMTModel, MarianTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset
from tqdm import tqdm

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Free up memory
torch.cuda.empty_cache()

# Load the tokenizer and model for MarianMT
model_name = "Helsinki-NLP/opus-mt-hi-en"  # Using a Hindi to English model for demonstration purposes
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name).to(device)

# Load the datasets
def read_sentences(src_path, tgt_path, limit=None):
    with open(src_path, "r", encoding="utf-8") as src_file, open(tgt_path, "r", encoding="utf-8") as tgt_file:
        incorrect_sentences = [line.strip() for line in tqdm(src_file.readlines(), desc="Reading Incorrect Sentences")]
        correct_sentences = [line.strip() for line in tqdm(tgt_file.readlines(), desc="Reading Correct Sentences")]
    if limit:
        incorrect_sentences = incorrect_sentences[:limit]
        correct_sentences = correct_sentences[:limit]
    return incorrect_sentences, correct_sentences

# Paths for training, validation, and test datasets
train_src_path = "wikiExtractsData/data/train_merge.src"
train_tgt_path = "wikiExtractsData/data/train_merge.tgt"
valid_src_path = "wikiExtractsData/data/valid.src"
valid_tgt_path = "wikiExtractsData/data/valid.tgt"
test_src_path = "Wiki-edits/hiwiki.extracted.clean.src"
test_tgt_path = "Wiki-edits/hiwiki.extracted.clean.trg"

# Load Training, Validation, and Test Data (Limited subset for faster training initially)
train_incorrect, train_correct = read_sentences(train_src_path, train_tgt_path, limit=1000000)
valid_incorrect, valid_correct = read_sentences(valid_src_path, valid_tgt_path, limit=200000)
test_incorrect, test_correct = read_sentences(test_src_path, test_tgt_path, limit=2500)

# Define Dataset class for sentence correction
class SentenceCorrectionDataset(Dataset):
    def __init__(self, incorrect_sentences, correct_sentences, tokenizer, max_len=128):
        self.incorrect_sentences = incorrect_sentences
        self.correct_sentences = correct_sentences
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.incorrect_sentences)

    def __getitem__(self, idx):
        incorrect = self.incorrect_sentences[idx]
        correct = self.correct_sentences[idx]

        # Add a prefix to indicate the task type
        input_text = incorrect
        target_text = correct

        # Tokenize input and target texts
        inputs = self.tokenizer(
            input_text, 
            max_length=self.max_len, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )

        targets = self.tokenizer(
            target_text, 
            max_length=self.max_len, 
            padding="max_length", 
            truncation=True, 
            return_tensors="pt"
        )

        # Squeeze tensors to remove unnecessary dimensions
        input_ids = inputs.input_ids.squeeze()
        attention_mask = inputs.attention_mask.squeeze()
        labels = targets.input_ids.squeeze()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

# Create datasets using the small subset
train_dataset = SentenceCorrectionDataset(train_incorrect, train_correct, tokenizer)
valid_dataset = SentenceCorrectionDataset(valid_incorrect, valid_correct, tokenizer)
test_dataset = SentenceCorrectionDataset(test_incorrect, test_correct, tokenizer)

# Define Training Arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,  # Standard learning rate for fine-tuning
    per_device_train_batch_size=8,  # Adjust batch size to fit in GPU memory
    per_device_eval_batch_size=8,
    num_train_epochs=5,  # More epochs for better fine-tuning
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available(),  # Mixed precision training if CUDA is available
)

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset
)

# Train the Model using the subset
trainer.train()

# Evaluate on the validation dataset
print("Evaluating on validation dataset:")
validation_results = trainer.evaluate()
print(validation_results)

# Evaluate on the test dataset
print("Evaluating on test dataset:")
test_results = trainer.evaluate(eval_dataset=test_dataset)
print(test_results)

# Save the Model
model.save_pretrained("./sentence_correction_model")
tokenizer.save_pretrained("./sentence_correction_model")

# Make Predictions with the Trained Model
model = MarianMTModel.from_pretrained("./sentence_correction_model").to(device)
tokenizer = MarianTokenizer.from_pretrained("./sentence_correction_model")

# Predict for a new incorrect sentence
incorrect_sentence = "उसकी प्रतिभा की गहराई किसी अनजाने समुद्र जैसा है"
inputs = tokenizer(incorrect_sentence, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(device)

# Generate corrected sentence
with torch.no_grad():
    output_ids = model.generate(input_ids=inputs["input_ids"], max_length=128, num_beams=5, early_stopping=True)

# Decode the output to get the corrected sentence
corrected_sentence = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(f"Corrected Sentence: {corrected_sentence}")


Using device: cuda


Reading Incorrect Sentences: 100%|██████████| 2607757/2607757 [00:00<00:00, 3385905.95it/s]
Reading Correct Sentences: 100%|██████████| 2607757/2607757 [00:00<00:00, 3517563.21it/s]
Reading Incorrect Sentences: 100%|██████████| 521552/521552 [00:00<00:00, 3303649.17it/s]
Reading Correct Sentences: 100%|██████████| 521552/521552 [00:00<00:00, 3271656.98it/s]
Reading Incorrect Sentences: 100%|██████████| 13187/13187 [00:00<00:00, 2960143.80it/s]
Reading Correct Sentences: 100%|██████████| 13187/13187 [00:00<00:00, 3077923.59it/s]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,2.0838,0.386179
2,0.4277,0.224713
3,0.268,0.192019
4,0.1624,0.180972
5,0.1423,0.174762


There were missing keys in the checkpoint model loaded: ['model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.decoder.embed_tokens.weight', 'model.decoder.embed_positions.weight', 'lm_head.weight'].


Evaluating on validation dataset:


{'eval_loss': 0.17476209998130798, 'eval_runtime': 0.5792, 'eval_samples_per_second': 345.283, 'eval_steps_per_second': 43.16, 'epoch': 5.0}
Evaluating on test dataset:
{'eval_loss': 0.21930748224258423, 'eval_runtime': 0.5763, 'eval_samples_per_second': 347.06, 'eval_steps_per_second': 43.383, 'epoch': 5.0}
Corrected Sentence: उसके प्रतिभा का गहराई किसी अज्ञाते समुद्र जैसे है .
