In [4]:
import json
import sys
from pathlib import Path
from collections import Counter, defaultdict
from datasets import load_dataset
from transformers import GPT2TokenizerFast
import spacy
from tqdm import tqdm

In [5]:
# CPU/pilot-friendly settings
NUM_WORDS_TO_PROCESS = 50_000  # small pilot
OUTPATH = "token_pos_map.json"
MIN_OCCURRENCES = 3  # minimum occurrences to assign POS

In [6]:
# Load GPT-2 tokenizer
tokenizer_name = "gpt2"
print(f"Loading tokenizer: {tokenizer_name}")
tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
print(f"✓ Tokenizer loaded. Vocab size: {tokenizer.vocab_size:,}")

Loading tokenizer: gpt2
✓ Tokenizer loaded. Vocab size: 50,257


In [7]:
# Load spaCy model (lightweight - only POS tagger)
spacy_model = "en_core_web_sm"
print(f"Loading spaCy model: {spacy_model}")

Loading spaCy model: en_core_web_sm


In [8]:
try:
    nlp = spacy.load(spacy_model, disable=["ner", "parser", "lemmatizer"])
    print("✓ spaCy model loaded (POS tagger only)")
except OSError:
    print(f"Model not found. Downloading {spacy_model}...")
    import subprocess
    subprocess.run([sys.executable, "-m", "spacy", "download", spacy_model], check=True)
    nlp = spacy.load(spacy_model, disable=["ner", "parser", "lemmatizer"])
    print("✓ spaCy model downloaded and loaded")

✓ spaCy model loaded (POS tagger only)


In [10]:
test_sentence = "I love machine learning and neural networks."
print(f"Test sentence: '{test_sentence}'\n")

Test sentence: 'I love machine learning and neural networks.'



In [15]:
# spaCy POS tagging
doc = nlp(test_sentence)
print("spaCy POS tags:")
for token in doc:
    print(f"  {token.text:10s} → {token.pos_:10s}")

spaCy POS tags:
  I          → PRON      
  love       → VERB      
  machine    → NOUN      
  learning   → NOUN      
  and        → CCONJ     
  neural     → ADJ       
  networks   → NOUN      
  .          → PUNCT     


In [16]:
# GPT-2 tokenization test
print("\nGPT-2 tokenization (with space prefix):")
for token in doc:
    if not token.is_space and not token.is_punct:
        word = token.text
        word_with_space = " " + word
        token_ids = tokenizer(word_with_space, add_special_tokens=False)["input_ids"]
        decoded = tokenizer.decode(token_ids)
        print(f"  '{word}' → tokens: {token_ids} → decoded: '{decoded}'")


GPT-2 tokenization (with space prefix):
  'I' → tokens: [314] → decoded: ' I'
  'love' → tokens: [1842] → decoded: ' love'
  'machine' → tokens: [4572] → decoded: ' machine'
  'learning' → tokens: [4673] → decoded: ' learning'
  'and' → tokens: [290] → decoded: ' and'
  'neural' → tokens: [17019] → decoded: ' neural'
  'networks' → tokens: [7686] → decoded: ' networks'


In [18]:
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
print(f"Loading dataset: {dataset_name}/{dataset_config}")

ds = load_dataset(dataset_name, dataset_config, split="train")
print(f"✓ Dataset loaded. Size: {len(ds):,} examples")

Loading dataset: wikitext/wikitext-2-raw-v1
✓ Dataset loaded. Size: 36,718 examples


In [19]:
# Peek at first few examples
print("\nFirst 3 examples:")
for i in range(min(3, len(ds))):
    text = ds[i]["text"].strip()
    if text:
        preview = text[:100] + "..." if len(text) > 100 else text
        print(f"\n[{i}]: {preview}")


First 3 examples:

[1]: = Valkyria Chronicles III =


In [20]:
# Initialize counters
token_pos_counts = defaultdict(Counter)
processed_words = 0
skipped_empty = 0
skipped_multitoken = 0

In [21]:
# Process dataset
for ex in tqdm(ds, desc="Processing dataset"):
    text = ex["text"].strip()
    if not text:
        skipped_empty += 1
        continue
    
    # Process text with spaCy
    doc = nlp(text)
    
    for token in doc:
        # Skip whitespace, punctuation, and symbols
        if token.is_space or token.is_punct or token.pos_ == "SYM":
            continue
        
        word = token.text
        pos = token.pos_
        
        # Skip obvious non-linguistic content
        if word in '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' or word.strip() == '':
            continue
        
        # Tokenize with space prefix (GPT-2 convention)
        word_with_space = " " + word
        toks = tokenizer(word_with_space, add_special_tokens=False)
        ids = toks["input_ids"]
        
        # ONLY map single-token words
        if len(ids) == 1:
            tid = ids[0]
            token_pos_counts[tid][pos] += 1
            processed_words += 1
        else:
            skipped_multitoken += 1
        
        if processed_words >= NUM_WORDS_TO_PROCESS:
            break
    
    if processed_words >= NUM_WORDS_TO_PROCESS:
        break

