In [1]:
# ==== 0) Imports & setup ====
import os, json, math, random
from typing import List, Dict, Any, Tuple
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from gliner import GLiNER
from transformers import AutoTokenizer

# Device (MPS on Apple Silicon, else CUDA, else CPU)
device = torch.device("mps" if torch.backends.mps.is_available()
                      else "cuda" if torch.cuda.is_available()
                      else "cpu")

# ==== 1) Label map (Kaggle -> GLiNER strings) ====
PII_LABELS = [
    "name", "email", "username", "id number", "phone number", "url", "street address"
]
label_map = {
    "NAME_STUDENT": "name",
    "EMAIL": "email",
    "USERNAME": "username",
    "ID_NUM": "id number",
    "PHONE_NUM": "phone number",
    "URL_PERSONAL": "url",
    "STREET_ADDRESS": "street address",
}

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ==== 2) Helpers: reconstruct text and convert BIO -> char spans ====
def reconstruct_text(tokens: List[str], whitespaces: List[bool]) -> str:
    # faithful reconstruction per dataset spec
    return "".join([t + (" " if ws else "") for t, ws in zip(tokens, whitespaces)]).strip()

def bio_to_char_spans(tokens: List[str], whitespaces: List[bool], bio: List[str], text: str) -> List[Dict[str, Any]]:
    spans = []
    start_char = 0
    current = None  # {"label": str, "start": int}

    for tok, tag, ws in zip(tokens, bio, whitespaces):
        end_char = start_char + len(tok)

        if tag.startswith("B-"):
            # close previous
            if current is not None:
                spans.append({"start": current["start"], "end": current["end"], "label": current["label"]})
            raw = tag[2:]
            if raw in label_map:
                current = {"label": label_map[raw], "start": start_char, "end": end_char}
            else:
                current = None

        elif tag.startswith("I-"):
            if current is not None:
                current["end"] = end_char
            # else: stray I- (ignore)

        else:  # "O"
            if current is not None:
                spans.append({"start": current["start"], "end": current["end"], "label": current["label"]})
                current = None

        start_char = end_char + (1 if ws else 0)

    if current is not None:
        spans.append({"start": current["start"], "end": current["end"], "label": current["label"]})

    # Optional: sanity clamp within text length
    L = len(text)
    clean = []
    for s in spans:
        if 0 <= s["start"] < s["end"] <= L:
            clean.append(s)
    return clean


In [3]:

