In [None]:
import os
import re
import cv2
import pandas as pd
import numpy as np
import easyocr
from PIL import Image
import logging
from tqdm import tqdm

### Key Components:

- **entity_allowed_units**: Dictionary containing allowed units for each entity (width, depth, height, weight, voltage, etc.).
- **unit_normalization_map**: Mapping of unit abbreviations and alternative forms to standard units.
- **entity_keywords**: Dictionary of keywords used to identify specific entities (e.g., 'width', 'weight') in the extracted text.

In [None]:
logging.basicConfig(
    filename='ocr_pipeline_part1.log',
    filemode='a',
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)

reader = easyocr.Reader(['en'], gpu=True)

entity_allowed_units = {
    'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'item_weight': {'gram', 'kilogram', 'microgram', 'milligram', 'ounce', 'pound', 'ton'},
    'maximum_weight_recommendation': {'gram', 'kilogram', 'microgram', 'milligram', 'ounce', 'pound', 'ton'},
    'voltage': {'kilovolt', 'millivolt', 'volt'},
    'wattage': {'kilowatt', 'watt'},
    'item_volume': {'centilitre', 'cubic foot', 'cubic inch', 'cup', 'decilitre', 'fluid ounce', 'gallon',
                    'imperial gallon', 'litre', 'microlitre', 'millilitre', 'pint', 'quart'}
}

unit_normalization_map = {
    'cm': 'centimetre', 'centimeter': 'centimetre', 'centimetres': 'centimetre', 'centimeters': 'centimetre',
    'm': 'metre', 'meter': 'metre', 'metres': 'metre', 'meters': 'metre',
    'mm': 'millimetre', 'millimeter': 'millimetre', 'millimetres': 'millimetre', 'millimeters': 'millimetre',
    'ft': 'foot', 'feet': 'foot', 'foot': 'foot',
    'in': 'inch', 'inch': 'inch', 'inches': 'inch',
    'yd': 'yard', 'yard': 'yard', 'yards': 'yard',
    'g': 'gram', 'gram': 'gram', 'grams': 'gram',
    'kg': 'kilogram', 'kilogram': 'kilogram', 'kilograms': 'kilogram',
    'mg': 'milligram', 'milligram': 'milligram', 'milligrams': 'milligram',
    'μg': 'microgram', 'microgram': 'microgram', 'micrograms': 'microgram',
    'lb': 'pound', 'lbs': 'pound', 'pound': 'pound', 'pounds': 'pound',
    'oz': 'ounce', 'ounce': 'ounce', 'ounces': 'ounce',
    'ton': 'ton', 'tons': 'ton',
    'kv': 'kilovolt', 'kilovolt': 'kilovolt', 'kilovolts': 'kilovolt',
    'mv': 'millivolt', 'millivolt': 'millivolt', 'millivolts': 'millivolt',
    'v': 'volt', 'volt': 'volt', 'volts': 'volt',
    'kw': 'kilowatt', 'kilowatt': 'kilowatt', 'kilowatts': 'kilowatt',
    'w': 'watt', 'watt': 'watt', 'watts': 'watt',
    'l': 'litre', 'liter': 'litre', 'litre': 'litre', 'liters': 'litre', 'litres': 'litre',
    'ml': 'millilitre', 'millilitre': 'millilitre', 'milliliter': 'millilitre', 'milliliters': 'millilitre', 'millilitres': 'millilitre',
    'cl': 'centilitre', 'centilitre': 'centilitre', 'centiliter': 'centilitre', 'centiliters': 'centilitre', 'centilitres': 'centilitre',
    'dl': 'decilitre', 'decilitre': 'decilitre', 'deciliter': 'decilitre', 'deciliters': 'decilitre', 'decilitres': 'decilitre',
    'μl': 'microlitre', 'microlitre': 'microlitre', 'microliter': 'microlitre', 'microliters': 'microlitre', 'microlitres': 'microlitre',
    'gal': 'gallon', 'gallon': 'gallon', 'gallons': 'gallon',
    'cup': 'cup', 'cups': 'cup',
    'pt': 'pint', 'pint': 'pint', 'pints': 'pint',
    'qt': 'quart', 'quart': 'quart', 'quarts': 'quart',
    'fl oz': 'fluid ounce', 'fluid ounce': 'fluid ounce', 'fluid ounces': 'fluid ounce',
    'cu ft': 'cubic foot', 'cubic foot': 'cubic foot', 'cubic feet': 'cubic foot',
    'cu in': 'cubic inch', 'cubic inch': 'cubic inch', 'cubic inches': 'cubic inch',
    'imperial gallon': 'imperial gallon', 'imperial gallons': 'imperial gallon',
}

