In [2]:
!pip install pytesseract transformers torch torchvision spacy scikit-learn tqdm
!python -m spacy download en_core_web_sm





[notice] A new release of pip is available: 23.3.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
     ---------------------------------------- 0.1/12.8 MB 1.7 MB/s eta 0:00:08
      --------------------------------------- 0.2/12.8 MB 2.9 MB/s eta 0:00:05
     - -------------------------------------- 0.6/12.8 MB 4.4 MB/s eta 0:00:03
     --- ------------------------------------ 1.0/12.8 MB 5.7 MB/s eta 0:00:03
     ---- ----------------------------------- 1.5/12.8 MB 6.8 MB/s eta 0:00:02
     ------ --------------------------------- 2.0/12.8 MB 7.7 MB/s eta 0:00:02
     -------- ------------------------------- 2.6/12.8 MB 8.3 MB/s eta 0:00:02
     --------- ------------------------------ 3.1/12.8 MB 8.6 MB/s eta 0:00:02
     ----------- ---------------------------- 3.6/12.8 MB 8.9 MB/s eta 0:00:02
     ------------ -----------------------


[notice] A new release of pip is available: 23.3.1 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import os
import pandas as pd
import time
from tqdm import tqdm
from sklearn.metrics import f1_score
from src.utils import download_images
from src.preprocessing import preprocess_image
from src.model import ocr_nlp_pipeline, save_predictions_to_csv

In [3]:
# ---- Configuration ----
BATCH_SIZE = 500  # Adjust this based on your memory/capacity
IMAGE_DIR = 'images/'  # Folder where images will be stored
TRAIN_CSV = 'dataset/train.csv'
TEST_CSV = 'dataset/test.csv'
OUTPUT_CSV = 'output/test_out.csv'

In [5]:
# ---- Step 1: Download Images in Batches ----

def download_images_batch(test_csv, image_dir, batch_size=500):
    """
    Downloads images in batches from a list of URLs provided in the test_csv.
    
    Parameters:
    - test_csv: str, path to the CSV file containing image URLs.
    - image_dir: str, path to save the downloaded images.
    - batch_size: int, number of images to download per batch.
    """
    test_data = pd.read_csv(test_csv)
    total_images = len(test_data)
    
    # Make sure the image directory exists
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
    
    for i in tqdm(range(0, total_images, batch_size)):
        batch_data = test_data.iloc[i:i + batch_size]
        download_images(batch_data, image_dir)
        time.sleep(1)  # Small delay to avoid overwhelming the server

# Download images for the test dataset
download_images_batch(TEST_CSV, IMAGE_DIR, BATCH_SIZE)

  0%|          | 0/263 [00:00<?, ?it/s]Exception in thread Thread-10 (_handle_workers):
Traceback (most recent call last):
  File "c:\Users\DELL\AppData\Local\Programs\Python\Python312\Lib\threading.py", line 1052, in _bootstrap_inner
    self.run()
  File "c:\Users\DELL\AppData\Local\Programs\Python\Python312\Lib\threading.py", line 989, in run
    self._target(*self._args, **self._kwargs)
  File "c:\Users\DELL\AppData\Local\Programs\Python\Python312\Lib\multiprocessing\pool.py", line 522, in _handle_workers
    cls._wait_for_updates(current_sentinels, change_notifier)
  File "c:\Users\DELL\AppData\Local\Programs\Python\Python312\Lib\multiprocessing\pool.py", line 502, in _wait_for_updates
    wait(sentinels, timeout=timeout)
  File "c:\Users\DELL\AppData\Local\Programs\Python\Python312\Lib\multiprocessing\connection.py", line 1066, in wait
    ready_handles = _exhaustive_wait(waithandle_to_obj.keys(), timeout)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  

In [4]:
# ---- Step 2: Preprocess Images in Batches ----

def preprocess_images_batch(image_dir, batch_size=500):
    """
    Preprocess images in batches, applying transformations as needed.
    
    Parameters:
    - image_dir: str, path where the images are stored.
    - batch_size: int, number of images to preprocess per batch.
    
    Returns:
    - preprocessed_images: list, paths of preprocessed images.
    """
    image_files = os.listdir(image_dir)
    total_images = len(image_files)
    
    preprocessed_images = []
    
    for i in tqdm(range(0, total_images, batch_size)):
        batch_files = image_files[i:i + batch_size]
        for img_file in batch_files:
            img_path = os.path.join(image_dir, img_file)
            preprocessed_img = preprocess_image(img_path)
            preprocessed_images.append(preprocessed_img)
    
    return preprocessed_images

# Preprocess images
preprocessed_images = preprocess_images_batch(IMAGE_DIR, BATCH_SIZE)

100%|██████████| 1/1 [00:01<00:00,  1.17s/it]


In [5]:
# ---- Step 3: Model Prediction ----

def generate_predictions(image_dir, test_csv, entity_type, batch_size=500):
    """
    Generate predictions in batches using the OCR + NLP pipeline.
    
    Parameters:
    - image_dir: str, directory where images are stored.
    - test_csv: str, path to the test CSV file.
    - entity_type: str, the type of entity to extract (e.g., "item_weight").
    - batch_size: int, number of images to process in each batch.
    
    Returns:
    - output_df: DataFrame, containing the index and predictions.
    """
    test_data = pd.read_csv(test_csv)
    total_images = len(test_data)
    
    results = []
    
    for i in tqdm(range(0, total_images, batch_size)):
        batch_data = test_data.iloc[i:i + batch_size]
        batch_results = ocr_nlp_pipeline(image_dir, batch_data, entity_type)
        results.extend(batch_results)
    
    output_df = pd.DataFrame(results, columns=['index', 'prediction'])
    return output_df

# Generate predictions
entity_type = "item_weight"  # Modify this depending on your use case
output_df = generate_predictions(IMAGE_DIR, TEST_CSV, entity_type, BATCH_SIZE)

# Save predictions to CSV
save_predictions_to_csv(output_df, OUTPUT_CSV)

  0%|          | 0/263 [00:00<?, ?it/s]


TypeError: argument of type 'method' is not iterable