# Dataset creation notebook

Purpose: create a dataset for NER model using Gemini API for batched based sample generation

Structure:
1. Setup & Configuration
3. Bio tags assignment
4. Generating and saving the whole dataset

### Todo (future improvements)
1. Problem - duplicated samples! Warning: data leak during training
1. Implement checkpoint saves for dataset creation (in case of error)
2. Tagging validation
3. Docstrings to essential functions 

### Setup & Configuration

In [1]:
import os
import json
import time
import random
import pandas as pd
from tqdm import tqdm
from getpass import getpass
from google import genai
from google.genai import types
from sklearn.model_selection import train_test_split

In [2]:
PROMPT_POSITIVE = """
Generate 20 unique items in valid JSON Lines format. 
Output **only** JSON objects, one per line, with no additional commentary, explanation, or text. 
Each line must be a dictionary with keys:
"text" — a sentence containing one or more specific mountain names
"entities" — a list of all mountain names mentioned in the sentence, including alternative names

"""

PROMPT_NEGATIVE = """
Generate 20 unique items in valid JSON Lines format. 
Output **only** JSON objects, one per line, with no commentary.
Each line must be a dictionary with:
"text" — a sentence about geography, hiking, or nature
"entities" — an empty list []
Use generic terms such as "hill", "ridge", "peak", "valley", "cliff" and include real geographic names that are NOT mountains, like rivers, lakes, deserts, islands, or regions.
"""

In [3]:
def setup_api():
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        api_key = getpass("Enter Gemini API Key: ")

    client = genai.Client(api_key=api_key)
    return client

# Maximum number of times to retry a failed API call before giving up on that specific request
MAX_RETRIES = 5

def generate_batch_data(client, prompt, batch_count=1):
    dataset = []

    # Tqdm wrapper makes a progress bar
    for i in tqdm(range(batch_count), desc="Generating batches"):
        try:
            for attempt in range(MAX_RETRIES):
                try:
                    response = client.models.generate_content(
                        model="gemini-2.5-flash",
                        contents=prompt,
                        config=types.GenerateContentConfig(
                            temperature=0.8,
                            response_mime_type="application/json"
                        )
                    )
                    break # success
                except Exception as e:
                    # Exponential backoff (base 10)
                    print(f"Attempt {attempt+1} failed: {e}")
                    wait_time = 10 * (attempt + 1) 
                    print(f"Retrying in {wait_time}s …")
                    time.sleep(wait_time)

            # Extract the raw generated text
            raw_text = response.candidates[0].content.parts[0].text

            # Split by lines (JSONL format) and parse
            batch_data = [json.loads(line) for line in raw_text.strip().splitlines() if line.strip()]

            # Add to overall dataset
            dataset.extend(batch_data)
            
            time.sleep(5)  # rate limit pause

        except Exception as e:
            print(f"Error on batch {i}: {e}. Retrying after 10s …")
            time.sleep(10)

    return dataset


### Verify setup

In [4]:
try:
    client = setup_api()
    test_response = generate_batch_data(client, PROMPT_POSITIVE, 1)
    print(f"Total samples returned: {len(test_response)}")
    print("Preview of first 2 samples sample:")
    print(test_response[:1])
except Exception as e:
    print(f"Smoke test failed: {e}")


Generating batches: 100%|██████████| 1/1 [00:15<00:00, 15.95s/it]

Total samples returned: 20
Preview of first 2 samples sample:
[{'text': 'Climbing Mount Everest, also known as Sagarmatha in Nepal and Chomolungma in Tibet, is a lifelong dream for many alpinists.', 'entities': ['Mount Everest', 'Sagarmatha', 'Chomolungma']}]





In [5]:
import re
import json

def assign_bio_tags(entry):
    """
        Converts entity lists to BIO tags.
    """
    text = entry['text']
    
    # Robust regex tokenization
    tokens = re.findall(r'\w+|[^\w\s]', text)
    tags = ["O"] * len(tokens)
    
    entities = entry.get('entities', [])
    
    for entity in entities:
        # Tokenize the entity string using the exact same regex
        entity_tokens = re.findall(r'\w+|[^\w\s]', entity)
        entity_len = len(entity_tokens)
        
        # Sliding window match
        for i in range(len(tokens) - entity_len + 1):
            if tokens[i:i+entity_len] == entity_tokens:
                tags[i] = "B-MNT"
                for j in range(1, entity_len):
                    tags[i+j] = "I-MNT"
    
    return {"tokens": tokens, "ner_tags": tags}

### Bio-tags assignment test

In [6]:
# Test Case 1: Positive Sample
test_data_1 = {
    "text": "The climb to Mount Fitz Roy was difficult near K2!",
    "entities": ["Mount Fitz Roy", "K2"]
}

# Test Case 2: Negative Sample
test_data_2 = {
    "text": "The river flows near the high ridge, far from the city center.",
    "entities": []
}