# ==== 3) Load Kaggle JSON and convert to GLiNER format ====
def load_kaggle_and_convert(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    gliner_samples = []
    for ex in data:
        tokens = ex["tokens"]
        ws = ex["trailing_whitespace"]
        bio = ex["labels"]
        # prefer faithful reconstruction over provided full_text
        text = reconstruct_text(tokens, ws)
        spans = bio_to_char_spans(tokens, ws, bio, text)
        # drop empty ones (optional)
        if len(spans) == 0:
            continue
        gliner_samples.append({"text": text, "entities": spans})
    return gliner_samples


In [4]:

# ==== 4) Train/val split ====
def split_train_val(samples: List[Dict[str, Any]], val_ratio=0.2, seed=42):
    random.seed(seed)
    idx = list(range(len(samples)))
    random.shuffle(idx)
    cut = int((1 - val_ratio) * len(idx))
    train_idx, val_idx = idx[:cut], idx[cut:]
    train = [samples[i] for i in train_idx]
    val = [samples[i] for i in val_idx]
    return train, val

In [5]:

# ==== 5) GLiNER Dataset for PyTorch DataLoader (expects 'ner' and 'labels') ====
class GLiNERSpanDataset(Dataset):
    def __init__(self, data, label_list):
        self.data = data
        self.labels = label_list

    def __len__(self): return len(self.data)

    def __getitem__(self, i):
        ex = self.data[i]
        return {
            "text": ex["text"],
            "ner": [(e["start"], e["end"], e["label"]) for e in ex["entities"]],
            "labels": self.labels,
        }


In [6]:
# 0) One-time
import nltk; nltk.download("punkt")
from nltk.tokenize import PunktSentenceTokenizer

MAX_TOKENS = 512

def sentence_window_examples(samples, tokenizer, max_tokens=512, keep_empty_windows=True):
    """Split each sample (text, entities) into sentence-packed windows ≤ max_tokens.
       Keeps only entities fully inside each window and remaps char offsets."""
    out = []
    sent_tok = PunktSentenceTokenizer()

    for ex in samples:
        text = ex["text"]
        ents = ex["entities"]  # [{"start":int,"end":int,"label":str}, ...]

        spans = list(sent_tok.span_tokenize(text))  # [(s_start, s_end), ...]
        i = 0
        while i < len(spans):
            j = i
            win_start = spans[i][0]
            win_end = win_start

            # pack sentences until token budget hit
            while j < len(spans):
                cand_end = spans[j][1]
                cand_text = text[win_start:cand_end]
                if len(tokenizer.encode(cand_text, add_special_tokens=False)) <= max_tokens:
                    win_end = cand_end
                    j += 1
                else:
                    break

            # ensure at least one sentence (hard trim if single sent too long)
            if j == i:
                s0, s1 = spans[i]
                hi = s1
                while hi > s0 and len(tokenizer.encode(text[s0:hi], add_special_tokens=False)) > max_tokens:
                    hi = s0 + (hi - s0) * 9 // 10
                win_start, win_end = s0, max(s0 + 1, hi)
                j = i + 1

            win_text = text[win_start:win_end]

            # keep entities fully inside; shift to window coords
            kept = []
            for e in ents:
                if e["start"] >= win_start and e["end"] <= win_end:
                    kept.append({
                        "start": e["start"] - win_start,
                        "end":   e["end"]   - win_start,
                        "label": e["label"]
                    })

            if kept or keep_empty_windows:
                out.append({"text": win_text, "entities": kept})

            i = j
    return out


[nltk_data] Downloading package punkt to /Users/valencia/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [7]:
# ==== 6) Build everything and train ====


def finetune_gliner(kaggle_json_path: str,
                    out_dir: str = "./gliner_finetuned",
                    lr: float = 1e-5,
                    batch_size: int = 2,
                    epochs: int = 3):

    os.makedirs(out_dir, exist_ok=True)

    # a) Convert data
    samples = load_kaggle_and_convert(kaggle_json_path)

    # c) Model & tokenizer
    model = GLiNER.from_pretrained("knowledgator/gliner-pii-base-v1.0")
    tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliner-pii-base-v1.0")
    model.to(device)
    model.config.max_length = 512  # quiets the warning

    windowed = sentence_window_examples(samples, tokenizer, max_tokens=MAX_TOKENS, keep_empty_windows=True)


    # b) Label space
    all_labels = sorted({e["label"] for ex in windowed for e in ex["entities"]})
    # ensure matches your target set (optional: intersect with PII_LABELS)
    # all_labels = sorted(set(all_labels).intersection(set(PII_LABELS))) or keep all



    train_data, val_data = split_train_val(windowed, val_ratio=0.2, seed=42)
    # d) Datasets & DataLoaders
    train_ds = GLiNERSpanDataset(train_data, all_labels)
    val_ds = GLiNERSpanDataset(val_data, all_labels)


    # # dataset items must be: {"text": str, "ner": [(start,end,label), ...], "labels": all_labels}
    # collate_fn = lambda batch: model.data_processor.collate_fn(batch, prepare_labels=True, config=model.config)

        # Version-safe collate
    def gliner_collate(batch):
        # batch items from your Dataset: {"text": str, "ner": [(s,e,l), ...], "labels": all_labels}
        texts     = [b["text"] for b in batch]
        entities  = [b["ner"]  for b in batch]
        labels    = batch[0]["labels"]                   # shared label space
        cls2id    = {c: i for i, c in enumerate(labels)}
        id2cls    = {i: c for c, i in cls2id.items()}

        # GLiNER processor expects lists indexed by sample id
        collated = {
            "tokens": texts,                              # raw texts (name 'tokens' is historical)
            "entities": entities,                         # list of spans per sample
            "classes_to_id": [cls2id] * len(texts),       # ← list, not dict
            "id_to_classes": [id2cls] * len(texts),       # optional but some versions read it
        }
        return model.data_processor.collate_fn(collated, prepare_labels=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=gliner_collate)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, collate_fn=gliner_collate)


    # e) Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    # f) Train loop
    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {ep}/{epochs}")
        for batch in pbar:
            # move tensors to device
            for k, v in batch.items():
                if isinstance(v, torch.Tensor):
                    batch[k] = v.to(device)

            out = model(**batch)           # expects inputs from DataCollator
            loss = out.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        avg_train = running / max(1, len(train_loader))
        # quick val loss
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for vb in val_loader:
                for k, v in vb.items():
                    if isinstance(v, torch.Tensor):
                        vb[k] = v.to(device)
                out = model(**vb)
                val_loss += out.loss.item()
        avg_val = val_loss / max(1, len(val_loader))
        print(f"📉 train_loss={avg_train:.4f} | ✅ val_loss={avg_val:.4f}")

    # g) Save
    model.save_pretrained(out_dir)
    print(f"💾 Saved to: {out_dir}")

    # h) Sanity check on one val sample
    sample_text = val_data[0]["text"]
    preds = model.predict_entities(sample_text, all_labels, threshold=0.35)  # lower threshold boosts recall
    print("\n🔎 Sanity check predictions:")
    for p in preds:
        print(f"- {p['label']} → '{sample_text[p['start']:p['end']]}'")

In [8]:

# ==== 7) Run ====
# Replace with your Kaggle file path (in current dir or use os.path.join)
finetune_gliner(".data/mixtral-8x7b-v1.json",
                out_dir=".models/gliner_finetuned_kaggle",
                lr=1e-5,
                batch_size=8,
                epochs=3)

Fetching 13 files: 100%|██████████| 13/13 [00:00<00:00, 115033.65it/s]
Epoch 1/3:   0%|          | 0/532 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1/3:   0%|          | 0/532 [00:04<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 14.20 GiB, other allocations: 3.83 GiB, max allowed: 20.40 GiB). Tried to allocate 5.58 GiB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).