allowed_units = set()
for units in entity_allowed_units.values():
    allowed_units.update(units)

entity_keywords = {
    'width': ['width', 'breadth', 'wide'],
    'depth': ['depth'],
    'height': ['height', 'length'],
    'item_weight': ['weight', 'wt'],
    'maximum_weight_recommendation': ['maximum weight', 'recommended weight', 'max weight', 'max wt', 'maximum wt'],
    'voltage': ['voltage', 'volt'],
    'wattage': ['wattage', 'watt'],
    'item_volume': ['capacity', 'volume', 'size']
}

### Functions:

- `normalize_unit(unit)`:
  Normalize units to a standard form if they exist in the normalization map.
  
- `preprocess_image(img, max_size=1024)`:
  Resize the image to a maximum size for OCR processing.
  
- `extract_numerical_values(text, entity)`:
  Extract numerical values and associated units from the text, based on the provided entity (e.g., 'width', 'weight').
  
- `extract_dimensions_pattern(text)`:
  Extract length, width, and height dimensions from the text using a specific pattern (e.g., "12 x 24 x 36 cm").
  
- `extract_entity_values(text, query_entities)`:
  Extract values for multiple entities (e.g., weight, volume) from the provided text by looking for keywords and associated numerical values.
  
- `extract_and_save_entities(df, img_dir, output_csv='ocr_results.csv')`:
  The main function that processes a dataframe containing image paths, applies OCR, extracts relevant entities and their values, and saves the results in a CSV file.

In [None]:
def normalize_unit(unit):
    unit = unit.lower().strip()
    normalized_unit = unit_normalization_map.get(unit, unit)
    if normalized_unit in allowed_units:
        return normalized_unit
    else:
        return None

def preprocess_image(img, max_size=1024):
    height, width = img.shape[:2]
    if max(height, width) > max_size:
        scaling_factor = max_size / float(max(height, width))
        img = cv2.resize(img, None, fx=scaling_factor, fy=scaling_factor, interpolation=cv2.INTER_AREA)
    return img

def extract_numerical_values(text, entity):
    units = entity_allowed_units.get(entity, set())
    if not units:
        return []

    all_units = units.copy()
    for unit in units:
        for abbr, full_unit in unit_normalization_map.items():
            if full_unit == unit:
                all_units.add(abbr)

    units_escaped = [re.escape(u) for u in all_units]
    pattern = r'(\d+(?:.\d+)?)\s*(' + '|'.join(units_escaped) + r')\b'
    matches = re.findall(pattern, text, re.IGNORECASE)

    results = []
    for match in matches:
        number_str = match[0]
        unit = match[1]
        normalized_unit = normalize_unit(unit)
        if normalized_unit:
            try:
                number = float(number_str)
                results.append((number, normalized_unit))
            except ValueError:
                continue
    return results

def extract_dimensions_pattern(text):
    pattern = r'(\d+(?:.\d+)?)\sx\s(\d+(?:.\d+)?)\sx\s(\d+(?:.\d+)?)\s*([a-zA-Z]+)'
    matches = re.findall(pattern, text, re.IGNORECASE)
    results = {}
    if matches:
        for match in matches:
            length_value = float(match[0])
            width_value = float(match[1])
            height_value = float(match[2])
            unit = match[3]
            normalized_unit = normalize_unit(unit)
            if normalized_unit:
                results['length'] = (length_value, normalized_unit)
                results['width'] = (width_value, normalized_unit)
                results['height'] = (height_value, normalized_unit)
                break
    return results

