<a href="https://colab.research.google.com/github/yashvoladoddi37/movie-title-ocr-corrector/blob/main/ocr_text_correction_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch datasets pandas scikit-learn tqdm

Collecting sympy==1.13.1 (from torch)
  Downloading sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Downloading sympy-1.13.1-py3-none-any.whl (6.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.2/6.2 MB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sympy
  Attempting uninstall: sympy
    Found existing installation: sympy 1.13.3
    Uninstalling sympy-1.13.3:
      Successfully uninstalled sympy-1.13.3
Successfully installed sympy-1.13.1


In [None]:
#DATA PREPARATION
import pandas as pd
from sklearn.model_selection import train_test_split
def load_and_prepare_data(test_size=0.2, random_state=42, sample_fraction=0.5):
    # Load your dataset
    df = pd.read_csv('imdb_title_ocr_variations.csv')

    # Reduce dataset size by half
    df = df.sample(frac=sample_fraction, random_state=random_state)

    # Split into train and validation sets
    train_df, val_df = train_test_split(
        df,
        test_size=test_size,
        random_state=random_state
    )

    print(f"Original dataset size: {len(df) / sample_fraction}")
    print(f"Reduced dataset size: {len(df)}")
    print(f"Training samples: {len(train_df)}")
    print(f"Validation samples: {len(val_df)}")

    train_df = train_df.rename(columns={'ocr_generated_title': 'incorrect_text'})
    val_df = val_df.rename(columns={'ocr_generated_title': 'incorrect_text'})
    train_df = train_df.rename(columns={'original_title': 'correct_text'})
    val_df = val_df.rename(columns={'original_title': 'correct_text'})

    return train_df, val_df

In [None]:
#DEFINE THE TRAINING CLASS
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from tqdm import tqdm

class OCRTrainer:
    def __init__(self, device, use_amp=True, model_name="t5-base"):
        self.device = device
        self.use_amp = use_amp
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.scaler = GradScaler() if use_amp else None
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)

    def train(self, train_dataset, val_dataset, epochs=3, batch_size=32):
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)

        for epoch in range(epochs):
            self.model.train()
            total_loss = 0
            train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')

            for batch in train_pbar:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                self.optimizer.zero_grad()

                if self.use_amp:
                    with autocast():
                        outputs = self.model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            labels=labels
                        )
                        loss = outputs.loss

                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels
                    )
                    loss = outputs.loss
                    loss.backward()
                    self.optimizer.step()

                total_loss += loss.item()
                train_pbar.set_postfix({'loss': loss.item()})

            avg_loss = total_loss / len(train_loader)
            print(f'\nEpoch {epoch+1} - Average loss: {avg_loss:.4f}')

            # Validation
            val_loss = self.evaluate(val_loader)
            print(f'Validation loss: {val_loss:.4f}')

    def evaluate(self, val_loader):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['labels'].to(self.device)

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                total_loss += outputs.loss.item()

        return total_loss / len(val_loader)

    def correct_text(self, text):
        self.model.eval()
        inputs = self.tokenizer(text, return_tensors="pt", padding=True).to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=128
            )

        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)

    def save_model(self, path):
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

In [None]:
#CUSTOM DATASET CLASS
from torch.utils.data import Dataset

class OCRCorrectionDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=128):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        incorrect_text = row['incorrect_text']
        correct_text = row['correct_text']

        inputs = self.tokenizer(
            incorrect_text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(
                correct_text,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors='pt'
            )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': labels['input_ids'].squeeze()
        }

In [None]:
#HUGGINGFACE LOGIN -> IMPORTANT BEFORE TRAINING
from huggingface_hub import login

# Replace 'YOUR_API_KEY' with your actual Hugging Face API key
login(token="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")

In [None]:
#TRAINING
import torch
from google.colab import drive
from huggingface_hub import notebook_login
from transformers import TrainingArguments, Trainer
import matplotlib.pyplot as plt

