Data preparation pipeline

In [1]:
from abc import ABC, abstractmethod
class BaseSaveDriver(ABC):
    """
    Abstract base class for save drivers.
    Allows different storage implementations (local, cloud, etc.).
    """

    def __init__(self, batch_size: int = 100):
        """
        Initialize the base save driver.

        Args:
            batch_size: Number of documents per batch
        """
        self.batch_size = batch_size
        self.current_batch = []
        self.batch_count = 0
        self.documents_processed = 0

    @abstractmethod
    def add_document(self, document):
        """Add a document to the current batch."""
        pass

    @abstractmethod
    def finalize(self):
        """Save any remaining documents and return statistics."""
        pass

    @abstractmethod
    def get_statistics(self):
        """Get current statistics."""
        pass

    @abstractmethod
    def _save_current_batch(self):
        """Abstract method to save the current batch to storage."""
        pass


In [9]:
from spacy.pipeline import EntityRuler
import spacy
import re
from typing import List, Dict, Any, Iterator

class SpacyJSONGenerator:
    def __init__(self, batch_size: int = 100, n_process: int = 1, require_gpu: bool = False):
        """
        Initialize the generator with batching capabilities.

        Args:
            batch_size: Number of texts (sentences) to process in each batch
            n_process: Number of processes for parallel processing (use -1 for all cores)
        """
        # Load the transformer model
        if require_gpu:
            spacy.require_gpu()
        self.nlp = spacy.load("en_core_web_trf", disable=["lemmatizer"])
        self.batch_size = batch_size
        self.n_process = n_process

        # Add EntityRuler for NLE extraction (BEFORE NER for better integration)
        ruler = self.nlp.add_pipe("entity_ruler", before="ner")
        self._setup_nle_patterns(ruler)

    def _setup_nle_patterns(self, ruler: EntityRuler):
        """Setup patterns for Nonlinguistic Entity extraction using EntityRuler."""
        patterns = [
            # Phone patterns
            {"label": "PHONE", "pattern": [{"TEXT": {"REGEX": r"\(?[0-9]{3}\)?[-.\s]?[0-9]{3}[-.\s]?[0-9]{4}"}}]},
            {"label": "PHONE", "pattern": [{"TEXT": {"REGEX": r"\+[1-9]\d{1,14}"}}]},

            # Address patterns
            {"label": "ADDRESS", "pattern": [{"IS_DIGIT": True}, {"IS_ALPHA": True, "OP": "+"}, {"LOWER": {"IN": ["st", "street", "ave", "avenue", "rd", "road", "blvd", "boulevard", "dr", "drive", "ln", "lane", "ct", "court", "pl", "place"]}}]},
            {"label": "ADDRESS", "pattern": [{"LOWER": "p"}, {"TEXT": "."}, {"LOWER": "o"}, {"TEXT": "."}, {"LOWER": "box"}, {"IS_DIGIT": True}]},

            # IP Address patterns
            {"label": "IP_ADDRESS", "pattern": [{"TEXT": {"REGEX": r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"}}]},
            {"label": "IP_ADDRESS", "pattern": [{"TEXT": {"REGEX": r"\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\b"}}]},

            # SSN patterns
            {"label": "SSN", "pattern": [{"IS_DIGIT": True, "LENGTH": 3}, {"TEXT": "-"}, {"IS_DIGIT": True, "LENGTH": 2}, {"TEXT": "-"}, {"IS_DIGIT": True, "LENGTH": 4}]},
            {"label": "SSN", "pattern": [{"IS_DIGIT": True, "LENGTH": 3}, {"IS_SPACE": True}, {"IS_DIGIT": True, "LENGTH": 2}, {"IS_SPACE": True}, {"IS_DIGIT": True, "LENGTH": 4}]},

            # URL and Email patterns (using built-ins)
            {"label": "URL", "pattern": [{"LIKE_URL": True}]},
            {"label": "EMAIL", "pattern": [{"LIKE_EMAIL": True}]}
        ]

        ruler.add_patterns(patterns)


    def _extract_punctuation_spans(self, text: str) -> List[Dict[str, Any]]:
        """Extract punctuation spans from text."""
        punct_spans = []
        punct_pattern = r'[^\w\s]'  # Match non-word, non-space characters

        for match in re.finditer(punct_pattern, text):
            punct_spans.append({
                "start": match.start(),
                "end": match.end(),
                "value": match.group()
            })

        return punct_spans

    def _extract_special_tags_from_doc(self, doc) -> List[Dict[str, Any]]:
        """Extract special tags from spaCy doc (NLEs are now in doc.ents)."""
        special_tags = []

        # Filter NLE entities (non-standard NER labels)
        nle_labels = {"PHONE", "ADDRESS", "IP_ADDRESS", "SSN", "URL", "EMAIL"}

        for ent in doc.ents:
            if ent.label_ in nle_labels:
                special_tags.append({
                    "start": ent.start_char,
                    "end": ent.end_char,
                    "type": ent.label_,
                    "value": ent.text
                })

        return special_tags

    def _get_sentence_spans(self, doc, text: str) -> List[Dict[str, int]]:
        """Extract sentence spans."""
        sent_spans = []
        for sent in doc.sents:
            sent_spans.append({
                "start": sent.start_char,
                "end": sent.end_char
            })
        return sent_spans

    def process_single_doc(self, doc, original_text: str, sentence_id: str) -> Dict[str, Any]:
        """Process a single spaCy doc and return the JSON structure."""

        # Extract sentence spans
        sent_spans = self._get_sentence_spans(doc, original_text)

        # Extract punctuation spans
        punct_spans = self._extract_punctuation_spans(original_text)

        # Extract special tags (NLEs) from doc.ents
        special_tags = self._extract_special_tags_from_doc(doc)

        # Extract named entity spans with entity IDs (standard NER only)
        ner_spans = []
        nle_labels = {"PHONE", "ADDRESS", "IP_ADDRESS", "SSN", "URL", "EMAIL"}

        for ent in doc.ents:
            # Only include standard NER entities, not NLEs
            if ent.label_ not in nle_labels:
                ner_spans.append({
                    "entity_id": f"{ent.label_}-{str(ent).upper().replace(' ', '_').replace('-', '_')}",
                    "start": ent.start_char,
                    "end": ent.end_char,
                    "label": ent.label_
                })

        # Extract POS tokens and tags
        pos_tokens = []
        pos_tags = []
        ner_iob = []

        for token in doc:
            # Skip whitespace-only tokens
            if not token.text.strip():
                continue

            pos_tokens.append(token.text)
            pos_tags.append(token.pos_)

            # Determine IOB tag
            if token.ent_iob_ == 'B':
                ner_iob.append(f"B-{token.ent_type_}")
            elif token.ent_iob_ == 'I':
                ner_iob.append(f"I-{token.ent_type_}")
            else:
                ner_iob.append("O")

        # Build the final JSON structure
        result = {
            "id": sentence_id,
            "text": original_text,
            "sent_spans": sent_spans,
            "punct_spans": punct_spans,
            "special_tags": special_tags,
            "ner_spans": ner_spans,
            "pos_tokens": pos_tokens,
            "pos_tags": pos_tags,
            "ner_iob": ner_iob
        }

        return result

    def process_sentences_batch(self, sentences: List[str], sentence_ids: List[str] = None):
        """Process a batch of sentences efficiently."""
        if sentence_ids is None:
            sentence_ids = [f"sent_{str(uuid.uuid4())}" for _ in range(len(sentences))]

        # Process batch with spaCy
        docs = list(self.nlp.pipe(sentences, batch_size=self.batch_size, n_process=self.n_process))

        # Process each doc
        results = []
        for doc, original_text, sent_id in zip(docs, sentences, sentence_ids):
            result = self.process_single_doc(doc, original_text, sent_id)
            results.append(result)

        return results


    def process_sentences_streaming(self, sentences: Iterator[str],
                                   sentence_id_generator: Iterator[str] = None) -> Iterator[Dict[str, Any]]:
        """Process sentences in streaming fashion with batching."""
        sentence_batch = []
        id_batch = []

        for i, sentence in enumerate(sentences):
            sentence_batch.append(sentence)

            if sentence_id_generator:
                id_batch.append(next(sentence_id_generator))
            else:
                id_batch.append(f"sent_{i:07d}")

            # Process batch when it reaches batch_size
            if len(sentence_batch) >= self.batch_size:
                results = self.process_sentences_batch(sentence_batch, id_batch)
                for result in results:
                    yield result

                # Clear batches
                sentence_batch = []
                id_batch = []

        # Process remaining sentences
        if sentence_batch:
            results = self.process_sentences_batch(sentence_batch)
            for result in results:
                yield result

    def process_and_save(self, dataset, save_driver: BaseSaveDriver, num_batches=None, resume_from_progress=True):
        """
        Process dataset using Hugging Face map() function with configurable save driver.
        Includes detailed timing measurements and bottleneck analysis.

        Args:
            dataset: Hugging Face dataset
            save_driver: SaveDriver instance for handling storage (local, cloud, etc.)
            num_batches: Number of batches to process (None = process all)
            resume_from_progress: Whether to resume from existing progress (if available)

        Returns:
            BaseSaveDriver: The save driver instance with statistics
        """

        print(f"🚀 Starting HF map() optimized processing with {save_driver.__class__.__name__}...")

        # Check if we should resume from existing progress
        documents_to_skip = 0
        initial_batch_count = 0
        if resume_from_progress and hasattr(save_driver, 'progress_data'):
            progress = save_driver.progress_data
            if progress['documents_processed'] > 0:
                documents_to_skip = progress['documents_processed']
                initial_batch_count = progress['batch_count']

                print(f"🔄 Resuming from previous progress:")
                print(f"   📄 Documents already processed: {progress['documents_processed']}")
                print(f"   📦 Batches already created: {progress['batch_count']}")
                print(f"⏭️  Skipping first {documents_to_skip} documents...")

        def process_batch_texts(batch):
            """Process a batch of texts with spaCy using HF map."""

            texts = [text for text in batch['text'] if len(text) >= 10]

            if not texts:
                return {'processed': [None] * len(batch['text'])}

            try:
                processed_docs = self.process_sentences_batch(texts)
                return {'processed': processed_docs}
            except Exception as e:
                print(f"❌ Error processing batch: {e}")
                return {'processed': [None] * len(batch['text'])}

        print(f"Skipping documents (if needed) and adding mapping")
        processed_dataset = dataset['train'].skip(documents_to_skip * self.batch_size).map(
            process_batch_texts,
            batched=True,
            batch_size=self.batch_size,
            remove_columns=['text']
        )

        print("💾 Processing and saving data...")

        processed_count = 0

        try:
            for example in processed_dataset:


                save_driver.add_document(example['processed'])
                processed_count += 1

                # Check batch count more frequently to respect num_batches limit
                current_batch_count = save_driver.batch_count
                new_batches_created = current_batch_count - initial_batch_count

                # Check if we've processed enough NEW batches (check after each document)
                if num_batches is not None and new_batches_created >= num_batches:
                    print(f"🛑 Reached target of {num_batches} new batches. Stopping...")
                    print(f"   📊 Total batches: {current_batch_count}, New batches this run: {new_batches_created}")
                    break



        except KeyboardInterrupt:
            print("\n⚠️  Processing interrupted by user. Progress saved.")
            if hasattr(save_driver, '_save_progress'):
                save_driver._save_progress()
            raise
        except Exception as e:
            print(f"\n❌ Processing failed: {e}")
            print("💾 Progress saved. You can resume later.")
            if hasattr(save_driver, '_save_progress'):
                save_driver._save_progress()
            raise


        # Finalize and get statistics
        batch_count, documents_processed = save_driver.finalize()

        # Calculate total time and performance metrics

        print(f"\n🎉 Processing completed!")
        print(f"📊 Performance Summary:")
        print(f"   📄 Documents processed: {documents_processed}")
        print(f"   📦 Batches created: {batch_count}")

        return save_driver

class CloudSaveDriver(BaseSaveDriver):
    """
    Google Cloud Storage (GCS) implementation for saving processed batches.
    """

    def __init__(self, bucket_name=None, project_id=None, batch_size=100, progress_file="gcs_processing_progress.json"):
        """
        Initialize the CloudSaveDriver with GCS support.

        Args:
            bucket_name: GCS bucket name (optional, can be set in config)
            batch_size: Number of documents per batch file
            progress_file: File to store processing progress for resumption
        """
        super().__init__(batch_size)
        self.progress_file = progress_file
        self.progress_data = self._load_progress()

        # Import here to avoid dependency issues if GCS not installed
        try:
            from google.cloud import storage
            from google.cloud.exceptions import GoogleCloudError

            from Config import Config

        except ImportError as e:
            raise ImportError(f"GCS dependencies not installed. Run: pip install google-cloud-storage. Error: {e}")

        # Store imports for use in methods
        self.storage = storage
        self.GoogleCloudError = GoogleCloudError
        self.config = Config

        # Validate and get GCS configuration
        try:
            self.config.validate_gcs_config()
            self.bucket_name = bucket_name
            self.project_id = project_id
            self.credentials = self.config.get_gcs_credentials()
        except Exception as e:
            raise ValueError(f"GCS configuration error: {e}")

        # Initialize GCS client
        try:

            # Credentials from file path
            self.client = self.storage.Client.from_service_account_json(
                self.credentials,
                project=self.project_id
            )

            # Get bucket reference
            self.bucket = self.client.bucket(self.bucket_name)

            # Test bucket access
            if not self.bucket.exists():
                raise ValueError(f"Bucket '{self.bucket_name}' does not exist or is not accessible")

        except Exception as e:
            raise RuntimeError(f"Failed to initialize GCS client: {e}")

        print(f"☁️  CloudSaveDriver (GCS) initialized:")
        print(f"  - Bucket: {self.bucket_name}")
        print(f"  - Project: {self.project_id}")
        print(f"  - Batch size: {self.batch_size}")

        # Restore state from progress file
        if self.progress_data['documents_processed'] > 0:
            self.documents_processed = self.progress_data['documents_processed']
            self.batch_count = self.progress_data['batch_count']
            print(f"  - Resuming from: {self.documents_processed} docs, {self.batch_count} batches")

    def _load_progress(self):
        """Load existing progress if available."""
        if os.path.exists(self.progress_file):
            try:
                with open(self.progress_file, 'r') as f:
                    progress = json.load(f)
                    print(f"📋 Loaded existing progress: {progress['documents_processed']} docs, {progress['batch_count']} batches")
                    return progress
            except Exception as e:
                print(f"⚠️  Could not load progress file: {e}")

        return {
            'documents_processed': 0,
            'batch_count': 0,
            'start_time': time.time()
        }

    def _save_progress(self):
        """Save current progress to file."""
        self.progress_data.update({
            'documents_processed': self.documents_processed,
            'batch_count': self.batch_count,
            'last_save_time': time.time()
        })

        try:
            with open(self.progress_file, 'w') as f:
                json.dump(self.progress_data, f, indent=2)
        except Exception as e:
            print(f"⚠️  Could not save progress: {e}")

    def add_document(self, document):
        """
        Add a document to the current batch.

        Args:
            document: Processed document to add
        """
        if document is not None:
            self.current_batch.append(document)
            self.documents_processed += 1

            # Save batch when it reaches the desired size
            if len(self.current_batch) >= self.batch_size:
                self._save_current_batch()

    def _save_current_batch(self):
        """
        Save the current batch to GCS bucket.
        """
        if not self.current_batch:
            return

        save_start = time.time()
        self.batch_count += 1

        # Create filename with timestamp and batch number
        timestamp = int(time.time())
        filename = f"batch_{self.batch_count:06d}_{timestamp}.json"

        try:
            # Create temporary file for JSON data
            with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as temp_file:
                json.dump(self.current_batch, temp_file, ensure_ascii=False)
                temp_file_path = temp_file.name

            blob = self.bucket.blob(filename)

            for attempt in range(self.config.GCS_RETRY_ATTEMPTS):
                try:
                    blob.upload_from_filename(temp_file_path, content_type='application/json')

                    break
                except self.GoogleCloudError as e:
                    if attempt == self.config.GCS_RETRY_ATTEMPTS - 1:
                        raise
                    print(f"⚠️  Upload attempt {attempt + 1} failed, retrying... Error: {e}")
                    time.sleep(2 ** attempt)  # Exponential backoff

                # Configure upload settings for large files
                # blob.chunk_size = self.config.GCS_UPLOAD_CHUNK_SIZE # default is 100MB

                # Upload with retry logic

            os.unlink(temp_file_path)  # Clean up

            save_time = time.time() - save_start
            file_size_mb = len(json.dumps(self.current_batch)) / 1024 / 1024

            print(f"☁️  Saved batch {self.batch_count} with {len(self.current_batch)} documents to gs://{self.bucket_name}/{filename}")
            print(f"   ⏱️  Upload time: {save_time:.3f}s, Size: {file_size_mb:.1f} MB, Rate: {len(self.current_batch)/save_time:.1f} docs/sec")

        except Exception as e:
            print(f"❌ Failed to save batch {self.batch_count} to GCS: {e}")
            # Clean up temp file if it exists
            try:
                if 'temp_file_path' in locals():
                    os.unlink(temp_file_path)
            except:
                pass
            raise

        # Save progress after each batch
        self._save_progress()

        # Clear current batch to free memory
        self.current_batch = []

    def finalize(self):
        """
        Save any remaining documents and return statistics.
        """
        finalize_start = time.time()

        # Save remaining documents if any
        if self.current_batch:
            print(f"🔄 Finalizing: saving remaining {len(self.current_batch)} documents to GCS...")
            self._save_current_batch()

        finalize_time = time.time() - finalize_start
        print(f"✅ GCS finalization completed in {finalize_time:.3f}s")
        print(f"📊 CloudSaveDriver (GCS) completed:")
        print(f"  - Total batches: {self.batch_count}")
        print(f"  - Total documents: {self.documents_processed}")
        print(f"  - Bucket: gs://{self.bucket_name}")

        # Clean up progress file on successful completion
        # if os.path.exists(self.progress_file):
        #     os.remove(self.progress_file)
        #     print("🧹 Cleaned up progress file")

        return self.batch_count, self.documents_processed

    def get_statistics(self):
        """Get current statistics."""
        return {
            'batches_created': self.batch_count,
            'documents_processed': self.documents_processed,
            'current_batch_size': len(self.current_batch),
            'storage_type': 'gcs',
            'bucket_name': self.bucket_name,
            'project_id': self.project_id
        }

    def list_batches(self):
        """
        List all batch files in the GCS bucket.

        Returns:
            list: List of blob objects representing batch files
        """
        try:
            blobs = list(self.bucket.list_blobs(prefix="batch_"))
            return sorted(blobs, key=lambda x: x.name)
        except Exception as e:
            print(f"❌ Failed to list batches from GCS: {e}")
            return []

    def load_batch(self, blob):
        """
        Load a batch from GCS.

        Args:
            blob: GCS blob object or blob name

        Returns:
            list: List of processed documents
        """
        try:
            if isinstance(blob, str):
                blob = self.bucket.blob(blob)

            # Download to temporary file
            with tempfile.NamedTemporaryFile(mode='w+', suffix='.json', delete=False) as temp_file:
                blob.download_to_filename(temp_file.name)

                # Load JSON data
                with open(temp_file.name, 'r', encoding='utf-8') as f:
                    data = json.load(f)

                # Clean up
                import os
                os.unlink(temp_file.name)

                return data

        except Exception as e:
            print(f"❌ Failed to load batch from GCS: {e}")
            return []

class CloudParquetSaveDriver(CloudSaveDriver):
    """
    Google Cloud Storage (GCS) implementation for saving processed batches to parquet files.
    """
    def __init__(self, bucket_name=None, project_id=None, batch_size=100, progress_file="gcs_processing_progress.json"):
        super().__init__(bucket_name, project_id, batch_size, progress_file)

    def _save_current_batch(self):
        """
        Save the current batch to GCS bucket.
        """
        if not self.current_batch:
            return

        self.batch_count += 1

        # Create filename with timestamp and batch number
        timestamp = int(time.time())
        filename = f"batch_{self.batch_count:06d}_{timestamp}.parquet"
        save_start = time.time()
        try:
            # Create temporary file for Parquet data
            with tempfile.NamedTemporaryFile(mode='w', suffix='.parquet', delete=False) as temp_file:
                pa.Table.from_pylist(self.current_batch).to_pandas().to_parquet(temp_file.name)
                temp_file_path = temp_file.name

                blob = self.bucket.blob(filename)

                for attempt in range(self.config.GCS_RETRY_ATTEMPTS):
                    try:
                        blob.upload_from_filename(temp_file_path, content_type='application/parquet')

                        break
                    except self.GoogleCloudError as e:
                        if attempt == self.config.GCS_RETRY_ATTEMPTS - 1:
                            raise
                        print(f"⚠️  Upload attempt {attempt + 1} failed, retrying... Error: {e}")
                        time.sleep(2 ** attempt)  # Exponential backoff

                os.unlink(temp_file_path)  # Clean up

                save_time = time.time() - save_start
                file_size_mb = len(json.dumps(self.current_batch)) / 1024 / 1024

                print(f"☁️  Saved batch {self.batch_count} with {len(self.current_batch)} documents to gs://{self.bucket_name}/{filename}")
                print(f"   ⏱️  Upload time: {save_time:.3f}s, Size: {file_size_mb:.1f} MB, Rate: {len(self.current_batch)/save_time:.1f} docs/sec")

                self._save_progress()
        except Exception as e:
            print(f"❌ Failed to save batch {self.batch_count} to GCS: {e}")
            # Clean up temp file if it exists
            try:
                if 'temp_file_path' in locals():
                    os.unlink(temp_file_path)
            except:
                pass
            raise


    def load_batch(self, blob):
        """
        Load a batch from GCS.
        """
        try:
            if isinstance(blob, str):
                blob = self.bucket.blob(blob)
            else:
                # Download to temporary file
                with tempfile.NamedTemporaryFile(mode='w+', suffix='.parquet', delete=False) as temp_file:
                    blob.download_to_filename(temp_file.name)

                    # Load Parquet data
                    table = pq.read_table(temp_file.name)

                    # Clean up
                    os.unlink(temp_file.name)

                    return table.to_pylist()

        except Exception as e:
            print(f"❌ Failed to load batch from GCS: {e}")
            return []




In [3]:
from datasets import load_dataset

In [4]:
!python -m spacy download en_core_web_trf

Collecting en-core-web-trf==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.8.0/en_core_web_trf-3.8.0-py3-none-any.whl (457.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m457.4/457.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_trf')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [5]:
!pip install datasets==3.6.0



In [6]:
generator = SpacyJSONGenerator(batch_size=50, n_process=1,require_gpu=True)
dataset = load_dataset("Skylion007/openwebtext", trust_remote_code=True, streaming=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
gcs_save_driver = CloudParquetSaveDriver(
    bucket_name="parquet_v1_openwebtext-with-pos-ner",
    project_id="eastern-bridge-472408-d3",
    batch_size=50  # Small batch size for testing
    )

NameError: name 'CloudParquetSaveDriver' is not defined

In [None]:
result_driver = generator.process_and_save(
            dataset=dataset,
            save_driver=gcs_save_driver,
            num_batches=1  # Process the whole dataset
        )