def extract_entity_values(text, query_entities):
    entity_matches = {entity: [] for entity in query_entities}

    dimension_matches = extract_dimensions_pattern(text)
    if dimension_matches:
        for entity in query_entities:
            if entity in dimension_matches:
                entity_matches[entity].append(dimension_matches[entity])

    for entity in query_entities:
        if not entity_matches[entity]:
            keywords = entity_keywords.get(entity, [])
            for keyword in keywords:
                pattern = r'\b' + re.escape(keyword) + r'\b.{0,20}?(\d+(?:.\d+)?)(?:\s*([a-zA-Z]+))?'
                matches = re.findall(pattern, text, re.IGNORECASE)
                if matches:
                    for match in matches:
                        number_str = match[0]
                        unit = match[1] if len(match) > 1 else None
                        if unit:
                            normalized_unit = normalize_unit(unit)
                        else:
                            normalized_unit = None

                        try:
                            number = float(number_str)
                            if normalized_unit:
                                entity_matches[entity].append((number, normalized_unit))
                            else:
                                entity_matches[entity].append((number, None))
                        except ValueError:
                            continue

        if not entity_matches[entity]:
            entity_matches.pop(entity)

    return entity_matches

def ocr_extract_entities(img_path, query_entities):
    """
    Perform OCR on the image data and extract specified entities.
    """
    try:
        img = cv2.imread(img_path)
        if img is None:
            logging.warning(f"Could not read image {img_path}")
            return {entity: [] for entity in query_entities}

        img = preprocess_image(img)

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        result = reader.readtext(img, detail=0, paragraph=True)
        text = ' '.join(result)

        extracted_entities = extract_entity_values(text, query_entities)
        return extracted_entities
    except Exception as e:
        logging.error(f"OCR extraction failed for {img_path}: {e}")
        return {entity: [] for entity in query_entities}

In [None]:
def main():
    test_csv_path = '/kaggle/input/testfile2/test.csv'
    images_folder = '/kaggle/input/the-dataset/images/test'
    output_file = 'output.csv'
    df = pd.read_csv(test_csv_path)
    logging.info(f"Loaded {len(df)} entries from test.csv")

    df_part = df.reset_index(drop=True)
    logging.info(f"Processing all {len(df_part)} rows")

    ocr_cache = {}
    output_list = []
    total = len(df_part)

    for idx, row in tqdm(df_part.iterrows(), total=total):
        index = row['index']
        image_link = row['image_link']
        entity_name = row['entity_name']
        image_filename = os.path.basename(image_link)
        image_path = os.path.join(images_folder, image_filename)

        if image_filename in ocr_cache:
            extracted_entities = ocr_cache[image_filename]
        else:
            extracted_entities = ocr_extract_entities(image_path, [entity_name])
            ocr_cache[image_filename] = extracted_entities

        entity_values = extracted_entities.get(entity_name, [])

        if entity_values:
            prediction = ''
            for number, unit in entity_values:
                if unit:
                    if float(number).is_integer():
                        number_formatted = f"{int(number)}"
                    else:
                        number_formatted = f"{number:.2f}".rstrip('0').rstrip('.')
                    prediction = f"{number_formatted} {unit}"
                    break
        else:
            prediction = ""

        output_list.append({'index': index, 'prediction': prediction})

    output_df = pd.DataFrame(output_list)
    output_df = output_df.set_index('index')
    output_df = output_df.reindex(df['index'])
    output_df = output_df.reset_index()
    output_df.to_csv(output_file, index=False)

    logging.info(f"Processing completed. Results saved to {output_file}")

if __name__ == "__main__":
    main()