# --- Execute Tests ---
print("\nTest 1:")
result_1 = assign_bio_tags(test_data_1)
print(json.dumps(result_1, indent=4))

print("\nTest 2:")
result_2 = assign_bio_tags(test_data_2)
print(json.dumps(result_2, indent=4))


Test 1:
{
    "tokens": [
        "The",
        "climb",
        "to",
        "Mount",
        "Fitz",
        "Roy",
        "was",
        "difficult",
        "near",
        "K2",
        "!"
    ],
    "ner_tags": [
        "O",
        "O",
        "O",
        "B-MNT",
        "I-MNT",
        "I-MNT",
        "O",
        "O",
        "O",
        "B-MNT",
        "O"
    ]
}

Test 2:
{
    "tokens": [
        "The",
        "river",
        "flows",
        "near",
        "the",
        "high",
        "ridge",
        ",",
        "far",
        "from",
        "the",
        "city",
        "center",
        "."
    ],
    "ner_tags": [
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O",
        "O"
    ]
}


# Generating a whole dataset

### Save & Load functions

In [7]:
def save_jsonl(data, filename):
    """
        Save jsonl data to file 
    """
    with open(filename, 'w') as f:
        for entry in data:
            json.dump(entry, f)
            f.write('\n')

def load_jsonl(filename):
    """
        Load jsonl data from file
    """
    data = []
    with open(filename, 'r') as f:
        for line in f:
            if line.strip():  # avoid empty lines
                data.append(json.loads(line))
    return data

### Caution: This section generates data using an external API.

In [8]:
# Generate raw data

# One batch contains 20 samples
POSITIVE_BATCH_COUNT = 6 
NEGATIVE_BATCH_COUNT = 6

try:
    client = setup_api()

    print("\nGenerating Positive Samples...")
    pos_data = generate_batch_data(client, PROMPT_POSITIVE, batch_count=POSITIVE_BATCH_COUNT)
    save_jsonl(pos_data, "data/raw/raw_positive.jsonl")

    print("\nGenerating Negative Samples...")
    neg_data = generate_batch_data(client, PROMPT_NEGATIVE, batch_count=NEGATIVE_BATCH_COUNT)
    save_jsonl(neg_data, "data/raw/raw_negative.jsonl")

except Exception as e:
    print(f"\nAPI generation error: {e}")


Generating Positive Samples...


Generating batches:  83%|████████▎ | 5/6 [01:27<00:19, 19.30s/it]

Attempt 1 failed: 503 UNAVAILABLE. {'error': {'code': 503, 'message': 'The model is overloaded. Please try again later.', 'status': 'UNAVAILABLE'}}
Retrying in 10s …


Generating batches: 100%|██████████| 6/6 [01:54<00:00, 19.08s/it]



Generating Negative Samples...


Generating batches: 100%|██████████| 6/6 [01:08<00:00, 11.44s/it]


### Processing and Validating

In [9]:
# Load raw data from disk
raw_data = load_jsonl("data/raw/raw_positive.jsonl") + load_jsonl("data/raw/raw_negative.jsonl")
random.shuffle(raw_data)

# DEDUPPLICATION
print(f"Original Raw Sample Count: {len(raw_data)}")

unique_data = []
processed_texts = set()

for sample in raw_data:
    text = sample.get('text') 
    
    if text and text not in processed_texts:
        processed_texts.add(text)
        unique_data.append(sample)

raw_data = unique_data # Overwrite the raw_data list with unique samples

print(f"Unique Raw Sample Count: {len(raw_data)}")

processed_dataset = []
for entry in tqdm(raw_data, desc="Applying BIO Tags"):
    try:
        tagged_entry = assign_bio_tags(entry)
        processed_dataset.append(tagged_entry)
        
        # Todo: write validation of bio tags
        #
        # if validate_bio_tags(tagged_entry):
        #     processed_dataset.append(tagged_entry)
        # else:
        #     print(f"Skipping invalid entry: {entry['text']}")
            
    except Exception as e:
        print(f"Tagging failed for an entry: {e}")
        continue 
        
print(f"Tagging complete. Total samples: {len(processed_dataset)}")

Original Raw Sample Count: 240
Unique Raw Sample Count: 236


Applying BIO Tags: 100%|██████████| 236/236 [00:00<00:00, 78622.38it/s]

Tagging complete. Total samples: 236





### Split and Export

In [12]:
# 3. SPLIT AND FINAL EXPORT
train_data, remain_data = train_test_split(processed_dataset, test_size=0.4, random_state=42)
val_data, test_data = train_test_split(remain_data, test_size=0.5, random_state=42)

save_jsonl(train_data, "data/final/train.jsonl")
save_jsonl(val_data, "data/final/validation.jsonl")
save_jsonl(test_data, "data/final/test.jsonl")