## Pipeline
This is the main notebook

Imports

In [1]:
import os
import warnings
warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import numpy as np
import pytesseract
from PIL import Image
from torchvision import transforms,models
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import easyocr
import re
from TTS.api import TTS
from IPython.display import Audio

## Define function to classify image as clear or blurry


In [None]:
from PIL import Image
import torch
from torchvision import transforms, models
import torch.nn as nn

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

#  val_transform (same as during training)
val_transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Class names matched with training dataset
class_names = ['blur', 'clear']  
# Load model architecture
def load_model():
    model = models.resnet18(pretrained=False)
    model.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(model.fc.in_features, 2)  
    )
    model.load_state_dict(torch.load("best_resnet_blur_classifier.pth", map_location=device))
    model.to(device)
    model.eval()
    return model

# Classify image function
def classify_image(image_path):
    model = load_model()
    
    image = Image.open(image_path).convert('RGB')
    image = val_transform(image).unsqueeze(0).to(device)  

    with torch.no_grad():
        output = model(image)
        pred = torch.argmax(output, dim=1).item()
    
    return class_names[pred]


## Define OCR function for clear images


In [4]:
# Initialize EasyOCR reader 
reader = easyocr.Reader(['en'], gpu=True)

def normalize_symbols(text: str) -> str:
    
    
    text = re.sub(r'\s*–\s*', ' – ', text)  # checking proper spacing around dashes
    text = re.sub(r'\s+', ' ', text).strip()  # Normalize extra spaces
    return text

def run_ocr(image_path: str) -> str:
    """
    Extract printed or overlaid text from an image using EasyOCR.

    Args:
        image_path (str): Path to the image file.

    Returns:
        str: Cleaned and normalized extracted text
    """
    results = reader.readtext(image_path, detail=0)
    raw_text = " ".join(results)
    cleaned_text = normalize_symbols(raw_text)
    return cleaned_text

## Load punctuation restoration model and tokenizer


In [5]:
model_dir = "t5_punct_model"
tokenizer = T5Tokenizer.from_pretrained(model_dir)
punctuation_model = T5ForConditionalGeneration.from_pretrained(model_dir)
punctuation_model.eval()
punctuation_model.to(device)


T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

## function to restore punctuation and capitalization on OCR text


In [6]:
def restore_punctuation(text):
    prompt = "restore punctuation and capitalization: " + text.lower().strip()
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=64)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        output = punctuation_model.generate(**inputs, max_length=64)
    return tokenizer.decode(output[0], skip_special_tokens=True)


## Full pipeline function: classify → OCR → punctuation restoration → print


In [7]:

def process_image(image_path):
    prediction=classify_image(image_path)
    return prediction



In [8]:
def tts(text, output_path="output.wav"):
    """
    Convert text to speech using TTS.
    
    Args:
        text (str): Text to convert to speech.
        output_path (str): Path to save the output audio file.
    """
    tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False)
    tts.tts_to_file(text=text, file_path=output_path)
    

## Run the pipeline on a test image


In [11]:
def pipeline(image_path):
    prediction=process_image(image_path)
    if prediction == 'clear':
        print("Image classified as CLEAR.")
        extracted_text = run_ocr(image_path)
        if extracted_text:
            print("Extracted Text:", extracted_text)
            punctuated_text = restore_punctuation(extracted_text)
            print("Punctuated Text:", punctuated_text)
            tts(punctuated_text, output_path="output.wav")
            print("Text-to-Speech conversion completed. Audio saved as 'output.wav'.")
        else:
            print("No text found in the image.")


image_path = "text_overlay_dataset_test/img_00002.jpg"  
pipeline(image_path)
output_path = "output.wav"  
Audio(output_path, autoplay=True) 


Image classified as CLEAR.
Extracted Text: lanterns glow on the narrow softly path
Punctuated Text: Lanterns glow on the narrow, softly path.
 > tts_models/en/ljspeech/tacotron2-DDC is already downloaded.
 > vocoder_models/en/ljspeech/hifigan_v2 is already downloaded.
 > Using model: Tacotron2
 > Setting up Audio Processor...
 | > sample_rate:22050
 | > resample:False
 | > num_mels:80
 | > log_func:np.log
 | > min_level_db:-100
 | > frame_shift_ms:None
 | > frame_length_ms:None
 | > ref_level_db:20
 | > fft_size:1024
 | > power:1.5
 | > preemphasis:0.0
 | > griffin_lim_iters:60
 | > signal_norm:False
 | > symmetric_norm:True
 | > mel_fmin:0
 | > mel_fmax:8000.0
 | > pitch_fmin:1.0
 | > pitch_fmax:640.0
 | > spec_gain:1.0
 | > stft_pad_mode:reflect
 | > max_norm:4.0
 | > clip_norm:True
 | > do_trim_silence:True
 | > trim_db:60
 | > do_sound_norm:False
 | > do_amp_to_db_linear:True
 | > do_amp_to_db_mel:True
 | > do_rms_norm:False
 | > db_level:None
 | > stats_path:None
 | > base:2.71828

## Website

https://finalproject-ztfhsc3heavh3eappjmp5wc.streamlit.app/