print(f"\n✓ Processing complete!")
print(f"  Single-token words processed: {processed_words:,}")
print(f"  Multi-token words skipped: {skipped_multitoken:,}")
print(f"  Empty lines skipped: {skipped_empty}")

Processing dataset:   3%|▎         | 1196/36718 [00:23<11:43, 50.47it/s]


✓ Processing complete!
  Single-token words processed: 50,000
  Multi-token words skipped: 5,155
  Empty lines skipped: 420





In [24]:
sample_tokens = list(token_pos_counts.items())[:5]
for tid, counter in sample_tokens:
    token_str = tokenizer.decode([tid]).strip()
    print(f"\nToken ID {tid} ('{token_str}'):")
    for pos, count in counter.most_common():
        print(f"  {pos}: {count} times")



Token ID 17740 ('Chronicles'):
  PROPN: 39 times

Token ID 6711 ('III'):
  PROPN: 17 times

Token ID 645 ('no'):
  DET: 36 times
  INTJ: 6 times
  ADV: 2 times

Token ID 513 ('3'):
  NUM: 40 times

Token ID 4960 ('Japanese'):
  ADJ: 5 times
  PROPN: 1 times


In [25]:
## 7. Build Final Mapping with Confidence Filter

# %%
token_to_pos = {}
low_count_tokens = 0

for tid, counter in token_pos_counts.items():
    total_count = sum(counter.values())
    
    # Only assign POS if we have sufficient evidence
    if total_count >= MIN_OCCURRENCES:
        most_common_pos, count = counter.most_common(1)[0]
        
        # Calculate confidence (what % agree with majority POS)
        confidence = count / total_count
        
        # Only include if reasonably confident (>50% agreement)
        if confidence > 0.5:
            token_to_pos[int(tid)] = most_common_pos
        else:
            low_count_tokens += 1
    else:
        low_count_tokens += 1

print(f"Final mapping statistics:")
print(f"  Tokens with POS assigned: {len(token_to_pos):,}")
print(f"  Tokens filtered (low count/confidence): {low_count_tokens:,}")

# %% [markdown]
## 8. Visualize Results by POS Category

# %%
# Group tokens by POS
pos_to_tokens = defaultdict(list)
for tid, pos in token_to_pos.items():
    pos_to_tokens[pos].append(tid)

print("\nTokens per POS category:")
for pos in sorted(pos_to_tokens.keys()):
    count = len(pos_to_tokens[pos])
    print(f"  {pos:10s}: {count:,} tokens")

# %%
# Show examples from each POS category
print("\n" + "="*60)
print("Sample tokens by POS category:")
print("="*60)

for pos in sorted(pos_to_tokens.keys())[:10]:  # Show first 10 categories
    tokens = pos_to_tokens[pos][:8]  # Show up to 8 examples
    print(f"\n{pos:12s}:", end=" ")
    examples = [tokenizer.decode([tid]).strip() for tid in tokens]
    print(", ".join(examples))

Final mapping statistics:
  Tokens with POS assigned: 2,729
  Tokens filtered (low count/confidence): 4,284

Tokens per POS category:
  ADJ       : 307 tokens
  ADP       : 59 tokens
  ADV       : 123 tokens
  AUX       : 22 tokens
  CCONJ     : 8 tokens
  DET       : 27 tokens
  NOUN      : 999 tokens
  NUM       : 157 tokens
  PART      : 5 tokens
  PRON      : 41 tokens
  PROPN     : 405 tokens
  SCONJ     : 25 tokens
  VERB      : 550 tokens
  X         : 1 tokens

Sample tokens by POS category:

ADJ         : Japanese, tactical, third, same, real, first, military, secret

ADP         : of, as, outside, by, for, in, during, against

ADV         : commonly, also, more, partially, freely, very, directly, once

AUX         : is, are, was, would, can, be, has, will

CCONJ       : and, but, or, either, But, nor, Yet, And

DET         : no, the, a, The, both, A, an, each

NOUN        : role, video, game, series, time, gameplay, predecessors, story

NUM         : 3, 2011, 2010, 2014, 4, o