def main():
    # Mount Google Drive (if needed)
    drive.mount('/content/drive')

    # Log into Hugging Face Hub
    # notebook_login()

    # Check GPU
    if torch.cuda.is_available():
        print(f'GPU: {torch.cuda.get_device_name(0)}')
        print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')
    else:
        print('No GPU available, using CPU')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load and prepare data
    train_df, val_df = load_and_prepare_data()
    print(f"\nTraining set size: {len(train_df)}")
    print(f"Validation set size: {len(val_df)}")
    print("\nColumn names:", train_df.columns.tolist())

    # Initialize trainer with mixed precision
    trainer_obj = OCRTrainer(device=device, use_amp=True)

    repo_name = "movie-title-OCR-corrector-t5"  # Replace with your desired repository name

    # Create datasets
    train_dataset = OCRCorrectionDataset(train_df, trainer_obj.tokenizer)
    val_dataset = OCRCorrectionDataset(val_df, trainer_obj.tokenizer)

    # TrainingArguments for Hugging Face Trainer
    training_args = TrainingArguments(
        output_dir=repo_name,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=64,
        num_train_epochs=1,
        logging_dir="./logs",
        logging_steps=10,  # Log every 10 steps
        evaluation_strategy="epoch",
        save_strategy="steps",  # Save checkpoints regularly
        save_steps=500,  # Save a checkpoint every 500 steps
        load_best_model_at_end=True,
        push_to_hub=True,  # Push to Hugging Face Hub
        resume_from_checkpoint=True,  # Allow resumption
    )

    # Hugging Face Trainer
    trainer = Trainer(
        model=trainer_obj.model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )

    # Train the model
    train_results = trainer.train(resume_from_checkpoint=True)  # Resumes if checkpoint exists

    # Save the model to Hugging Face Hub
    trainer.save_model()
    trainer.push_to_hub(commit_message="Final model after training")

    # Plot training loss if available
    try:
        plt.plot(train_results.history["loss"], label="Training Loss")
        plt.xlabel("Iteration")
        plt.ylabel("Loss")
        plt.title("Training Loss Curve")
        plt.legend()
        plt.show()
    except AttributeError:
        print("Training loss data not available for plotting.")

if __name__ == "__main__":
    main()


In [None]:
# Inference Test Cell
import torch
from google.colab import drive

# Mount Google Drive (if not already mounted)
drive.mount('/content/drive')

# Load the trained model and tokenizer
model_name = "yashvoladoddi37/movie-title-OCR-corrector-t5"  # Replace with your model name on Hugging Face

model_path = '/content/drive/MyDrive/ocr_correction_model'
tokenizer = AutoTokenizer.from_pretrained(model_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define device here

model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)  # Assuming 'device' is defined

# Test texts
test_texts = [
    "The Godfthr",
    "Pulp Ficton",
    "Star Wars: The Las Jed",
    "The Lord of the Rigns",
    "Did you watch Avend3rs: Endgame or Star Warts? My favourites are The Dark Knigt, The Godfthr, and The Shawshank Redemtion. I also liked The Lord of the Rigns, Pulp Ficton, Inceptionn, and Interestellar. These are all classiccs.",
    "Th3 Godfthr Star Warts The Dark Kn1gt The Shawshank Redemtion The Lord of the Rigns Pulp Ficton Av3nders: Endgame Inceptionn"
]

# Perform inference
for text in test_texts:
    inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=256
        )

    corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Original: {text}")
    print(f"Corrected: {corrected_text}\n")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Original: The Godfthr
Corrected: The Godfather

Original: Pulp Ficton
Corrected: Pulp Fiction

Original: Star Wars: The Las Jed
Corrected: Star Wars: The Last Jedi

Original: The Lord of the Rigns
Corrected: The Lord of the Rings

Original: Did you watch Avend3rs: Endgame or Star Warts? My favourites are The Dark Knigt, The Godfthr, and The Shawshank Redemtion. I also liked The Lord of the Rigns, Pulp Ficton, Inceptionn, and Interestellar. These are all classiccs.
Corrected: Did you watch Avengers: Endgame or Star Wars? My favourites are The Dark Knigt, The Godfather, and The Shawshank Redemonstration

Original: Th3 Godfthr Star Warts The Dark Kn1gt The Shawshank Redemtion The Lord of the Rigns Pulp Ficton Av3nders: Endgame Inceptionn
Corrected: The Godfather Star Wars The Dark Knight The Shawshank Redemontion The Lord of the Rings Pulp Fiction Avengers: Endg