In [None]:
!pip install iterative-stratification

In [32]:
import spacy
from spacy.tokens import DocBin
import numpy as np
from collections import Counter
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit

nlp = spacy.blank("en")
doc_bin = DocBin().from_disk("../../data/train.spacy")
docs = list(doc_bin.get_docs(nlp.vocab))

entity_types = sorted({ent.label_ for doc in docs for ent in doc.ents})
n_types = len(entity_types)

# Build a binary matrix indicating entity presence per doc
y = np.zeros((len(docs), n_types))
for i, doc in enumerate(docs):
    labels = {ent.label_ for ent in doc.ents}
    for j, etype in enumerate(entity_types):
        if etype in labels:
            y[i, j] = 1

# Use iterative stratification to split the docs (80% train, 20% dev)
msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_index, dev_index = next(msss.split(np.zeros((len(docs), 1)), y))

train_docs = [docs[i] for i in train_index]
dev_docs = [docs[i] for i in dev_index]

train_doc_bin = DocBin(docs=train_docs)
dev_doc_bin = DocBin(docs=dev_docs)
train_doc_bin.to_disk("train.spacy")
dev_doc_bin.to_disk("dev.spacy")

In [6]:
doc_bin = DocBin().from_disk("data/train.spacy")
docs = list(doc_bin.get_docs(nlp.vocab))

# Count the entities
entity_counter = Counter()
for doc in docs:
    for ent in doc.ents:
        entity_counter[ent.label_] += 1
        
entity_counter

Counter({'PERS': 3254,
         'ORG': 2942,
         'LOC': 1770,
         'JOB': 1048,
         'DATE': 955,
         'MON': 496,
         'PERIOD': 317,
         'ART': 313,
         'MISC': 306,
         'QUANT': 238,
         'PCT': 137,
         'DOC': 81,
         'TIME': 21})

In [7]:
doc_bin = DocBin().from_disk("data/dev.spacy")
docs = list(doc_bin.get_docs(nlp.vocab))

# Count the entities
entity_counter = Counter()
for doc in docs:
    for ent in doc.ents:
        entity_counter[ent.label_] += 1
        
entity_counter

Counter({'PERS': 772,
         'ORG': 700,
         'LOC': 391,
         'JOB': 272,
         'DATE': 251,
         'MON': 122,
         'PERIOD': 85,
         'ART': 85,
         'MISC': 65,
         'QUANT': 55,
         'PCT': 34,
         'DOC': 20,
         'TIME': 5})