# I. **Setup & Configuration**

This cell performs the complete environment setup for the Word2Vec GPU-based implementation. The setup process consists of four main steps:

## **Setup Steps**

### **Step 1: Numba-CUDA Installation Check**
- Verifies if `numba-cuda` (version 0.4.0) is already installed in the environment
- If not installed, provides instructions for manual installation using `uv pip install`
- This is a prerequisite for GPU-accelerated training

### **Step 2: Numba-CUDA Configuration**
- Configures numba-cuda settings for optimal performance in Google Colab environment
- Enables `CUDA_ENABLE_PYNVJITLINK` to support CUDA 12.x compatibility
- Disables low occupancy warnings for cleaner output
- Performs a CUDA kernel test to verify functionality

### **Step 3: Required Package Installation**
Automatically installs all necessary Python packages for the Word2Vec implementation:

- **`pynvjitlink-cu12`**: CUDA 12.x JIT linker support for numba-cuda
- **`numpy>=1.20.0`**: Numerical computing library
- **`gensim>=4.0.0`**: Word embedding evaluation and comparison tools
- **`scikit-learn>=1.0.0`**: Machine learning utilities for evaluation
- **`matplotlib>=3.5.0`**: Plotting and visualization library
- **`seaborn>=0.11.0`**: Statistical data visualization
- **`tqdm>=4.60.0`**: Progress bars for long-running operations
- **`requests>=2.25.0`**: HTTP library for downloading datasets
- **`pynvml>=11.0.0`**: NVIDIA Management Library for GPU monitoring

The installation process uses `uv pip` (if available) or falls back to standard `pip`, with automatic error handling for each package.

### **Step 4: GPU and CUDA Availability Check**
- Verifies NVIDIA GPU presence using `nvidia-smi`
- Checks CUDA availability through numba-cuda
- Displays GPU device information and memory capacity
- Provides clear status indicators for all checks

## **Output**
After execution, a summary report displays the status of all setup steps:
- ✅ Green checkmarks indicate successful completion
- ❌ Red X marks indicate failures or missing components
- Warnings are provided for non-critical issues

In [None]:
import os
import subprocess
import sys

def install_package_with_uv(package: str, quiet: bool = True) -> bool:
    try:
        cmd = ["uv", "pip", "install"]
        if quiet:
            cmd.append("-q")
        cmd.extend(["--system", package])

        result = subprocess.run(
            cmd,
            capture_output=quiet,
            text=True,
            check=True
        )

        if not quiet:
            print(f"✅ {package} installed successfully")
        return True

    except subprocess.CalledProcessError as e:
        print(f"❌ Failed to install {package}: {e}")
        if not quiet and e.stdout:
            print(f"stdout: {e.stdout}")
        if not quiet and e.stderr:
            print(f"stderr: {e.stderr}")
        return False

    except FileNotFoundError:
        print(f"⚠️ uv not found, trying regular pip for {package}...")
        try:
            cmd = [sys.executable, "-m", "pip", "install"]
            if quiet:
                cmd.append("-q")
            cmd.append(package)

            subprocess.check_call(cmd)
            if not quiet:
                print(f"✅ {package} installed successfully (via pip)")
            return True

        except Exception as e2:
            print(f"❌ Failed to install {package} with pip: {e2}")
            return False

def check_numba_cuda_installed():
    print("\n" + "=" * 60)
    print("STEP 1: Checking numba-cuda installation")
    print("=" * 60)
    print("Checking if numba-cuda is installed...")

    try:
        import numba
        from numba import cuda
        print("✅ numba-cuda is already installed")

        try:
            import numba_cuda
            print(f"  numba version: {numba.__version__ if hasattr(numba, '__version__') else 'unknown'}")
        except:
            pass

        return True

    except ImportError:
        print("❌ numba-cuda is NOT installed")
        print("⚠️ Please install manually first:")
        print("!uv pip install -q --system numba-cuda==0.4.0")
        return False

def setup_numba_cuda_config():
    print("\n" + "=" * 60)
    print("STEP 2: Configuring numba-cuda")
    print("=" * 60)
    print("Setting up numba-cuda (Official Solution)")
    print("Based on: https://github.com/googlecolab/colabtools/issues/5081")
    print()

    print("Configuring numba-cuda...")
    try:
        from numba import config
        config.CUDA_ENABLE_PYNVJITLINK = 1
        config.CUDA_LOW_OCCUPANCY_WARNINGS = 0
        print("✅ numba-cuda configuration set")
        print("  - CUDA_ENABLE_PYNVJITLINK = 1")
        print("  - CUDA_LOW_OCCUPANCY_WARNINGS = 0")
    except ImportError:
        print("❌ numba not installed - cannot configure")
        return False
    except Exception as e:
        print(f"❌ Failed to configure numba-cuda: {e}")
        return False

    print("\nTesting CUDA functionality...")
    try:
        from numba import cuda
        import numpy as np

        if cuda.is_available():
            device = cuda.get_current_device()
            print(f"CUDA available: {device.name}")

            @cuda.jit
            def increment_by_one(an_array):
                pos = cuda.grid(1)
                if pos < an_array.size:
                    an_array[pos] += 1

            test_array = np.zeros(10, dtype=np.float32)
            increment_by_one[16, 16](test_array)

            expected = np.ones(10, dtype=np.float32)
            if np.allclose(test_array, expected):
                print("✅ CUDA kernel test passed!")
                return True
            else:
                print("❌ CUDA kernel test failed")
                return False
        else:
            print("❌ CUDA not available")
            return False

    except Exception as e:
        print(f"❌ CUDA test failed: {e}")
        return False

def install_all_requirements():
    print("\n" + "=" * 60)
    print("STEP 3: Installing all required packages")
    print("Installing required packages for Google Colab...")
    print("=" * 60)

    packages = [
        "pynvjitlink-cu12",
        "numpy>=1.20.0",
        "gensim>=4.0.0",
        "scikit-learn>=1.0.0",
        "matplotlib>=3.5.0",
        "seaborn>=0.11.0",
        "tqdm>=4.60.0",
        "requests>=2.25.0",
        "pynvml>=11.0.0"
    ]

    success_count = 0
    failed_packages = []

    for package in packages:
        print(f"Installing {package}...", end=" ", flush=True)
        if install_package_with_uv(package, quiet=True):
            print("✅")
            success_count += 1
        else:
            print("❌")
            failed_packages.append(package)

    print(f"\nInstalled {success_count}/{len(packages)} packages successfully")

    if failed_packages:
        print(f"⚠️  Failed packages: {', '.join(failed_packages)}")
        return False

    return True

def check_gpu():
    print("\n" + "=" * 60)
    print("Checking GPU availability...")
    print("=" * 60)

    try:
        result = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
        if result.returncode == 0:
            print("✅ NVIDIA GPU detected:")
            print(result.stdout)
            return True
        else:
            print("❌ No NVIDIA GPU detected")
            return False

    except FileNotFoundError:
        print("❌ nvidia-smi not found")
        return False

def check_cuda():
    print("\n" + "=" * 60)
    print("Checking CUDA availability...")
    print("=" * 60)

    try:
        from numba import cuda
        if cuda.is_available():
            device = cuda.get_current_device()
            print(f"✅ CUDA available: {device.name}")

            try:
                import pynvml
                pynvml.nvmlInit()
                handle = pynvml.nvmlDeviceGetHandleByIndex(0)
                memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
                total_memory = memory_info.total / 1024**3
                print(f"  Memory: {total_memory:.1f} GB")
            except (ImportError, Exception) as e:
                print(f"  Device: {device.name}")
                print(f"  (Memory info unavailable: {e})")

            return True
        else:
            print("❌ CUDA not available")
            return False

    except ImportError:
        print("❌ Numba not installed")
        return False

def main():
    print("=" * 60)
    print("  Word2Vec Implementation - Complete Google Colab Setup")
    print("=" * 60)
    print("\nThis script combines all setup steps:")
    print("  1. Check numba-cuda installation")
    print("  2. Configure numba-cuda")
    print("  3. Install all required packages")
    print("  4. Check GPU and CUDA availability")

    results = {
        "numba_cuda_installed": False,
        "numba_cuda_configured": False,
        "requirements_installed": False,
        "gpu_available": False,
        "cuda_available": False
    }

    results["numba_cuda_installed"] = check_numba_cuda_installed()

    if not results["numba_cuda_installed"]:
        print("\n⚠️  Warning: numba-cuda is not installed. Please install it first:")
        print("   !uv pip install -q --system numba-cuda==0.4.0")
        print("   Continuing with other setup steps...")

    results["numba_cuda_configured"] = setup_numba_cuda_config()

    if not results["numba_cuda_configured"]:
        print("\n⚠️  Warning: Failed to configure numba-cuda. Continuing anyway...")

    results["requirements_installed"] = install_all_requirements()

    results["gpu_available"] = check_gpu()

    results["cuda_available"] = check_cuda()

    print("\n" + "=" * 60)
    print("  SETUP SUMMARY")
    print("=" * 60)
    print(f"  ✅ numba-cuda installed: {'✅' if results['numba_cuda_installed'] else '❌'}")
    print(f"  ✅ numba-cuda configured: {'✅' if results['numba_cuda_configured'] else '❌'}")
    print(f"  ✅ Requirements installed: {'✅' if results['requirements_installed'] else '❌'}")
    print(f"  ✅ GPU available: {'✅' if results['gpu_available'] else '❌'}")
    print(f"  ✅ CUDA available: {'✅' if results['cuda_available'] else '❌'}")
    print("=" * 60)

    if results['gpu_available'] and results['cuda_available']:
        print("\nSetup complete! Ready to run Word2Vec training")
        print("\nTo run the full pipeline:")
        print("  !python run_all.py")

    elif results['numba_cuda_installed'] and results['numba_cuda_configured']:
        print("\nSetup completed successfully")
        print("⚠️ Note: GPU/CUDA may not be available, but CPU training is still possible")
        print("\nTo run the full pipeline:")
        print("  !python run_all.py")

    else:
        print("\nSetup completed with some warnings")
        print("Some features may not work correctly")
        print("\nTo run anyway:")
        print("  !python run_all.py")

    return 0

if __name__ == "__main__":
    main()



# II. **Common Utilities**

This cell contains core utility functions for Word2Vec implementation, handling vocabulary processing, data loading, weight initialization, and support structures for both Hierarchical Softmax and Negative Sampling.

In [None]:
import json
import math
import os
import pathlib
import re
import time
import hashlib
import pickle
from collections import defaultdict
from typing import List, Tuple, Dict, Any, Optional

from numba import cuda
import numpy as np
from numpy import linalg, ndarray


W2V_VERSION = "1.0"
BLANK_TOKEN = "<BLANK>"

# Constants for Hierarchical Softmax and Exp Table
EXP_TABLE_SIZE = 1000
MAX_EXP = 6
MAX_CODE_LENGTH = 40


def build_vocab(data_path: str) -> List[Tuple[str, int, int]]:
    """
    Build vocabulary from data files -> Returns list of (word, total_count, sentence_count)
    """
    files = [fn for fn in os.listdir(data_path) if fn.startswith("0")]
    sentences_per_word = defaultdict(int)
    totals_per_word = defaultdict(int)

    for file in files:
        with open(os.path.join(data_path, file), encoding="utf-8") as f:
            for line in f:
                less_spacey = re.sub(r"[ ]{2,}", " ", line.strip())
                words = less_spacey.split(" ")
                if len(words) > 1:
                    uniques = set()
                    for word in words:
                        uniques.add(word)
                        totals_per_word[word] += 1
                    for deduped in uniques:
                        sentences_per_word[deduped] += 1

    r = []
    for word, total in totals_per_word.items():
        sent = sentences_per_word[word]
        r.append((word, total, sent))
    return r


def sort_vocab(my_vocab: List[Tuple[str, int, int]]) -> List[Tuple[str, int, int]]:
    """
    Sort vocabulary by frequency (descending) then alphabetically
    """
    vs = [(BLANK_TOKEN, 0, 0)] + sorted(my_vocab, key=lambda t: (-t[1], t[0]))
    return vs


def prune_vocab(min_occrs: int, my_vocab: List[Tuple[str, int, int]]) -> List[Tuple[str, int]]:
    """
    Prune vocabulary based on minimum sentence occurrences -> Returns only total counts
    """
    if min_occrs > 1:
        totals = [(wrd, total_count) for wrd, total_count, sentence_count in my_vocab
                 if sentence_count >= min_occrs or wrd == BLANK_TOKEN]
        return totals
    else:
        return [(word, total) for word, total, _ in my_vocab]


def bias_freq_counts(vocab: List[Tuple[str, int]], exponent: float) -> List[Tuple[str, float]]:
    """
    Apply frequency biasing with given exponent for negative sampling
    """
    totalsson = sum(count for _, count in vocab)
    plain = [(word, count / totalsson) for word, count in vocab]

    if exponent == 1.0:
        return plain

    exped = [(word, math.pow(count, exponent)) for word, count in plain]
    sum_exped = sum([q for _, q in exped])
    jooh = [(word, f/sum_exped) for word, f in exped]
    return jooh


def _get_vocab_cache_key(data_path: str, min_occurs_by_sentence: int, freq_exponent: float) -> str:
    """
    Generate cache key based on vocabulary parameters
    """
    key_string = f"{data_path}_{min_occurs_by_sentence}_{freq_exponent}"
    return hashlib.md5(key_string.encode()).hexdigest()


def _get_vocab_cache_path(cache_key: str) -> str:
    """
    Get path to vocabulary cache file
    """
    cache_dir = "./output/vocab_cache"
    os.makedirs(cache_dir, exist_ok=True)
    return os.path.join(cache_dir, f"vocab_{cache_key}.pkl")


def _save_vocab_cache(vocab: List[Tuple[str, float]], w_to_i: Dict[str, int], word_counts: List[int], cache_path: str):
    """
    Save vocabulary to cache file
    """
    cache_data = {
        'vocab': vocab,
        'w_to_i': w_to_i,
        'word_counts': word_counts
    }
    with open(cache_path, 'wb') as f:
        pickle.dump(cache_data, f)


def _load_vocab_cache(cache_path: str) -> Optional[Tuple[List[Tuple[str, float]], Dict[str, int], List[int]]]:
    """
    Load vocabulary from cache file -> Returns None if cache doesn't exist or is invalid
    """
    try:
        if not os.path.exists(cache_path):
            return None
        with open(cache_path, 'rb') as f:
            cache_data = pickle.load(f)
        return (cache_data['vocab'], cache_data['w_to_i'], cache_data['word_counts'])
    except Exception:
        return None


def handle_vocab(data_path: str, min_occurs_by_sentence: int, freq_exponent: float, use_cache: bool = True):
    """
    Complete vocabulary handling pipeline with optional caching -> Returns: (biased_vocab, w_to_i, word_counts)
    - biased_vocab: List of (word, frequency) for negative sampling
    - w_to_i: Dictionary mapping word to index
    - word_counts: List of word counts (for Huffman tree construction)

    Args:
        use_cache: If True, try to load from cache or save to cache after building
                   Cache is based on data_path, min_occurs_by_sentence, and freq_exponent
                   Changing epochs or embed_dim will not invalidate the cache
    """
    # Try to load from cache
    if use_cache:
        cache_key = _get_vocab_cache_key(data_path, min_occurs_by_sentence, freq_exponent)
        cache_path = _get_vocab_cache_path(cache_key)
        cached_vocab = _load_vocab_cache(cache_path)
        if cached_vocab is not None:
            return cached_vocab

    # Build vocabulary
    vocab: List[Tuple[str, int, int]] = build_vocab(data_path)
    sorted_vocab: List[Tuple[str, int, int]] = sort_vocab(vocab)
    pruned_vocab: List[Tuple[str, int]] = prune_vocab(min_occurs_by_sentence, sorted_vocab)

    # Store word counts before biasing
    word_counts = [count for _, count in pruned_vocab]
    biased_vocab: List[Tuple[str, float]] = bias_freq_counts(pruned_vocab, freq_exponent)
    w_to_i: Dict[str, int] = {word: idx for idx, (word, _) in enumerate(biased_vocab)}

    # Save to cache
    if use_cache:
        cache_key = _get_vocab_cache_key(data_path, min_occurs_by_sentence, freq_exponent)
        cache_path = _get_vocab_cache_path(cache_key)
        _save_vocab_cache(biased_vocab, w_to_i, word_counts, cache_path)

    return biased_vocab, w_to_i, word_counts


def get_subsampling_weights_and_negative_sampling_array(vocab: List[Tuple[str, float]], t: float) -> Tuple[ndarray, ndarray]:
    """
    Calculate subsampling weights and create negative sampling array

    Negative sampling array size is dynamically adjusted based on vocabulary size:
    - For small vocabs (< 10k): uses 1M (original default)
    - For medium vocabs (10k-100k): uses 10M
    - For large vocabs (> 100k): uses 100M (same as word2vec.c original)

    This ensures all words appear in the array and maintains distribution accuracy
    """
    # Subsampling weights
    tot_wgt: int = sum([c for _, c in vocab])
    freqs: List[float] = [c/tot_wgt for _, c in vocab]

    # Clamp negative probabilities to zero
    probs: List[float] = [max(0.0, 1-math.sqrt(t/freq)) if freq > 0 else 0.0 for freq in freqs]

    # Negative sampling array - precompute for efficient sampling
    vocab_size = len(vocab)

    # Dynamically adjust arr_len based on vocabulary size
    # Original source code of the Word2Vec paper uses 1e8 (100M), we scale based on vocab size
    if vocab_size < 10000:
        arr_len = 1000000  # 1M for small vocabs
    elif vocab_size < 100000:
        arr_len = 10000000  # 10M for medium vocabs
    else:
        arr_len = 100000000  # 100M for large vocabs (same as word2vec.c in original source code)

    print(f"Creating negative sampling array with size {arr_len:,} for vocab size {vocab_size:,}")

    w2 = [round(f*arr_len) for f in freqs]

    # Check if any words would be excluded (rounded to 0)
    excluded_count = sum(1 for scaled in w2 if scaled == 0)
    if excluded_count > 0:
        print(f"⚠️ WARNING: {excluded_count} words have frequency too low and will be excluded from negative sampling")
        print(f"⚠️ Consider increasing arr_len or reducing min_occurs threshold")

    neg_arr = []
    for i, scaled in enumerate(w2):
        if scaled > 0:  # Only add words that appear at least once
            neg_arr.extend([i]*scaled)

    actual_arr_size = len(neg_arr)
    print(f"Negative sampling array created: {actual_arr_size:,} entries ({actual_arr_size/1e6:.2f}M)")

    return np.asarray(probs, dtype=np.float32), np.asarray(neg_arr, dtype=np.int32)


def get_data_file_names(path: str, seed: int) -> List[str]:
    """
    Get shuffled list of data file names
    """
    rng = np.random.default_rng(seed=seed)
    qq = [fn for fn in os.listdir(path) if fn.startswith("0")]
    data_files = sorted(qq)
    rng.shuffle(data_files)
    return data_files


def read_all_data_files_ever(dat_path: str, file_names: List[str], w_to_i: Dict[str, int],
                             max_words: int = None) -> Tuple[List[int], List[int], List[int]]:
    """
    Read all data files and convert to indices

    Args:
        dat_path: Path to data directory
        file_names: List of file names to read
        w_to_i: Word to index mapping
        max_words: Maximum number of words to read (None = all). If specified, will stop reading when total words reach this limit.

    Returns:
        Tuple of (inps, offs, lens) where:
        - inps: List of word indices
        - offs: List of offsets for each sentence
        - lens: List of sentence lengths
    """
    start = time.time()
    inps, offs, lens = [], [], []
    offset_total = 0
    stats = defaultdict(int)
    total_words_read = 0
    stopped_early = False

    for fn in file_names:
        fp = os.path.join(dat_path, fn)
        ok_lines = 0
        too_short_lines = 0
        with open(fp, encoding="utf-8") as f:
            for line in f:
                # Check if we've reached max_words limit
                if max_words is not None and total_words_read >= max_words:
                    stopped_early = True
                    break

                words = [word for word in re.split(r"[ .]+", line.strip()) if word]
                if len(words) < 2:
                    too_short_lines += 1
                    continue

                idcs = [w_to_i[w] for w in words if w in w_to_i]
                le = len(idcs)

                # Check if adding this sentence would exceed max_words
                if max_words is not None and total_words_read + le > max_words:
                    # Only add words up to the limit
                    remaining_words = max_words - total_words_read
                    if remaining_words > 0:
                        idcs = idcs[:remaining_words]
                        le = len(idcs)
                    else:
                        stopped_early = True
                        break

                ok_lines += 1
                offs.append(offset_total)
                lens.append(le)
                inps.extend(idcs)
                offset_total += le
                total_words_read += le

                # Break if we've reached the limit exactly
                if max_words is not None and total_words_read >= max_words:
                    stopped_early = True
                    break

        stats["file_read_lines_ok"] += ok_lines
        stats["one_word_sentence_lines_which_were_ignored"] += too_short_lines

        # Break outer loop if we've reached the limit
        if stopped_early:
            break

    print(f"read_all_data_files_ever() STATS: {stats}")
    if max_words is not None and stopped_early:
        print(f"⚠️ Stopped early: reached max_words limit of {max_words:,} words")
    tot_tm = time.time()-start
    print(f"read_all_data_files_ever() Total time {tot_tm} s for {len(file_names)} files (avg {tot_tm/len(file_names)} s/file)")
    return inps, offs, lens


def init_weight_matrices(vocab_size: int, embed_dim: int, seed: int) -> Tuple[ndarray, ndarray]:
    """
    Initialize weight matrices with Gaussian distribution
    """
    rng = np.random.default_rng(seed=seed)
    rows, cols = vocab_size, embed_dim
    sigma: float = math.sqrt(1.0/cols)
    zs = rng.standard_normal(size=(rows, cols), dtype=np.float32)
    xs = sigma * zs
    # First row all zero since it represents the blank token
    xs[0, :] = 0.0
    zs2 = rng.standard_normal(size=(rows, cols), dtype=np.float32)
    xs2 = sigma * zs2
    xs2[0, :] = 0.0
    return xs, xs2


def print_norms(weights_cuda):
    """
    Print statistics about vector norms
    """
    w = weights_cuda.copy_to_host()
    norms = [linalg.norm(v) for v in w]
    a, med, b = np.percentile(norms, [2.5, 50, 97.5])
    avg = float(sum(norms) / len(norms))
    print(f"Vector norms (count {len(norms)}) 2.5% median mean 97.5%: {a:0.4f}  {med:0.4f}  {avg:0.4f}  {b:0.4f}")


def write_vectors(weights_cuda, vocab: List[Tuple[str, float]], out_path: str):
    """
    Write vectors to file in word2vec format
    """
    w = weights_cuda.copy_to_host()
    pathlib.Path(os.path.dirname(out_path)).mkdir(parents=True, exist_ok=True)
    with open(out_path, "w", encoding="utf-8") as f:
        # len-1: skip first which is the blank token & all zero
        f.write(f"{len(w)-1} {len(w[0])}\n")
        for i, v in enumerate(w):
            # skip first which is the blank token & all zero
            if i == 0:
                continue
            v_str = " ".join([str(f) for f in v])
            word, _ = vocab[i]
            f.write(f"{word} {v_str}\n")


def write_json(to_jsonify: Dict[str, Any], json_path: str):
    """
    Write dictionary to JSON file
    """
    with open(json_path, "w", encoding="utf-8") as f:
        f.write(json.dumps(to_jsonify))
        f.write("\n")
        f.flush()


def create_exp_table(exp_table_size: int = EXP_TABLE_SIZE, max_exp: float = MAX_EXP) -> ndarray:
    """
    Create precomputed exp table for fast sigmoid calculation
    Based on word2vec.c from the original source code of the Word2Vec paper

    Args:
        exp_table_size: Size of the exp table (default: 1000)
        max_exp: Maximum exponent value (default: 6)

    Returns:
        numpy array of precomputed sigmoid values
    """
    exp_table = np.zeros(exp_table_size, dtype=np.float32)
    for i in range(exp_table_size):
        # Precompute exp((i / exp_table_size * 2 - 1) * max_exp)
        exp_value = math.exp((i / exp_table_size * 2 - 1) * max_exp)
        # Precompute sigmoid: exp(x) / (exp(x) + 1)
        exp_table[i] = exp_value / (exp_value + 1)
    return exp_table


def init_hs_weight_matrix(vocab_size: int, embed_dim: int) -> ndarray:
    """
    Initialize Hierarchical Softmax weight matrix (syn1)
    Based on word2vec.c from the original source code of the Word2Vec paper

    Args:
        vocab_size: Vocabulary size
        embed_dim: Embedding dimension

    Returns:
        Weight matrix for internal nodes: (vocab_size - 1, embed_dim)
        Initialized with zeros
    """
    # Internal nodes: vocab_size - 1
    syn1 = np.zeros((vocab_size - 1, embed_dim), dtype=np.float32)
    return syn1


def create_huffman_tree(word_counts: List[int], max_code_length: int = MAX_CODE_LENGTH) -> Tuple[ndarray, ndarray, ndarray]:
    """
    Create binary Huffman tree from word counts
    Based on word2vec.c from the original source code of the Word2Vec paper
    Frequent words will have short unique binary codes

    Args:
        word_counts: List of word counts (frequencies)
        max_code_length: Maximum code length (default: 40)

    Returns:
        Tuple of (codes_array, points_array, code_lengths):
        - codes_array: (vocab_size, max_code_length) binary codes, padded with -1
        - points_array: (vocab_size, max_code_length) node indices in path, padded with -1
        - code_lengths: (vocab_size,) code length for each word
    """
    vocab_size = len(word_counts)

    # Initialize arrays
    count = np.zeros(vocab_size * 2 + 1, dtype=np.int64)
    binary = np.zeros(vocab_size * 2 + 1, dtype=np.int32)
    parent_node = np.zeros(vocab_size * 2 + 1, dtype=np.int64)

    # Set initial counts
    for a in range(vocab_size):
        count[a] = word_counts[a]
    for a in range(vocab_size, vocab_size * 2):
        count[a] = int(1e15)  # Large value for internal nodes

    # Build Huffman tree
    pos1 = vocab_size - 1
    pos2 = vocab_size

    for a in range(vocab_size - 1):
        # Find two smallest nodes
        if pos1 >= 0:
            if count[pos1] < count[pos2]:
                min1i = pos1
                pos1 -= 1
            else:
                min1i = pos2
                pos2 += 1
        else:
            min1i = pos2
            pos2 += 1

        if pos1 >= 0:
            if count[pos1] < count[pos2]:
                min2i = pos1
                pos1 -= 1
            else:
                min2i = pos2
                pos2 += 1
        else:
            min2i = pos2
            pos2 += 1

        count[vocab_size + a] = count[min1i] + count[min2i]
        parent_node[min1i] = vocab_size + a
        parent_node[min2i] = vocab_size + a
        binary[min2i] = 1

    # Assign binary codes to each word
    codes_array = np.full((vocab_size, max_code_length), -1, dtype=np.int32)
    points_array = np.full((vocab_size, max_code_length), -1, dtype=np.int32)
    code_lengths = np.zeros(vocab_size, dtype=np.int32)

    for a in range(vocab_size):
        b = a
        i = 0
        code = np.zeros(max_code_length, dtype=np.int32)
        point = np.zeros(max_code_length, dtype=np.int64)

        # Traverse from leaf to root
        while True:
            code[i] = binary[b]
            point[i] = b
            i += 1
            b = parent_node[b]
            if b == vocab_size * 2 - 2:
                break
            if i >= max_code_length:
                break  # Safety check

        code_lengths[a] = i
        # Store code and point arrays (reversed)
        points_array[a, 0] = vocab_size - 2  # Root node
        for b_idx in range(i):
            codes_array[a, i - b_idx - 1] = code[b_idx]
            if b_idx < i - 1:
                points_array[a, i - b_idx] = int(point[b_idx] - vocab_size)

    return codes_array, points_array, code_lengths


# III. **Data Handler**

This cell contains functions for downloading, preprocessing, and loading training datasets. It handles text cleaning, phrase detection, and data format conversion for Word2Vec training.

## **Dataset Download**

- **`download_wmt14_news()`**: Downloads and combines WMT14 (2012) and WMT15 (2014) News Crawl datasets from https://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.en.shuffled.gz https://www.statmt.org/wmt15/training-monolingual-news-crawl/news.2014.en.shuffled.gz
  - Downloads compressed files, extracts, and combines into single file
  - Skips download if files already exist
  - Returns path to combined file
- **`download_text8()`**: Downloads text8 dataset from http://mattmahoney.net/dc/text8.zip
  - Extracts and returns path to text8 file

## **Data Preprocessing**

- **`preprocess_wmt14_news()`**: Preprocesses WMT14 news file into sentence files
  - Removes punctuation, normalizes text
  - Splits into sentences (default: 1000 words per sentence)
  - Saves to numbered files (100K sentences per file)
  - Optional phrase detection (2-pass)
  - Supports `max_sentences` and `max_files` limits
- **`preprocess_text8()`**: Preprocesses text8 file into sentence files
  - Similar to WMT14 preprocessing
  - Optional phrase detection

## **Data Loading**

- **`get_data_file_names()`**: Retrieves and shuffles data file names (seeded for reproducibility)
- **`read_all_data_files()`**: Reads data files and converts words to integer indices
  - Returns: `(inps, offs, lens)` - word indices, sentence offsets, sentence lengths
  - Filters out sentences with < 2 words
  - Provides statistics on processed lines

## **Data Preparation Pipeline**

1. **Download**: Use `download_wmt14_news()` or `download_text8()` to get raw data
2. **Preprocess**: Use `preprocess_wmt14_news()` or `preprocess_text8()` to create sentence files
3. **Load**: Use `get_data_file_names()` and `read_all_data_files()` to load for training

## **Notes**

- All preprocessing removes punctuation (commas, periods, etc.) for consistency
- Phrase detection is optional but improves quality for multi-word expressions
- Processed files are cached - existing files are detected and skipped
- Data files are named with 4-digit numbers (e.g., "0000", "0001") starting with "0"

In [None]:
import os
import pathlib
import re
import time
import zipfile
import gzip
import json
from typing import List, Tuple, Dict
from collections import defaultdict
import requests
import tqdm


def clean_text_remove_punctuation(text: str) -> str:
    """
    Clean text by removing punctuation and normalizing whitespace

    Args:
        text: Input text line

    Returns:
        Cleaned text with only lowercase letters and spaces
    """
    if not text:
        return ""

    # Replace tabs and newlines with spaces
    text = re.sub(r'[\t\n]', ' ', text)

    # Normalize multiple spaces to single space
    text = re.sub(r'[ ]{2,}', ' ', text)

    # Remove all punctuation, keep only letters and spaces
    text = re.sub(r'[^a-zA-Z ]', '', text)

    # Convert to lowercase and strip
    text = text.lower().strip()

    return text


def detect_phrases(text: str, word_counts: Dict[str, int], bigram_counts: Dict[Tuple[str, str], int],
                   train_words: int, min_count: int = 5, threshold: float = 100.0) -> str:
    """
    Detect and combine phrases in text based on bigram scores
    Based on word2phrase.c TrainModel() function from the original source code of the Word2Vec paper

    Args:
        text: Input text (space-separated words)
        word_counts: Dictionary mapping words to their counts
        bigram_counts: Dictionary mapping (word1, word2) tuples to bigram counts
        train_words: Total number of words in training data
        min_count: Minimum word count threshold
        threshold: Score threshold for phrase formation (higher = fewer phrases)

    Returns:
        Text with phrases combined (e.g., "new york" -> "new_york")
    """
    words = text.split()
    if len(words) < 2:
        return text

    result = []
    i = 0
    while i < len(words):
        if i == len(words) - 1:
            # Last word, no bigram possible
            result.append(words[i])
            break

        word1 = words[i]
        word2 = words[i + 1]

        # Check if both words meet min_count
        count1 = word_counts.get(word1, 0)
        count2 = word_counts.get(word2, 0)

        if count1 < min_count or count2 < min_count:
            # One word doesn't meet threshold, keep as separate
            result.append(word1)
            i += 1
            continue

        bigram = (word1, word2)
        count_bigram = bigram_counts.get(bigram, 0)

        if count_bigram == 0:
            # Bigram not found, keep as separate
            result.append(word1)
            i += 1
            continue

        # score = (pab - min_count) / pa / pb * train_words (Score formula from word2phrase.c)
        score = (count_bigram - min_count) / count1 / count2 * train_words

        if score > threshold:
            # Combine into phrase
            result.append(f"{word1}_{word2}")
            i += 2  # Skip both words
        else:
            # Keep as separate
            result.append(word1)
            i += 1

    return " ".join(result)


def learn_phrase_vocab(data_path: str, min_count: int = 5) -> Tuple[Dict[str, int], Dict[Tuple[str, str], int], int]:
    """
    Learn vocabulary and bigram counts from training data
    Based on word2phrase.c LearnVocabFromTrainFile() function from the original source code of the Word2Vec paper

    Args:
        data_path: Path to training data directory
        min_count: Minimum word count threshold

    Returns:
        Tuple of (word_counts, bigram_counts, total_words)
    """
    word_counts = defaultdict(int)
    bigram_counts = defaultdict(int)
    total_words = 0

    data_files = [f for f in os.listdir(data_path) if f.startswith("0")]
    data_files.sort()

    print(f"Learning phrase vocabulary from {len(data_files)} files...")

    for file_idx, filename in enumerate(data_files):
        filepath = os.path.join(data_path, filename)
        last_word = None
        start = True

        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    start = True
                    last_word = None
                    continue

                words = line.split()
                for word in words:
                    word = word.lower().strip()
                    if not word:
                        continue

                    total_words += 1

                    # Count unigram
                    word_counts[word] += 1

                    # Count bigram (if not at start of sentence)
                    if not start and last_word:
                        bigram = (last_word, word)
                        bigram_counts[bigram] += 1

                    last_word = word
                    start = False

                # Reset at end of line
                start = True
                last_word = None

        if (file_idx + 1) % 10 == 0:
            print(f"  Processed {file_idx + 1}/{len(data_files)} files...")

    filtered_word_counts = {w: c for w, c in word_counts.items() if c >= min_count}

    print(f"Vocabulary: {len(filtered_word_counts):,} words (min_count={min_count})")
    print(f"Bigrams: {len(bigram_counts):,} unique bigrams")
    print(f"Total words: {total_words:,}")

    return filtered_word_counts, bigram_counts, total_words


def apply_phrases_to_data(data_path: str, output_path: str, word_counts: Dict[str, int],
                          bigram_counts: Dict[Tuple[str, str], int], train_words: int,
                          min_count: int = 5, threshold: float = 100.0) -> str:
    """
    Apply phrase detection to all data files

    Args:
        data_path: Input data directory
        output_path: Output data directory
        word_counts: Word count dictionary
        bigram_counts: Bigram count dictionary
        train_words: Total number of words
        min_count: Minimum word count
        threshold: Phrase score threshold

    Returns:
        Path to output directory
    """
    pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

    data_files = [f for f in os.listdir(data_path) if f.startswith("0")]
    data_files.sort()

    print(f"Applying phrase detection (threshold={threshold}) to {len(data_files)} files...")

    for file_idx, filename in enumerate(data_files):
        input_filepath = os.path.join(data_path, filename)
        output_filepath = os.path.join(output_path, filename)

        with open(input_filepath, 'r', encoding='utf-8') as fin, \
             open(output_filepath, 'w', encoding='utf-8') as fout:

            for line in fin:
                line = line.strip()
                if not line:
                    fout.write('\n')
                    continue

                # Apply phrase detection
                processed_line = detect_phrases(line, word_counts, bigram_counts,
                                                train_words, min_count, threshold)
                fout.write(processed_line + '\n')

        if (file_idx + 1) % 10 == 0:
            print(f"  Processed {file_idx + 1}/{len(data_files)} files...")

    print(f"Phrase detection complete. Output: {output_path}")
    return output_path


def preprocess_with_phrases(data_path: str, output_path: str, min_count: int = 5,
                            threshold1: float = 200.0, threshold2: float = 100.0) -> str:
    """
    Preprocess data with phrase detection (2 passes, like word2phrase.c from the original source code)

    Args:
        data_path: Input data directory
        output_path: Final output directory
        min_count: Minimum word count
        threshold1: First pass threshold (higher, fewer phrases)
        threshold2: Second pass threshold (lower, more phrases)

    Returns:
        Path to final output directory
    """
    print(f"Preprocessing with phrase detection...")
    print(f" -Input: {data_path}")
    print(f" -Output: {output_path}")
    print(f" -Threshold 1: {threshold1} (first pass)")
    print(f" -Threshold 2: {threshold2} (second pass)")

    print("\nStep 1: Learning vocabulary and bigram counts...")
    word_counts, bigram_counts, train_words = learn_phrase_vocab(data_path, min_count)

    print(f"\nStep 2: First pass phrase detection (threshold={threshold1})...")
    temp_path1 = output_path + "_phrase1"
    apply_phrases_to_data(data_path, temp_path1, word_counts, bigram_counts,
                          train_words, min_count, threshold1)

    print("\nStep 3: Relearning vocabulary from first pass...")
    word_counts2, bigram_counts2, train_words2 = learn_phrase_vocab(temp_path1, min_count)

    print(f"\nStep 4: Second pass phrase detection (threshold={threshold2})...")
    apply_phrases_to_data(temp_path1, output_path, word_counts2, bigram_counts2,
                          train_words2, min_count, threshold2)

    import shutil
    if os.path.exists(temp_path1):
        shutil.rmtree(temp_path1)
        print(f"Cleaned up temporary directory: {temp_path1}")

    print(f"\nPhrase preprocessing complete: {output_path}")
    return output_path


def download_wmt14_news(output_dir: str = "./data") -> str:
    """
    Download and combine multiple years of WMT14 and WMT15 News Crawl datasets
    Downloads WMT14 year 2012 and WMT15 year 2014, combines them into a single file
    Returns path to combined news file
    """
    datasets = [
        ("WMT14", 2012, "http://www.statmt.org/wmt14/training-monolingual-news-crawl"),
        ("WMT15", 2014, "https://www.statmt.org/wmt15/training-monolingual-news-crawl"),
    ]

    output_path = os.path.join(output_dir, "wmt14")
    combined_file = os.path.join(output_path, "news.combined.en.shuffled")

    # Create output directory
    pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

    # Check if combined file already exists
    if os.path.isfile(combined_file):
        print(f"WMT14/WMT15 News combined file already exists at: {combined_file}")
        return combined_file

    # Download and extract each dataset
    downloaded_files = []
    for wmt_version, year, base_url in datasets:
        train_file = f"news.{year}.en.shuffled"
        train_gz = f"{train_file}.gz"
        train_url = f"{base_url}/{train_gz}"
        news_file = os.path.join(output_path, train_file)
        gz_path = os.path.join(output_path, train_gz)

        # Check if already extracted
        if os.path.isfile(news_file):
            print(f"{wmt_version} News {year} already exists at: {news_file}")
            downloaded_files.append(news_file)
            continue

        # Download if missing
        if not os.path.isfile(gz_path):
            print(f"Downloading {wmt_version} News {year} ({train_gz})...")
            try:
                with requests.get(train_url, stream=True, timeout=30) as response:
                    response.raise_for_status()
                    total_size = int(response.headers.get('content-length', 0))

                    with open(gz_path, 'wb') as f:
                        with tqdm.tqdm(total=total_size, unit='B', unit_scale=True,
                                     desc=f"Downloading {year}") as pbar:
                            for chunk in response.iter_content(chunk_size=8192):
                                if chunk:
                                    f.write(chunk)
                                    pbar.update(len(chunk))
            except requests.exceptions.RequestException as e:
                print(f"⚠️ Warning: Could not download {train_url}: {e}")
                print(f"⚠️ Skipping {wmt_version} year {year}")
                continue

        # Extract if needed
        if os.path.isfile(gz_path) and not os.path.isfile(news_file):
            print(f"Extracting {gz_path}...")
            try:
                with gzip.open(gz_path, "rb") as source, open(news_file, "wb") as target:
                    target.write(source.read())
                downloaded_files.append(news_file)
                print(f"✅ Extracted {train_file}")
                # Remove gz file to save space
                os.remove(gz_path)
            except Exception as e:
                print(f"⚠️ Error extracting {gz_path}: {e}")
                continue

    if not downloaded_files:
        raise FileNotFoundError("No WMT14/WMT15 News files were successfully downloaded")

    # Combine all downloaded files into one
    print(f"\nCombining {len(downloaded_files)} WMT14/WMT15 News files into: {combined_file}")
    total_lines = 0

    with open(combined_file, 'w', encoding='utf-8') as outfile:
        for i, news_file in enumerate(downloaded_files):
            if not os.path.isfile(news_file):
                print(f"⚠️ Warning: {news_file} not found, skipping")
                continue

            print(f"Adding file {i+1}/{len(downloaded_files)}: {os.path.basename(news_file)}")
            line_count = 0

            with open(news_file, 'r', encoding='utf-8') as infile:
                for line in infile:
                    cleaned = line.strip()
                    if cleaned:  # Skip empty lines
                        outfile.write(cleaned + '\n')
                        line_count += 1
                        total_lines += 1

            print(f"Added {line_count:,} lines")

    # Get file size
    file_size = os.path.getsize(combined_file) / (1024**3)  # GB

    print(f"\n✅ Combined WMT14/WMT15 News dataset created:")
    print(f" -File: {combined_file}")
    print(f" -Total lines: {total_lines:,}")
    print(f" -Size: {file_size:.2f} GB")
    print(f" -Estimated words: ~{total_lines * 20:,} (assuming ~20 words/line)")

    return combined_file


def download_text8(output_dir: str = "./data") -> str:
    """
    Download text8 dataset from http://mattmahoney.net/dc/text8.zip
    Returns path to downloaded text8 file
    """
    url = "http://mattmahoney.net/dc/text8.zip"
    output_path = os.path.join(output_dir, "text8")
    text8_file = os.path.join(output_path, "text8")

    # Create output directory
    pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

    # Check if already exists
    if os.path.isfile(text8_file):
        print(f"Text8 file already exists at: {text8_file}")
        return text8_file

    zip_path = os.path.join(output_path, "text8.zip")

    print(f"Downloading text8 from {url}...")
    with requests.get(url, stream=True) as response:
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))

        with open(zip_path, 'wb') as f:
            with tqdm.tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        f.write(chunk)
                        pbar.update(len(chunk))

    print(f"Extracting {zip_path}...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(output_path)

    # Remove zip file to save space
    os.remove(zip_path)

    print(f"Text8 dataset ready at: {text8_file}")
    return text8_file


def preprocess_wmt14_news(news_file_path: str, output_dir: str, words_per_sentence: int = 1000,
                        max_sentences: int = None, max_files: int = None, use_phrases: bool = False,
                        phrase_threshold1: float = 200.0, phrase_threshold2: float = 100.0) -> str:
    """
    Preprocess WMT14 news file into sentence files

    Args:
        news_file_path: Path to WMT14 news file
        output_dir: Output directory for processed files
        words_per_sentence: Number of words per sentence (default: 1000)
        max_sentences: Maximum number of sentences to process (None = all)
        max_files: Maximum number of files to create (None = all)
        use_phrases: Whether to apply phrase detection (default: False)
        phrase_threshold1: First pass phrase threshold (default: 200.0)
        phrase_threshold2: Second pass phrase threshold (default: 100.0)

    Returns:
        Path to output directory
    """
    print(f"Preprocessing WMT14 news file: {news_file_path}")
    print(f"Output directory: {output_dir}")
    print(f"Words per sentence: {words_per_sentence}")
    print("Note: Punctuation will be removed from text (commas, periods, etc.)")
    if max_sentences:
        print(f"Max sentences: {max_sentences:,}")
    if max_files:
        print(f"Max files: {max_files}")
    if use_phrases:
        print(f"Phrase detection: Enabled (threshold1={phrase_threshold1}, threshold2={phrase_threshold2})")

    # Create output directory
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Check if already processed
    existing_files = [f for f in os.listdir(output_dir) if f.startswith("0")]
    if existing_files:
        print(f"Found {len(existing_files)} existing processed files. Skipping preprocessing.")
        print("⚠️ WARNING: If these files contain punctuation, delete them and reprocess to apply cleaning.")
        return output_dir

    # Step 1: Basic preprocessing
    temp_dir = output_dir + "_temp"
    pathlib.Path(temp_dir).mkdir(parents=True, exist_ok=True)

    # Read news file (one sentence per line)
    sentences = []
    sentence_count = 0

    with open(news_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            # Clean text: remove punctuation and normalize
            cleaned_line = clean_text_remove_punctuation(line)
            if cleaned_line:  # Skip empty lines after cleaning
                # Split into words and group into chunks
                words = cleaned_line.split()
                for i in range(0, len(words), words_per_sentence):
                    sentence_words = words[i:i + words_per_sentence]
                    if len(sentence_words) >= 2:  # Skip very short sentences
                        sentences.append(" ".join(sentence_words))
                        sentence_count += 1

                        # Stop if we've reached max_sentences
                        if max_sentences and sentence_count >= max_sentences:
                            print(f"Reached max_sentences limit: {max_sentences:,}")
                            break

                # Break outer loop if we've reached max_sentences
                if max_sentences and sentence_count >= max_sentences:
                    break

    print(f"Total sentences: {len(sentences):,}")

    # Save to temporary files
    sentences_per_file = 100000
    file_count = 0
    current_file_sentences = []

    for i, sentence in enumerate(sentences):
        current_file_sentences.append(sentence)

        # Write file when it reaches sentences_per_file or we're at the end
        if len(current_file_sentences) >= sentences_per_file or i == len(sentences) - 1:
            filename = f"{file_count:04d}"
            filepath = os.path.join(temp_dir, filename)

            with open(filepath, 'w', encoding='utf-8') as f:
                for sent in current_file_sentences:
                    f.write(sent + '\n')

            print(f"Wrote {len(current_file_sentences):,} sentences to {filename}")
            file_count += 1
            current_file_sentences = []

            # Stop if we've reached max_files
            if max_files and file_count >= max_files:
                print(f"Reached max_files limit: {max_files}")
                break

    # Step 2: Apply phrase detection if enabled
    if use_phrases:
        print("\nApplying phrase detection...")
        preprocess_with_phrases(temp_dir, output_dir, min_count=5,
                               threshold1=phrase_threshold1, threshold2=phrase_threshold2)
        # Cleanup temp directory
        import shutil
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
    else:
        # Just move files from temp to output
        import shutil
        if os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        shutil.move(temp_dir, output_dir)

    print(f"Preprocessing complete. Created {file_count} files in {output_dir}")
    return output_dir


def preprocess_text8(text8_file_path: str, output_dir: str, words_per_sentence: int = 1000,
                    use_phrases: bool = False, phrase_threshold1: float = 200.0,
                    phrase_threshold2: float = 100.0) -> str:
    """
    Preprocess text8 file into sentence files.

    Args:
        text8_file_path: Path to text8 file
        output_dir: Output directory for processed files
        words_per_sentence: Number of words per sentence (default: 1000)
        use_phrases: Whether to apply phrase detection (default: False)
        phrase_threshold1: First pass phrase threshold (default: 200.0)
        phrase_threshold2: Second pass phrase threshold (default: 100.0)

    Returns:
        Path to output directory
    """
    print(f"Preprocessing text8 file: {text8_file_path}")
    print(f"Output directory: {output_dir}")
    print(f"Words per sentence: {words_per_sentence}")
    if use_phrases:
        print(f"Phrase detection: Enabled (threshold1={phrase_threshold1}, threshold2={phrase_threshold2})")

    # Create output directory
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Check if already processed
    existing_files = [f for f in os.listdir(output_dir) if f.startswith("0")]
    if existing_files:
        print(f"Found {len(existing_files)} existing processed files. Skipping preprocessing.")
        return output_dir

    # Step 1: Basic preprocessing
    temp_dir = output_dir + "_temp"
    pathlib.Path(temp_dir).mkdir(parents=True, exist_ok=True)

    # Read text8 file (single long line)
    with open(text8_file_path, 'r', encoding='utf-8') as f:
        text = f.read().strip()

    # Split into words
    words = text.split()
    print(f"Total words: {len(words):,}")

    # Group into sentences
    sentences = []
    for i in range(0, len(words), words_per_sentence):
        sentence_words = words[i:i + words_per_sentence]
        if len(sentence_words) >= 2:  # Skip very short sentences
            sentences.append(" ".join(sentence_words))

    print(f"Created {len(sentences):,} sentences")

    # Save to temporary files (similar to myw2v format)
    sentences_per_file = 100000
    file_count = 0
    current_file_sentences = []

    for i, sentence in enumerate(sentences):
        current_file_sentences.append(sentence)

        # Write file when it reaches sentences_per_file or we're at the end
        if len(current_file_sentences) >= sentences_per_file or i == len(sentences) - 1:
            filename = f"{file_count:04d}"
            filepath = os.path.join(temp_dir, filename)

            with open(filepath, 'w', encoding='utf-8') as f:
                for sent in current_file_sentences:
                    f.write(sent + '\n')

            print(f"Wrote {len(current_file_sentences):,} sentences to {filename}")
            file_count += 1
            current_file_sentences = []

    # Step 2: Apply phrase detection if enabled
    if use_phrases:
        print("\nApplying phrase detection...")
        preprocess_with_phrases(temp_dir, output_dir, min_count=5,
                               threshold1=phrase_threshold1, threshold2=phrase_threshold2)
        # Cleanup temp directory
        import shutil
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
    else:
        # Just move files from temp to output
        import shutil
        if os.path.exists(output_dir):
            shutil.rmtree(output_dir)
        shutil.move(temp_dir, output_dir)

    print(f"Preprocessing complete. Created {file_count} files in {output_dir}")
    return output_dir


def get_data_file_names(path: str, seed: int) -> List[str]:
    """
    Get shuffled list of data file names
    """
    import numpy as np
    rng = np.random.default_rng(seed=seed)
    qq = [fn for fn in os.listdir(path) if fn.startswith("0")]
    # Sort first to ensure consistent shuffling
    data_files = sorted(qq)
    rng.shuffle(data_files)
    return data_files


def read_all_data_files(data_path: str, file_names: List[str], word_to_idx: dict) -> Tuple[List[int], List[int], List[int]]:
    """
    Read all data files and convert words to indices
    Returns (inputs, offsets, lengths)
    """
    from collections import defaultdict

    start = time.time()
    inps, offs, lens = [], [], []
    offset_total = 0
    stats = defaultdict(int)

    for fn in file_names:
        fp = os.path.join(data_path, fn)
        ok_lines = 0
        too_short_lines = 0
        with open(fp, encoding="utf-8") as f:
            for line in f:
                words = [word for word in re.split(r"[ .]+", line.strip()) if word]
                if len(words) < 2:
                    too_short_lines += 1
                    continue
                idcs = [word_to_idx[w] for w in words if w in word_to_idx]
                le = len(idcs)
                ok_lines += 1
                offs.append(offset_total)
                lens.append(le)
                inps.extend(idcs)
                offset_total += le
        stats["file_read_lines_ok"] += ok_lines
        stats["one_word_sentence_lines_which_were_ignored"] += too_short_lines

    print(f"read_all_data_files() STATS: {stats}")
    tot_tm = time.time()-start
    print(f"read_all_data_files() Total time {tot_tm} s for {len(file_names)} files (avg {tot_tm/len(file_names)} s/file)")
    return inps, offs, lens


# IV. **Skip-gram Implementation**

This cell implements the Skip-gram architecture for Word2Vec training using GPU-accelerated CUDA kernels. Skip-gram predicts surrounding context words from a center word, making it effective for learning word representations.

In [None]:
import math
import os
import time
from typing import List, Tuple, Dict, Any

from numba import cuda
from numba.cuda import random as c_random
import numpy as np
from numpy import ndarray

@cuda.jit
def calc_skipgram(
        rows: int,
        c: int,
        k: int,
        learning_rate: float,
        w1,
        w2,
        calc_aux,
        random_states,
        subsample_weights,
        negsample_array,
        inp,
        offsets,
        lengths,
        use_hs,
        syn1,
        codes_array,
        points_array,
        code_lengths,
        exp_table,
        exp_table_size,
        max_exp):
    """
    CUDA kernel for Skip-gram training
    Based on word2vec.c Skip-gram implementation from the original source code of the Word2Vec paper
    Supports both Hierarchical Softmax and Negative Sampling
    """
    idx = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
    if idx >= rows:
        return
    le = lengths[idx]
    off = offsets[idx]

    for centre in range(0, le):
        word_idx = inp[off + centre]
        prob_to_reject = subsample_weights[word_idx]
        rnd = c_random.xoroshiro128p_uniform_float32(random_states, idx)

        if rnd > prob_to_reject:
            r_f = c_random.xoroshiro128p_uniform_float32(random_states, idx)
            r: int = math.ceil(r_f * c)

            # Context before center word
            for context_pre in range(max(0, centre-r), centre):
                step_skipgram(idx, w1, w2, calc_aux, inp[off+centre], inp[off+context_pre],
                             k, learning_rate, negsample_array, random_states,
                             use_hs, syn1, codes_array, points_array, code_lengths,
                             exp_table, exp_table_size, max_exp)

            # Context after center word
            for context_post in range(centre + 1, min(le, centre + 1 + r)):
                step_skipgram(idx, w1, w2, calc_aux, inp[off+centre], inp[off+context_post],
                             k, learning_rate, negsample_array, random_states,
                             use_hs, syn1, codes_array, points_array, code_lengths,
                             exp_table, exp_table_size, max_exp)


@cuda.jit(device=True)
def fast_sigmoid(f, exp_table, exp_table_size, max_exp):
    """
    Fast sigmoid using precomputed exp table
    Based on word2vec.c exp table lookup from the original source code of the Word2Vec paper
    """
    if f <= -max_exp:
        return 0.0
    elif f >= max_exp:
        return 1.0
    else:
        idx = int((f + max_exp) * (exp_table_size / max_exp / 2.0))
        if idx < 0:
            idx = 0
        if idx >= exp_table_size:
            idx = exp_table_size - 1
        return exp_table[idx]


@cuda.jit(device=True)
def step_skipgram(thread_idx, w1, w2, calc_aux, x, y, k, learning_rate, negsample_array, random_states,
                  use_hs, syn1, codes_array, points_array, code_lengths,
                  exp_table, exp_table_size, max_exp):
    """
    Device function for Skip-gram gradient calculation
    Based on word2vec.c Skip-gram implementation from the original source code of the Word2Vec paper
    Supports both Hierarchical Softmax and Negative Sampling
    """
    emb_dim = w1.shape[1]
    negs_arr_len = len(negsample_array)

    # Initialize error accumulator
    for i in range(emb_dim):
        calc_aux[thread_idx, i] = 0.0

    # Hierarchical Softmax (if enabled) - traverse tree for context word y
    if use_hs:
        codelen = code_lengths[y]
        max_code_len = codes_array.shape[1]
        for d in range(codelen):
            if d >= max_code_len:
                break
            node_idx = points_array[y, d]
            if node_idx < 0:
                continue

            # Calculate dot product: w1[x] • syn1[node]
            f = 0.0
            for i in range(emb_dim):
                f += w1[x, i] * syn1[node_idx, i]

            # Early skip if f is outside range (same as original code)
            # This prevents unnecessary updates when sigmoid is saturated
            if f <= -max_exp:
                continue
            if f >= max_exp:
                continue

            # Get sigmoid from exp table (only if in range)
            sigmoid_val = fast_sigmoid(f, exp_table, exp_table_size, max_exp)

            # Get code bit (0 or 1)
            code_bit = codes_array[y, d]
            if code_bit < 0:
                continue

            # Calculate gradient: g = (1 - code_bit - sigmoid) * learning_rate
            g = (1.0 - float(code_bit) - sigmoid_val) * learning_rate

            # Propagate errors output -> hidden
            for i in range(emb_dim):
                calc_aux[thread_idx, i] += g * syn1[node_idx, i]

            # Learn weights hidden -> output
            for i in range(emb_dim):
                syn1[node_idx, i] += g * w1[x, i]

    # Negative Sampling (if enabled)
    if k > 0:
        # Positive sample: predict context word y
        dot_xy = 0.0
        for i in range(emb_dim):
            dot_xy += w1[x, i] * w2[y, i]
        s_xdy_m1 = fast_sigmoid(dot_xy, exp_table, exp_table_size, max_exp) - 1.0

        # Positive sample gradients
        for i in range(emb_dim):
            calc_aux[thread_idx, i] += -learning_rate * s_xdy_m1 * w2[y, i]
            w2[y, i] -= learning_rate * s_xdy_m1 * w1[x, i]

        # Negative samples
        for neg_sample in range(0, k):
            rnd = c_random.xoroshiro128p_uniform_float32(random_states, thread_idx)
            q_idx: int = int(math.floor(negs_arr_len * rnd))
            neg = negsample_array[q_idx]
            dot_xq = 0.0
            for i in range(emb_dim):
                dot_xq += w1[x, i] * w2[neg, i]
            s_dxq = fast_sigmoid(dot_xq, exp_table, exp_table_size, max_exp)

            # Negative sample gradients
            for i in range(emb_dim):
                calc_aux[thread_idx, i] -= learning_rate * s_dxq * w2[neg, i]
                w2[neg, i] -= learning_rate * s_dxq * w1[x, i]

    # Note: Original code does NOT use gradient clipping, only early skip
    # Gradient clipping may reduce training effectiveness
    # Update center word vector (same as original code)
    for i in range(emb_dim):
        w1[x, i] += calc_aux[thread_idx, i]


def train_skipgram(
        data_path: str,
        out_file_path: str,
        epochs: int,
        embed_dim: int = 100,
        min_occurs: int = 3,
        c: int = 5,
        k: int = 5,
        t: float = 1e-5,
        vocab_freq_exponent: float = 0.75,
        lr_max: float = 0.025,
        lr_min: float = 0.0025,
        cuda_threads_per_block: int = 32,
        hs: int = 0,
        max_memory_gb: float = 70.0,
        max_words: int = None,
        vocab: list = None,
        w_to_i: dict = None,
        word_counts: list = None,
        ssw: np.ndarray = None,
        negs: np.ndarray = None):
    """
    Train Skip-gram model
    Based on word2vec.c Skip-gram implementation from the original source code of the Word2Vec paper

    Args:
        hs: Hierarchical Softmax flag (0=NS only, 1=HS only). Cannot combine with k>0
        k: Negative sampling count (0=HS only, >0=NS only). Cannot combine with hs=1
        max_memory_gb: Maximum GPU memory usage in GB. If estimated memory exceeds this,
                       the dataset will be automatically split into batches for processing
                       Default: 70.0 GB (safe for A100 80GB GPU)

    Raises:
        ValueError: If both hs=1 and k>0 are specified (HS and NS cannot be combined)
    """
    # Validate: HS and NS cannot be used together
    if hs == 1 and k > 0:
        raise ValueError(
            "Error: Cannot use HS (hs=1) and Negative Sampling (k>0) together. "
            "Please choose either HS only (hs=1, k=0) or NS only (hs=0, k>0)"
        )

    params = {
        "model_type": "skipgram",
        "w2v_version": W2V_VERSION,
        "data_path": data_path,
        "out_file_path": out_file_path,
        "epochs": epochs,
        "embed_dim": embed_dim,
        "min_occurs": min_occurs,
        "c": c,
        "k": k,
        "t": t,
        "vocab_freq_exponent": vocab_freq_exponent,
        "lr_max": lr_max,
        "lr_min": lr_min,
        "cuda_threads_per_block": cuda_threads_per_block,
        "hs": hs
    }
    stats = {}
    params_path = out_file_path + "_params.json"
    stats_path = out_file_path + "_stats.json"

    seed = 12345

    # Adjust learning rate based on training method
    original_lr_max = lr_max
    original_lr_min = lr_min

    # Learning rate handling: HS only and NS only use the same learning rate
    # (as per word2vec.c original implementation)
    # No special adjustment needed for either method

    # Learning rate schedule
    # For multiple epochs: decrease between epochs
    # For all epochs: decrease linearly within epoch (as per word2vec.c)
    if epochs > 1:
        lr_step = (lr_max - lr_min) / (epochs - 1)
    else:
        lr_step = 0.0  # Not used for single epoch (LR decays within epoch)

    print(f"Skip-gram Training Parameters:")
    print(f"Seed: {seed}")
    print(f"Window size: {c}")
    if hs == 1:
        print(f"Hierarchical Softmax: Enabled")
    if k > 0:
        print(f"Negative samples: {k}")
    if original_lr_max != lr_max:
        print(f"Learning rate adjusted: {original_lr_max} -> {lr_max} (reduced for stability)")
    if epochs == 1:
        print(f"Learning rate: {lr_max} -> ~0 (will decrease linearly within epoch, as per word2vec.c)")
    else:
        print(f"Learning rate: {lr_max} -> {lr_min} (step: {lr_step:.6f} between epochs, also decreases linearly within each epoch)")
    print(f"Embedding dimension: {embed_dim}")
    print(f"Min word count: {min_occurs}")

    # Start timing for total execution
    start = time.time()

    # Build vocabulary if not provided (for reuse when training both models)
    if vocab is None or w_to_i is None or word_counts is None:
        print(f"\nBuilding vocabulary from: {data_path}")
        vocab_start = time.time()
        vocab, w_to_i, word_counts = handle_vocab(data_path, min_occurs, freq_exponent=vocab_freq_exponent, use_cache=True)
        vocab_size = len(vocab)
        build_time = time.time() - vocab_start
        print(f"Vocabulary {'loaded from cache' if build_time < 1.0 else 'built'} in {build_time:.2f}s. Vocab size: {vocab_size:,}")
    else:
        vocab_size = len(vocab)
        print(f"\nUsing pre-built vocabulary. Vocab size: {vocab_size:,}")

    # Build subsampling weights and negative sampling array if not provided
    if ssw is None or negs is None:
        ssw, negs = get_subsampling_weights_and_negative_sampling_array(vocab, t=t)

    # Create exp table
    print("Creating exp table for fast sigmoid...")
    exp_table = create_exp_table(EXP_TABLE_SIZE, MAX_EXP)

    # Setup Hierarchical Softmax if enabled
    use_hs = (hs == 1)
    syn1_cuda = None
    codes_array_cuda = None
    points_array_cuda = None
    code_lengths_cuda = None

    if use_hs:
        print("Creating Huffman tree for Hierarchical Softmax...")
        hs_start = time.time()
        codes_array, points_array, code_lengths = create_huffman_tree(word_counts, MAX_CODE_LENGTH)
        syn1 = init_hs_weight_matrix(vocab_size, embed_dim)
        print(f"Huffman tree created in {time.time() - hs_start:.2f}s")
        print(f" -Codes array shape: {codes_array.shape}")
        print(f" -Points array shape: {points_array.shape}")
        print(f" -Syn1 matrix shape: {syn1.shape}")

    data_files = get_data_file_names(data_path, seed=seed)
    print(f"Processing {len(data_files)} data files...")
    if max_words is not None:
        print(f"⚠️ Limiting to {max_words:,} total words (will stop early if reached)")
    inps_, offs_, lens_ = read_all_data_files_ever(data_path, data_files, w_to_i, max_words=max_words)
    inps, offs, lens = (np.asarray(inps_, dtype=np.int32),
                       np.asarray(offs_, dtype=np.int32),
                       np.asarray(lens_, dtype=np.int32))
    sentence_count = len(lens)
    total_words = len(inps)  # Total words for LR decay calculation

    print(f"Data loaded: {sentence_count:,} sentences, {total_words:,} total words")

    # Initialize weight matrices
    data_init_start = time.time()
    w1, w2 = init_weight_matrices(vocab_size, embed_dim, seed=seed)
    data_size_weights = 4 * (w1.size + w2.size)
    data_size_inputs = 4 * (inps.size + offs.size + lens.size + ssw.size + negs.size)

    # Calculate memory usage and determine batch size
    weights_gb = data_size_weights / (1024**3)
    inputs_gb = data_size_inputs / (1024**3)

    # Estimate calc_aux memory for full dataset
    calc_aux_size_full = sentence_count * embed_dim * 4
    calc_aux_gb_full = calc_aux_size_full / (1024**3)
    total_memory_gb = weights_gb + inputs_gb + calc_aux_gb_full

    # Determine if batch processing is needed
    use_batch_processing = (total_memory_gb > max_memory_gb)

    if use_batch_processing:
        # Calculate batch size based on available memory
        available_memory_gb = max_memory_gb - weights_gb - inputs_gb
        # Reserve 5GB for overhead
        available_memory_gb = max(1.0, available_memory_gb - 5.0)

        # Calculate max sentences per batch
        bytes_per_sentence = embed_dim * 4  # float32
        max_batch_sentences = int((available_memory_gb * 1024**3) / bytes_per_sentence)

        if max_batch_sentences >= 10_000_000:
            batch_size = 10_000_000
        elif max_batch_sentences >= 5_000_000:
            batch_size = 5_000_000
        elif max_batch_sentences >= 2_000_000:
            batch_size = 2_000_000
        elif max_batch_sentences >= 1_000_000:
            batch_size = 1_000_000
        else:
            batch_size = max(100_000, max_batch_sentences)

        num_batches = math.ceil(sentence_count / batch_size)
        batch_aux_gb = (batch_size * embed_dim * 4) / (1024**3)
        batch_total_gb = weights_gb + inputs_gb + batch_aux_gb

        print(f"\n⚠️ Memory usage would be {total_memory_gb:.1f} GB (exceeds {max_memory_gb} GB limit)")
        print(f"Using batch processing: {num_batches} batches, {batch_size:,} sentences/batch")
        print(f"Memory per batch: {batch_total_gb:.1f} GB (calc_aux: {batch_aux_gb:.1f} GB)")
    else:
        batch_size = sentence_count
        num_batches = 1
        print(f"\n✅ Memory usage: {total_memory_gb:.1f} GB (within {max_memory_gb} GB limit)")
        print(f"Processing all {sentence_count:,} sentences in one batch")

    blocks: int = math.ceil(batch_size / cuda_threads_per_block)
    print(f"CUDA config: {cuda_threads_per_block} threads/block, {blocks} blocks per batch")

    # Transfer to GPU - Transfer weights and vocab arrays (these are shared across batches)
    print("Transferring data to GPU...")
    data_transfer_start = time.time()
    ssw_cuda, negs_cuda = cuda.to_device(ssw), cuda.to_device(negs)
    w1_cuda, w2_cuda = cuda.to_device(w1), cuda.to_device(w2)
    exp_table_cuda = cuda.to_device(exp_table)

    # Keep input arrays on CPU - will slice and transfer per batch
    # This saves GPU memory

    if use_hs:
        syn1_cuda = cuda.to_device(syn1)
        codes_array_cuda = cuda.to_device(codes_array)
        points_array_cuda = cuda.to_device(points_array)
        code_lengths_cuda = cuda.to_device(code_lengths)

    print(f"Data transfer completed in {time.time()-data_transfer_start:.2f}s")

    stats["sentence_count"] = len(lens)
    stats["word_count"] = len(inps)
    stats["vocab_size"] = vocab_size
    stats["approx_data_size_weights"] = data_size_weights
    stats["approx_data_size_inputs"] = data_size_inputs
    stats["use_batch_processing"] = use_batch_processing
    if use_batch_processing:
        stats["batch_size"] = batch_size
        stats["num_batches"] = num_batches
        batch_aux_size = batch_size * embed_dim * 4
        stats["approx_data_size_aux_per_batch"] = batch_aux_size
        stats["approx_data_size_total"] = data_size_weights + data_size_inputs + batch_aux_size
    else:
        data_size_aux = 4 * (sentence_count * embed_dim)
        stats["approx_data_size_aux"] = data_size_aux
        stats["approx_data_size_total"] = data_size_weights + data_size_inputs + data_size_aux

    # Prepare HS parameters (use dummy arrays if HS disabled)
    if not use_hs:
        # Create dummy arrays for HS (will not be used, but needed for kernel signature)
        dummy_syn1 = cuda.device_array((1, embed_dim), dtype=np.float32)
        dummy_codes = cuda.device_array((vocab_size, MAX_CODE_LENGTH), dtype=np.int32)
        dummy_points = cuda.device_array((vocab_size, MAX_CODE_LENGTH), dtype=np.int32)
        dummy_lengths = cuda.device_array(vocab_size, dtype=np.int32)
        syn1_param = dummy_syn1
        codes_param = dummy_codes
        points_param = dummy_points
        lengths_param = dummy_lengths
    else:
        syn1_param = syn1_cuda
        codes_param = codes_array_cuda
        points_param = points_array_cuda
        lengths_param = code_lengths_cuda

    print_norms(w1_cuda)
    print(f"\nStarting Skip-gram training - {epochs} epochs...")
    epoch_times = []
    calc_start = time.time()

    # Track total words processed across all epochs (as per word2vec.c)
    # Learning rate decays based on total words processed, not per epoch
    # Use int64 to avoid overflow with large datasets and multiple epochs
    words_processed_total = np.int64(0)
    total_words_for_training = np.int64(epochs) * np.int64(total_words)

    for epoch in range(0, epochs):
        epoch_start = time.time()

        # Process each batch
        for batch_idx in range(num_batches):
            batch_start = batch_idx * batch_size
            batch_end = min((batch_idx + 1) * batch_size, sentence_count)
            batch_sentence_count = batch_end - batch_start

            if num_batches > 1:
                print(f"  Epoch {epoch+1}, Batch {batch_idx+1}/{num_batches}: sentences {batch_start:,}-{batch_end:,}")

            # Calculate word offset for this batch (offsets are cumulative)
            batch_word_start = offs[batch_start] if batch_start < len(offs) else 0
            batch_word_end = offs[batch_end] if batch_end < len(offs) else len(inps)
            batch_word_count = batch_word_end - batch_word_start

            # Calculate learning rate for this batch (linear decay as per word2vec.c)
            # Formula from word2vec.c: alpha = starting_alpha * (1 - word_count_actual / (iter * train_words + 1))
            # word_count_actual is total words processed across all epochs
            # This ensures LR decreases linearly from lr_max to ~0 over entire training
            denominator = total_words_for_training + 1
            current_lr = lr_max * (1.0 - words_processed_total / denominator) if denominator > 0 else lr_max

            # Apply minimum threshold (as per word2vec.c: min = starting_alpha * 0.0001)
            min_lr_threshold = lr_max * 0.0001
            current_lr = max(current_lr, min_lr_threshold)

            # Also apply lr_min as additional constraint (for multi-epoch training)
            if epochs > 1:
                current_lr = max(current_lr, lr_min)

            if num_batches > 1 and batch_idx == 0:
                print(f"    Learning rate: {current_lr:.6f} (decaying linearly, progress: {words_processed_total/total_words_for_training*100:.1f}%)")

            # Create batch arrays (slicing from CPU arrays)
            batch_lens = lens[batch_start:batch_end]
            batch_offs_local = offs[batch_start:batch_end] - batch_word_start  # Adjust offsets to start from 0
            batch_inps_local = inps[batch_word_start:batch_word_end]

            # Transfer batch arrays to GPU
            batch_lens_cuda = cuda.to_device(batch_lens)
            batch_offs_cuda = cuda.to_device(batch_offs_local)
            batch_inps_cuda = cuda.to_device(batch_inps_local)

            # Create calc_aux for this batch
            batch_calc_aux = np.zeros((batch_sentence_count, embed_dim), dtype=np.float32)
            batch_calc_aux_cuda = cuda.to_device(batch_calc_aux)

            # Create random states for this batch
            batch_random_states_cuda = c_random.create_xoroshiro128p_states(
                batch_sentence_count, seed=seed + epoch * 10000 + batch_idx * 100
            )

            # Launch CUDA kernel for this batch with current learning rate
            batch_blocks = math.ceil(batch_sentence_count / cuda_threads_per_block)
            calc_skipgram[batch_blocks, cuda_threads_per_block](
                batch_sentence_count, c, k, current_lr, w1_cuda, w2_cuda, batch_calc_aux_cuda,
                batch_random_states_cuda, ssw_cuda, negs_cuda, batch_inps_cuda,
                batch_offs_cuda, batch_lens_cuda,
                use_hs, syn1_param, codes_param, points_param, lengths_param,
                exp_table_cuda, EXP_TABLE_SIZE, MAX_EXP)

            # Update total words processed counter (as per word2vec.c)
            # Note: Actual words processed may vary due to subsampling, but this is an approximation
            # Use int64 to avoid overflow with large datasets and multiple epochs
            words_processed_total = np.int64(words_processed_total) + np.int64(batch_word_count)

            # Free batch arrays from GPU memory
            del batch_lens_cuda, batch_offs_cuda, batch_inps_cuda, batch_calc_aux_cuda, batch_random_states_cuda

        # Synchronize after all batches
        sync_start = time.time()
        cuda.synchronize()
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)

        # Final LR after epoch (using same formula as word2vec.c)
        denominator = total_words_for_training + 1
        final_lr = lr_max * (1.0 - words_processed_total / denominator) if denominator > 0 else lr_max
        final_lr = max(final_lr, lr_max * 0.0001)
        if epochs > 1:
            final_lr = max(final_lr, lr_min)

        progress_percent = (words_processed_total / total_words_for_training * 100) if total_words_for_training > 0 else 0.0
        print(f"  Epoch {epoch+1} completed in {epoch_time:.2f}s (LR: {final_lr:.6f}, Progress: {progress_percent:.1f}%)")

    print(f"\nSkip-gram training completed!")
    print(f"Epoch times - Min: {min(epoch_times):.2f}s, Avg: {np.mean(epoch_times):.2f}s, Max: {max(epoch_times):.2f}s")
    print(f"Total training time: {time.time()-calc_start:.2f}s")
    print(f"Total time: {time.time()-start:.2f}s")

    print_norms(w1_cuda)

    # Save results
    stats["epoch_time_min_seconds"] = min(epoch_times)
    stats["epoch_time_avg_seconds"] = np.mean(epoch_times)
    stats["epoch_time_max_seconds"] = max(epoch_times)
    stats["epoch_time_total_seconds"] = sum(epoch_times)
    stats["epoch_times_all_seconds"] = epoch_times

    print(f"Saving Skip-gram vectors to: {out_file_path}")
    write_vectors(w1_cuda, vocab, out_file_path)

    print(f"Saving parameters to: {params_path}")
    write_json(params, params_path)

    print(f"Saving statistics to: {stats_path}")
    write_json(stats, stats_path)

    print("Skip-gram training completed successfully!")


# V. **CBOW Implementation**

This cell implements the Continuous Bag-of-Words (CBOW) architecture for Word2Vec training using GPU-accelerated CUDA kernels. CBOW predicts a center word from surrounding context words, making it faster but typically less accurate than Skip-gram for rare words.

In [None]:
import math
import os
import time
from typing import List, Tuple, Dict, Any

from numba import cuda
from numba.cuda import random as c_random
import numpy as np
from numpy import ndarray

@cuda.jit
def calc_cbow(
        rows: int,
        c: int,
        k: int,
        learning_rate: float,
        w1,
        w2,
        calc_aux,
        random_states,
        subsample_weights,
        negsample_array,
        inp,
        offsets,
        lengths,
        use_hs,
        syn1,
        codes_array,
        points_array,
        code_lengths,
        exp_table,
        exp_table_size,
        max_exp):
    """
    CUDA kernel for CBOW training
    Based on word2vec.c CBOW implementation from the original source code of the Word2Vec paper
    Supports both Hierarchical Softmax and Negative Sampling
    """
    idx = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x
    if idx >= rows:
        return
    le = lengths[idx]
    off = offsets[idx]

    for centre in range(0, le):
        word_idx = inp[off + centre]
        prob_to_reject = subsample_weights[word_idx]
        rnd = c_random.xoroshiro128p_uniform_float32(random_states, idx)

        if rnd > prob_to_reject:
            r_f = c_random.xoroshiro128p_uniform_float32(random_states, idx)
            r: int = math.ceil(r_f * c)

            # Collect context words (before and after center word)
            context_words = cuda.local.array(64, dtype=np.int32)  # Max 2*c context words
            context_count = 0

            # Context before center word
            for context_pre in range(max(0, centre-r), centre):
                if context_count < 20:  # Prevent overflow
                    context_words[context_count] = inp[off+context_pre]
                    context_count += 1

            # Context after center word
            for context_post in range(centre + 1, min(le, centre + 1 + r)):
                if context_count < 20:  # Prevent overflow
                    context_words[context_count] = inp[off+context_post]
                    context_count += 1

            # Only proceed if we have context words
            if context_count > 0:
                step_cbow(idx, w1, w2, calc_aux, context_words, context_count,
                         inp[off+centre], k, learning_rate, negsample_array, random_states,
                         use_hs, syn1, codes_array, points_array, code_lengths,
                         exp_table, exp_table_size, max_exp)


@cuda.jit(device=True)
def fast_sigmoid(f, exp_table, exp_table_size, max_exp):
    """
    Fast sigmoid using precomputed exp table
    Based on word2vec.c exp table lookup
    """
    if f <= -max_exp:
        return 0.0
    elif f >= max_exp:
        return 1.0
    else:
        idx = int((f + max_exp) * (exp_table_size / max_exp / 2.0))
        if idx < 0:
            idx = 0
        if idx >= exp_table_size:
            idx = exp_table_size - 1
        return exp_table[idx]


@cuda.jit(device=True)
def step_cbow(thread_idx, w1, w2, calc_aux, context_words, context_count,
              center_word, k, learning_rate, negsample_array, random_states,
              use_hs, syn1, codes_array, points_array, code_lengths,
              exp_table, exp_table_size, max_exp):
    """
    Device function for CBOW gradient calculation
    Based on word2vec.c CBOW implementation from the original source code of the Word2Vec paper
    Supports both Hierarchical Softmax and Negative Sampling
    """
    emb_dim = w1.shape[1]
    negs_arr_len = len(negsample_array)

    # 1. Calculate neu1 = average of context word vectors
    neu1 = cuda.local.array(1000, dtype=np.float32)  # Max embedding dimension
    neu1e = cuda.local.array(1000, dtype=np.float32)  # Error accumulation

    # Initialize neu1 and neu1e
    for i in range(emb_dim):
        neu1[i] = 0.0
        neu1e[i] = 0.0

    # Average context word vectors
    for i in range(emb_dim):
        for ctx_idx in range(context_count):
            neu1[i] += w1[context_words[ctx_idx], i]
        neu1[i] /= context_count

    # 2. Hierarchical Softmax (if enabled)
    if use_hs:
        codelen = code_lengths[center_word]
        max_code_len = codes_array.shape[1]  # Get max code length from array shape
        for d in range(codelen):
            if d >= max_code_len:
                break
            node_idx = points_array[center_word, d]
            if node_idx < 0:
                continue

            # Calculate dot product: neu1 • syn1[node]
            f = 0.0
            for i in range(emb_dim):
                f += neu1[i] * syn1[node_idx, i]

            # Early skip if f is outside range (same as original code)
            # This prevents unnecessary updates when sigmoid is saturated
            if f <= -max_exp:
                continue
            if f >= max_exp:
                continue

            # Get sigmoid from exp table (only if in range)
            sigmoid_val = fast_sigmoid(f, exp_table, exp_table_size, max_exp)

            # Get code bit (0 or 1)
            code_bit = codes_array[center_word, d]
            if code_bit < 0:
                continue

            # Calculate gradient: g = (1 - code_bit - sigmoid) * learning_rate
            g = (1.0 - float(code_bit) - sigmoid_val) * learning_rate

            # Propagate errors output -> hidden
            for i in range(emb_dim):
                neu1e[i] += g * syn1[node_idx, i]

            # Learn weights hidden -> output
            for i in range(emb_dim):
                syn1[node_idx, i] += g * neu1[i]

    # 3. Negative Sampling (if enabled)
    if k > 0:
        # Positive sample: predict center_word
        dot_xy = 0.0
        for i in range(emb_dim):
            dot_xy += neu1[i] * w2[center_word, i]
        s_xdy_m1 = fast_sigmoid(dot_xy, exp_table, exp_table_size, max_exp) - 1.0

        # Update w2[center_word] and accumulate neu1e
        for i in range(emb_dim):
            neu1e[i] += -learning_rate * s_xdy_m1 * w2[center_word, i]
            w2[center_word, i] -= learning_rate * s_xdy_m1 * neu1[i]

        # Negative samples
        for neg_sample in range(0, k):
            rnd = c_random.xoroshiro128p_uniform_float32(random_states, thread_idx)
            q_idx: int = int(math.floor(negs_arr_len * rnd))
            neg = negsample_array[q_idx]
            dot_xq = 0.0
            for i in range(emb_dim):
                dot_xq += neu1[i] * w2[neg, i]
            s_dxq = fast_sigmoid(dot_xq, exp_table, exp_table_size, max_exp)

            # Update w2[neg] and accumulate neu1e
            for i in range(emb_dim):
                neu1e[i] -= learning_rate * s_dxq * w2[neg, i]
                w2[neg, i] -= learning_rate * s_dxq * neu1[i]

    # 4. Backprop neu1e to all context words
    # Note: Original code does NOT use gradient clipping, only early skip
    # Gradient clipping may reduce training effectiveness
    # Update context word vectors (same as original code)
    for ctx_idx in range(context_count):
        for i in range(emb_dim):
            w1[context_words[ctx_idx], i] += neu1e[i]


def train_cbow(
        data_path: str,
        out_file_path: str,
        epochs: int,
        embed_dim: int = 100,
        min_occurs: int = 3,
        c: int = 5,
        k: int = 5,
        t: float = 1e-5,
        vocab_freq_exponent: float = 0.75,
        lr_max: float = 0.025,
        lr_min: float = 0.0025,
        cuda_threads_per_block: int = 32,
        hs: int = 0,
        max_memory_gb: float = 70.0,
        max_words: int = None,
        vocab: list = None,
        w_to_i: dict = None,
        word_counts: list = None,
        ssw: np.ndarray = None,
        negs: np.ndarray = None):
    """
    Train CBOW model
    Based on word2vec.c CBOW implementation from the original source code of the Word2Vec paper

    Args:
        hs: Hierarchical Softmax flag (0=NS only, 1=HS only). Cannot combine with k>0
        k: Negative sampling count (0=HS only, >0=NS only). Cannot combine with hs=1
        max_memory_gb: Maximum GPU memory usage in GB. If estimated memory exceeds this,
                       the dataset will be automatically split into batches for processing
                       Default: 70.0 GB (safe for A100 80GB GPU)

    Raises:
        ValueError: If both hs=1 and k>0 are specified (HS and NS cannot be combined)
    """
    # Validate: HS and NS cannot be used together
    if hs == 1 and k > 0:
        raise ValueError(
            "Error: Cannot use HS (hs=1) and Negative Sampling (k>0) together. "
            "Please choose either HS only (hs=1, k=0) or NS only (hs=0, k>0)"
        )

    params = {
        "model_type": "cbow",
        "w2v_version": W2V_VERSION,
        "data_path": data_path,
        "out_file_path": out_file_path,
        "epochs": epochs,
        "embed_dim": embed_dim,
        "min_occurs": min_occurs,
        "c": c,
        "k": k,
        "t": t,
        "vocab_freq_exponent": vocab_freq_exponent,
        "lr_max": lr_max,
        "lr_min": lr_min,
        "cuda_threads_per_block": cuda_threads_per_block,
        "hs": hs
    }
    stats = {}
    params_path = out_file_path + "_params.json"
    stats_path = out_file_path + "_stats.json"

    seed = 12345

    # Adjust learning rate based on training method
    original_lr_max = lr_max
    original_lr_min = lr_min

    # Learning rate handling: HS only and NS only use the same learning rate
    # (as per word2vec.c original implementation)
    # No special adjustment needed for either method

    # Learning rate schedule
    # For multiple epochs: decrease between epochs
    # For all epochs: decrease LINEARLY within epoch (as per word2vec.c)
    if epochs > 1:
        lr_step = (lr_max - lr_min) / (epochs - 1)
    else:
        lr_step = 0.0  # Not used for single epoch (LR decays within epoch)

    print(f"CBOW Training Parameters:")
    print(f"Seed: {seed}")
    print(f"Window size: {c}")
    if hs == 1:
        print(f"Hierarchical Softmax: Enabled")
    if k > 0:
        print(f"Negative samples: {k}")
    if original_lr_max != lr_max:
        print(f"Learning rate adjusted: {original_lr_max} -> {lr_max} (reduced for stability)")
    if epochs == 1:
        print(f"Learning rate: {lr_max} -> ~0 (will decrease linearly within epoch, as per word2vec.c)")
    else:
        print(f"Learning rate: {lr_max} -> {lr_min} (step: {lr_step:.6f} between epochs, also decreases linearly within each epoch)")
    print(f"Embedding dimension: {embed_dim}")
    print(f"Min word count: {min_occurs}")

    # Start timing for total execution
    start = time.time()

    # Build vocabulary if not provided (for reuse when training both models)
    if vocab is None or w_to_i is None or word_counts is None:
        print(f"\nBuilding vocabulary from: {data_path}")
        vocab_start = time.time()
        vocab, w_to_i, word_counts = handle_vocab(data_path, min_occurs, freq_exponent=vocab_freq_exponent, use_cache=True)
        vocab_size = len(vocab)
        build_time = time.time() - vocab_start
        print(f"Vocabulary {'loaded from cache' if build_time < 1.0 else 'built'} in {build_time:.2f}s. Vocab size: {vocab_size:,}")
    else:
        vocab_size = len(vocab)
        print(f"\nUsing pre-built vocabulary. Vocab size: {vocab_size:,}")

    # Build subsampling weights and negative sampling array if not provided
    if ssw is None or negs is None:
        ssw, negs = get_subsampling_weights_and_negative_sampling_array(vocab, t=t)

    # Create exp table
    print("Creating exp table for fast sigmoid...")
    exp_table = create_exp_table(EXP_TABLE_SIZE, MAX_EXP)

    # Setup Hierarchical Softmax if enabled
    use_hs = (hs == 1)
    syn1_cuda = None
    codes_array_cuda = None
    points_array_cuda = None
    code_lengths_cuda = None

    if use_hs:
        print("Creating Huffman tree for Hierarchical Softmax...")
        hs_start = time.time()
        codes_array, points_array, code_lengths = create_huffman_tree(word_counts, MAX_CODE_LENGTH)
        syn1 = init_hs_weight_matrix(vocab_size, embed_dim)
        print(f"Huffman tree created in {time.time() - hs_start:.2f}s")
        print(f" -Codes array shape: {codes_array.shape}")
        print(f" -Points array shape: {points_array.shape}")
        print(f" -Syn1 matrix shape: {syn1.shape}")

    data_files = get_data_file_names(data_path, seed=seed)
    print(f"Processing {len(data_files)} data files...")
    if max_words is not None:
        print(f"⚠️ Limiting to {max_words:,} total words (will stop early if reached)")
    inps_, offs_, lens_ = read_all_data_files_ever(data_path, data_files, w_to_i, max_words=max_words)
    inps, offs, lens = (np.asarray(inps_, dtype=np.int32),
                       np.asarray(offs_, dtype=np.int32),
                       np.asarray(lens_, dtype=np.int32))
    sentence_count = len(lens)
    total_words = len(inps)  # Total words for LR decay calculation

    print(f"Data loaded: {sentence_count:,} sentences, {total_words:,} total words")

    # Initialize weight matrices
    data_init_start = time.time()
    w1, w2 = init_weight_matrices(vocab_size, embed_dim, seed=seed)
    data_size_weights = 4 * (w1.size + w2.size)
    data_size_inputs = 4 * (inps.size + offs.size + lens.size + ssw.size + negs.size)

    # Calculate memory usage and determine batch size
    weights_gb = data_size_weights / (1024**3)
    inputs_gb = data_size_inputs / (1024**3)

    # Estimate calc_aux memory for full dataset
    calc_aux_size_full = sentence_count * embed_dim * 4
    calc_aux_gb_full = calc_aux_size_full / (1024**3)
    total_memory_gb = weights_gb + inputs_gb + calc_aux_gb_full

    # Determine if batch processing is needed
    use_batch_processing = (total_memory_gb > max_memory_gb)

    if use_batch_processing:
        # Calculate batch size based on available memory
        available_memory_gb = max_memory_gb - weights_gb - inputs_gb
        # Reserve 5GB for overhead
        available_memory_gb = max(1.0, available_memory_gb - 5.0)

        # Calculate max sentences per batch
        bytes_per_sentence = embed_dim * 4  # float32
        max_batch_sentences = int((available_memory_gb * 1024**3) / bytes_per_sentence)

        # Round down to nice numbers for better performance
        if max_batch_sentences >= 10_000_000:
            batch_size = 10_000_000
        elif max_batch_sentences >= 5_000_000:
            batch_size = 5_000_000
        elif max_batch_sentences >= 2_000_000:
            batch_size = 2_000_000
        elif max_batch_sentences >= 1_000_000:
            batch_size = 1_000_000
        else:
            batch_size = max(100_000, max_batch_sentences)

        num_batches = math.ceil(sentence_count / batch_size)
        batch_aux_gb = (batch_size * embed_dim * 4) / (1024**3)
        batch_total_gb = weights_gb + inputs_gb + batch_aux_gb

        print(f"\n⚠️ Memory usage would be {total_memory_gb:.1f} GB (exceeds {max_memory_gb} GB limit)")
        print(f"Using batch processing: {num_batches} batches, {batch_size:,} sentences/batch")
        print(f"Memory per batch: {batch_total_gb:.1f} GB (calc_aux: {batch_aux_gb:.1f} GB)")
    else:
        batch_size = sentence_count
        num_batches = 1
        print(f"\n✅ Memory usage: {total_memory_gb:.1f} GB (within {max_memory_gb} GB limit)")
        print(f"Processing all {sentence_count:,} sentences in one batch")

    blocks: int = math.ceil(batch_size / cuda_threads_per_block)
    print(f"CUDA config: {cuda_threads_per_block} threads/block, {blocks} blocks per batch")

    # Transfer to GPU - Transfer weights and vocab arrays (these are shared across batches)
    print("Transferring data to GPU...")
    data_transfer_start = time.time()
    ssw_cuda, negs_cuda = cuda.to_device(ssw), cuda.to_device(negs)
    w1_cuda, w2_cuda = cuda.to_device(w1), cuda.to_device(w2)
    exp_table_cuda = cuda.to_device(exp_table)

    # Keep input arrays on CPU - will slice and transfer per batch
    # This saves GPU memory

    if use_hs:
        syn1_cuda = cuda.to_device(syn1)
        codes_array_cuda = cuda.to_device(codes_array)
        points_array_cuda = cuda.to_device(points_array)
        code_lengths_cuda = cuda.to_device(code_lengths)

    print(f"Data transfer completed in {time.time()-data_transfer_start:.2f}s")

    stats["sentence_count"] = len(lens)
    stats["word_count"] = len(inps)
    stats["vocab_size"] = vocab_size
    stats["approx_data_size_weights"] = data_size_weights
    stats["approx_data_size_inputs"] = data_size_inputs
    stats["use_batch_processing"] = use_batch_processing
    if use_batch_processing:
        stats["batch_size"] = batch_size
        stats["num_batches"] = num_batches
        batch_aux_size = batch_size * embed_dim * 4
        stats["approx_data_size_aux_per_batch"] = batch_aux_size
        stats["approx_data_size_total"] = data_size_weights + data_size_inputs + batch_aux_size
    else:
        data_size_aux = 4 * (sentence_count * embed_dim)
        stats["approx_data_size_aux"] = data_size_aux
        stats["approx_data_size_total"] = data_size_weights + data_size_inputs + data_size_aux

    # Prepare HS parameters (use dummy arrays if HS disabled)
    if not use_hs:
        # Create dummy arrays for HS (will not be used, but needed for kernel signature)
        dummy_syn1 = cuda.device_array((1, embed_dim), dtype=np.float32)
        dummy_codes = cuda.device_array((vocab_size, MAX_CODE_LENGTH), dtype=np.int32)
        dummy_points = cuda.device_array((vocab_size, MAX_CODE_LENGTH), dtype=np.int32)
        dummy_lengths = cuda.device_array(vocab_size, dtype=np.int32)
        syn1_param = dummy_syn1
        codes_param = dummy_codes
        points_param = dummy_points
        lengths_param = dummy_lengths
    else:
        syn1_param = syn1_cuda
        codes_param = codes_array_cuda
        points_param = points_array_cuda
        lengths_param = code_lengths_cuda

    print_norms(w1_cuda)
    print(f"\nStarting CBOW training - {epochs} epochs...")
    epoch_times = []
    calc_start = time.time()

    # Track total words processed across all epochs (as per word2vec.c)
    # Learning rate decays based on total words processed, not per epoch
    # Use int64 to avoid overflow with large datasets and multiple epochs
    words_processed_total = np.int64(0)
    total_words_for_training = np.int64(epochs) * np.int64(total_words)

    for epoch in range(0, epochs):
        epoch_start = time.time()

        # Process each batch
        for batch_idx in range(num_batches):
            batch_start = batch_idx * batch_size
            batch_end = min((batch_idx + 1) * batch_size, sentence_count)
            batch_sentence_count = batch_end - batch_start

            if num_batches > 1:
                print(f"  Epoch {epoch+1}, Batch {batch_idx+1}/{num_batches}: sentences {batch_start:,}-{batch_end:,}")

            # Calculate word offset for this batch (offsets are cumulative)
            batch_word_start = offs[batch_start] if batch_start < len(offs) else 0
            batch_word_end = offs[batch_end] if batch_end < len(offs) else len(inps)
            batch_word_count = batch_word_end - batch_word_start

            # Calculate learning rate for this batch (linear decay as per word2vec.c)
            # Formula from word2vec.c: alpha = starting_alpha * (1 - word_count_actual / (iter * train_words + 1))
            # word_count_actual is total words processed across all epochs
            # This ensures LR decreases linearly from lr_max to ~0 over entire training
            denominator = total_words_for_training + 1
            current_lr = lr_max * (1.0 - words_processed_total / denominator) if denominator > 0 else lr_max

            # Apply minimum threshold (as per word2vec.c: min = starting_alpha * 0.0001)
            min_lr_threshold = lr_max * 0.0001
            current_lr = max(current_lr, min_lr_threshold)

            # Also apply lr_min as additional constraint (for multi-epoch training)
            if epochs > 1:
                current_lr = max(current_lr, lr_min)

            if num_batches > 1 and batch_idx == 0:
                print(f"    Learning rate: {current_lr:.6f} (decaying linearly, progress: {words_processed_total/total_words_for_training*100:.1f}%)")

            # Create batch arrays (slicing from CPU arrays)
            batch_lens = lens[batch_start:batch_end]
            batch_offs_local = offs[batch_start:batch_end] - batch_word_start  # Adjust offsets to start from 0
            batch_inps_local = inps[batch_word_start:batch_word_end]

            # Transfer batch arrays to GPU
            batch_lens_cuda = cuda.to_device(batch_lens)
            batch_offs_cuda = cuda.to_device(batch_offs_local)
            batch_inps_cuda = cuda.to_device(batch_inps_local)

            # Create calc_aux for this batch
            batch_calc_aux = np.zeros((batch_sentence_count, embed_dim), dtype=np.float32)
            batch_calc_aux_cuda = cuda.to_device(batch_calc_aux)

            # Create random states for this batch
            batch_random_states_cuda = c_random.create_xoroshiro128p_states(
                batch_sentence_count, seed=seed + epoch * 10000 + batch_idx * 100
            )

            # Launch CUDA kernel for this batch with current learning rate
            batch_blocks = math.ceil(batch_sentence_count / cuda_threads_per_block)
            calc_cbow[batch_blocks, cuda_threads_per_block](
                batch_sentence_count, c, k, current_lr, w1_cuda, w2_cuda, batch_calc_aux_cuda,
                batch_random_states_cuda, ssw_cuda, negs_cuda, batch_inps_cuda,
                batch_offs_cuda, batch_lens_cuda,
                use_hs, syn1_param, codes_param, points_param, lengths_param,
                exp_table_cuda, EXP_TABLE_SIZE, MAX_EXP)

            # Update total words processed counter (as per word2vec.c)
            # Note: Actual words processed may vary due to subsampling, but this is an approximation
            # Use int64 to avoid overflow with large datasets and multiple epochs
            words_processed_total = np.int64(words_processed_total) + np.int64(batch_word_count)

            # Free batch arrays from GPU memory
            del batch_lens_cuda, batch_offs_cuda, batch_inps_cuda, batch_calc_aux_cuda, batch_random_states_cuda

        # Synchronize after all batches
        sync_start = time.time()
        cuda.synchronize()
        epoch_time = time.time() - epoch_start
        epoch_times.append(epoch_time)

        # Final LR after epoch (using same formula as word2vec.c)
        denominator = total_words_for_training + 1
        final_lr = lr_max * (1.0 - words_processed_total / denominator) if denominator > 0 else lr_max
        final_lr = max(final_lr, lr_max * 0.0001)
        if epochs > 1:
            final_lr = max(final_lr, lr_min)

        progress_percent = (words_processed_total / total_words_for_training * 100) if total_words_for_training > 0 else 0.0
        print(f"  Epoch {epoch+1} completed in {epoch_time:.2f}s (LR: {final_lr:.6f}, Progress: {progress_percent:.1f}%)")

    print(f"\nCBOW training completed!")
    print(f"Epoch times - Min: {min(epoch_times):.2f}s, Avg: {np.mean(epoch_times):.2f}s, Max: {max(epoch_times):.2f}s")
    print(f"Total training time: {time.time()-calc_start:.2f}s")
    print(f"Total time: {time.time()-start:.2f}s")

    print_norms(w1_cuda)

    # Save results
    stats["epoch_time_min_seconds"] = min(epoch_times)
    stats["epoch_time_avg_seconds"] = np.mean(epoch_times)
    stats["epoch_time_max_seconds"] = max(epoch_times)
    stats["epoch_time_total_seconds"] = sum(epoch_times)
    stats["epoch_times_all_seconds"] = epoch_times

    print(f"Saving CBOW vectors to: {out_file_path}")
    write_vectors(w1_cuda, vocab, out_file_path)

    print(f"Saving parameters to: {params_path}")
    write_json(params, params_path)

    print(f"Saving statistics to: {stats_path}")
    write_json(stats, stats_path)

    print("CBOW training completed successfully!")


# VI. **Evaluation**

This cell contains functions for evaluating trained Word2Vec embeddings using standard benchmarks. It provides word analogy testing, similarity analysis, and model comparison capabilities.

## **Output Format**

- Accuracy metrics: Float values (0.0 to 1.0)
- Detailed results: List of dictionaries with correct/incorrect examples per category
- Comparison JSON: Includes accuracy, training statistics, and summary metrics

In [None]:
import json
import os
import time
import re
from typing import List, Tuple, Dict, Any
import requests
from gensim.models import KeyedVectors


def download_questions_words(output_path: str = "./data/questions-words.txt") -> str:
    """
    Download questions-words.txt for word analogy test
    """
    if os.path.isfile(output_path):
        print(f"Questions-words.txt already exists at: {output_path}")
        return output_path

    url = "https://raw.githubusercontent.com/nicholas-leonard/word2vec/master/questions-words.txt"
    print(f"Downloading questions-words.txt from {url}...")

    with requests.get(url, stream=True) as response:
        response.raise_for_status()
        with open(output_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)

    print(f"Questions-words.txt downloaded to: {output_path}")
    return output_path


def word_analogy_test(vectors_path: str, questions_path: str = None) -> Tuple[dict, List[Dict]]:
    """
    Run word analogy test on trained vectors
    Returns dictionary containing:
        - semantic_accuracy
        - syntactic_accuracy
        - total_accuracy
    And details_by_category (list)
    """

    if questions_path is None:
        questions_path = download_questions_words()

    print(f"Loading vectors from: {vectors_path}")
    start = time.time()
    vecs = KeyedVectors.load_word2vec_format(vectors_path, binary=False)
    print(f"Vectors loaded in {time.time() - start:.2f}s")

    print(f"Running word analogy test with: {questions_path}")
    eval_start = time.time()
    overall_acc, details = vecs.evaluate_word_analogies(questions_path, case_insensitive=True)
    print(f"Word analogy test completed in {time.time() - eval_start:.2f}s")

    semantic_correct = 0
    semantic_total = 0

    syntactic_correct = 0
    syntactic_total = 0

    for cat in details:
        correct = len(cat["correct"])
        total = correct + len(cat["incorrect"])

        # Classify semantic vs syntactic based on categories in questions-words.txt
        # Semantic (5 categories): capital-common-countries, capital-world, currency, city-in-state, family
        # Syntactic (9 categories): gram1-9 (adjective-to-adverb, opposite, comparative, superlative, etc.)
        section = cat["section"].lower()

        # Semantic categories keywords (5 categories from questions-words.txt)
        # 1. capital-common-countries, capital-world -> "capital"
        # 2. currency -> "currency"
        # 3. city-in-state -> "city-in-state"
        # 4. family -> "family"
        semantic_keywords = ["capital", "currency", "family", "city-in-state"]
        is_semantic = any(keyword in section for keyword in semantic_keywords)

        if is_semantic:
            semantic_correct += correct
            semantic_total += total
        else:
            # Syntactic categories (all remaining categories, usually starting with "gram")
            syntactic_correct += correct
            syntactic_total += total

    semantic_acc = semantic_correct / semantic_total if semantic_total > 0 else 0
    syntactic_acc = syntactic_correct / syntactic_total if syntactic_total > 0 else 0

    # Total overall accuracy
    total_acc = (
        (semantic_correct + syntactic_correct) /
        (semantic_total + syntactic_total)
        if (semantic_total + syntactic_total) > 0 else 0
    )
    return (
        {
            "semantic_accuracy": semantic_acc,
            "syntactic_accuracy": syntactic_acc,
            "total_accuracy": total_acc
        },
        details
    )

def similarity_test(vectors_path: str, test_words: List[str] = None) -> Dict[str, Any]:
    """
    Test word similarity and find most similar words
    """
    if test_words is None:
        test_words = ["king", "queen", "man", "woman", "computer", "science", "university", "student"]

    print(f"Loading vectors for similarity test: {vectors_path}")
    vecs = KeyedVectors.load_word2vec_format(vectors_path, binary=False)

    results = {}

    print("\nMost similar words:")
    for word in test_words:
        if word in vecs:
            similar = vecs.most_similar(word, topn=5)
            results[word] = similar
            print(f"\n{word}:")
            for sim_word, score in similar:
                print(f"  {sim_word}: {score:.4f}")
        else:
            print(f"Word '{word}' not found in vocabulary")
            results[word] = []

    # Test some word pairs for similarity
    word_pairs = [
        ("king", "queen"),
        ("man", "woman"),
        ("computer", "science"),
        ("university", "student"),
        ("good", "bad"),
        ("big", "small")
    ]

    print("\nWord pair similarities:")
    pair_similarities = {}
    for word1, word2 in word_pairs:
        if word1 in vecs and word2 in vecs:
            similarity = vecs.similarity(word1, word2)
            pair_similarities[f"{word1}-{word2}"] = similarity
            print(f"  {word1} - {word2}: {similarity:.4f}")
        else:
            print(f"  {word1} - {word2}: One or both words not found")
            pair_similarities[f"{word1}-{word2}"] = None

    results["pair_similarities"] = pair_similarities
    return results


def save_evaluation_results(results: Dict[str, Any], output_path: str):
    """
    Save evaluation results to JSON file
    """
    import numpy as np

    def convert_numpy_types(obj):
        """
        Convert numpy types to Python native types for JSON serialization
        """
        if isinstance(obj, np.float32):
            return float(obj)
        elif isinstance(obj, np.float64):
            return float(obj)
        elif isinstance(obj, np.int32):
            return int(obj)
        elif isinstance(obj, np.int64):
            return int(obj)
        elif isinstance(obj, dict):
            return {key: convert_numpy_types(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [convert_numpy_types(item) for item in obj]
        else:
            return obj

    # Convert numpy types to Python native types
    results_converted = convert_numpy_types(results)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(results_converted, f, indent=2, ensure_ascii=False)
    print(f"Evaluation results saved to: {output_path}")


def compare_models(skipgram_path: str, cbow_path: str, output_path: str = "./output/model_comparison.json",
                   sg_acc: float = None, sg_details: List[Dict] = None,
                   cbow_acc: float = None, cbow_details: List[Dict] = None):
    """
    Compare Skip-gram and CBOW models

    Args:
        skipgram_path: Path to Skip-gram vectors
        cbow_path: Path to CBOW vectors
        output_path: Output path for comparison JSON
        sg_acc: Pre-computed Skip-gram accuracy (optional, will compute if None)
        sg_details: Pre-computed Skip-gram evaluation details (optional)
        cbow_acc: Pre-computed CBOW accuracy (optional, will compute if None)
        cbow_details: Pre-computed CBOW evaluation details (optional)
    """
    print("Comparing Skip-gram vs CBOW models...")

    # Evaluate both models only if not provided
    if sg_acc is None or sg_details is None:
        print("Evaluating Skip-gram model...")
        sg_result, sg_details = word_analogy_test(skipgram_path)
        sg_acc = sg_result["total_accuracy"]  # Extract total accuracy from dict
    else:
        print("Using pre-computed Skip-gram accuracy...")

    if cbow_acc is None or cbow_details is None:
        print("Evaluating CBOW model...")
        cbow_result, cbow_details = word_analogy_test(cbow_path)
        cbow_acc = cbow_result["total_accuracy"]  # Extract total accuracy from dict
    else:
        print("Using pre-computed CBOW accuracy...")

    # Load statistics
    sg_stats_path = skipgram_path + "_stats.json"
    cbow_stats_path = cbow_path + "_stats.json"

    sg_stats = {}
    cbow_stats = {}

    if os.path.isfile(sg_stats_path):
        with open(sg_stats_path, "r") as f:
            sg_stats = json.load(f)

    if os.path.isfile(cbow_stats_path):
        with open(cbow_stats_path, "r") as f:
            cbow_stats = json.load(f)

    comparison = {
        "models": {
            "skipgram": {
                "accuracy": sg_acc,
                "details": sg_details,
                "stats": sg_stats
            },
            "cbow": {
                "accuracy": cbow_acc,
                "details": cbow_details,
                "stats": cbow_stats
            }
        },
        "summary": {
            "skipgram_accuracy": sg_acc,
            "cbow_accuracy": cbow_acc,
            "accuracy_difference": sg_acc - cbow_acc,
            "skipgram_training_time": sg_stats.get("epoch_time_total_seconds", 0),
            "cbow_training_time": cbow_stats.get("epoch_time_total_seconds", 0),
            "time_difference": sg_stats.get("epoch_time_total_seconds", 0) - cbow_stats.get("epoch_time_total_seconds", 0)
        }
    }

    save_evaluation_results(comparison, output_path)

    print(f"\nModel Comparison Summary:")
    print(f"Skip-gram accuracy: {sg_acc:.4f} ({sg_acc*100:.2f}%)")
    print(f"CBOW accuracy: {cbow_acc:.4f} ({cbow_acc*100:.2f}%)")
    print(f"Difference: {sg_acc - cbow_acc:.4f} ({(sg_acc - cbow_acc)*100:.2f}%)")

    if sg_stats and cbow_stats:
        sg_time = sg_stats.get("epoch_time_total_seconds", 0)
        cbow_time = cbow_stats.get("epoch_time_total_seconds", 0)
        print(f"Skip-gram training time: {sg_time:.2f}s")
        print(f"CBOW training time: {cbow_time:.2f}s")
        print(f"Time difference: {sg_time - cbow_time:.2f}s")

    return comparison


# VII. **Main Pipeline**

This cell contains the main execution pipeline that orchestrates the complete Word2Vec training workflow: dataset download/preprocessing, model training, evaluation, and comparison.

## **Configuration Parameters**

All configuration is set at the top of the cell. Modify these values to customize the training pipeline:

### **Dataset Selection**
- **`use_wmt14`** (default: `False`): 
  - `False` = Use Text8 dataset (smaller, faster)
  - `True` = Use WMT14/WMT15 News Crawl (larger, higher quality)
- **`dataset_name`** (default: `"Text8"`): Display name for the dataset

### **Dataset Size Limits** (only for WMT14)
- **`max_sentences`** (default: `None`): Maximum number of sentences to process (`None` = all)
- **`max_files`** (default: `None`): Maximum number of files to create (`None` = all)
- **`max_words`** (default: `None`): Maximum total words for training (`None` = no limit, e.g., `700000000` for 700M words)

### **Training Method**
- **`use_hs_only`** (default: `True`):
  - `True` = Hierarchical Softmax only (HS=1, k=0)
  - `False` = Negative Sampling only (HS=0, k=5)

### **Model Selection**
- **`should_train_skipgram`** (default: `True`): Train Skip-gram model
- **`should_train_cbow`** (default: `True`): Train CBOW model

### **Phrase Detection**
- **`use_phrases`** (default: `False`): Enable phrase detection (combines frequent bigrams like "new_york")

## **Default Training Parameters**

The pipeline uses the following default training parameters (defined in `base_params`):
- **`epochs`**: 10
- **`embed_dim`**: 300
- **`min_occurs`**: 5 (minimum word count threshold)
- **`c`**: 5 (context window size)
- **`k`**: 0 if `use_hs_only=True`, else 5 (negative samples)
- **`t`**: 1e-5 (subsampling threshold)
- **`vocab_freq_exponent`**: 0.75 (frequency biasing for negative sampling)
- **`lr_max`**: 0.025 (maximum learning rate)
- **`lr_min`**: 0.025 if epochs=1, else 0.0001 (minimum learning rate)
- **`cuda_threads_per_block`**: 32 (optimized for free tier GPU on Google Colab like GPU T4. Can change to higher value like 512 for GPU A100 (paid GPU on Google Colab) )
- **`hs`**: 1 if `use_hs_only=True`, else 0
- **`max_words`**: Uses value from configuration above

## **Pipeline Steps**

1. **Download & Preprocessing**: Downloads dataset and preprocesses into sentence files
2. **Build Shared Vocabulary** (if training both models): Builds vocabulary once for reuse
3. **Train Skip-gram** (if enabled): Trains Skip-gram model with specified parameters
4. **Train CBOW** (if enabled): Trains CBOW model with specified parameters
5. **Evaluate Skip-gram**: Runs word analogy test and similarity analysis
6. **Evaluate CBOW**: Runs word analogy test and similarity analysis
7. **Compare Models**: Side-by-side comparison if both models were trained

## **Output Files**

All results are saved to `./output/` directory:
- **Vectors**: `vectors_skipgram`, `vectors_cbow` (word2vec format)
- **Evaluations**: `skipgram_eval.json`, `cbow_eval.json`
- **Statistics**: `vectors_skipgram_stats.json`, `vectors_cbow_stats.json`
- **Comparison**: `model_comparison.json` (if both models trained)

## **Notes**

- Vocabulary is cached and reused when training both models (saves time)
- If both models are trained, vocabulary is built once and shared
- Learning rate schedule: Linear decay based on total words processed (matches word2vec.c)
- All steps can be skipped individually by setting corresponding flags to `False`

In [None]:
# ============================================
# CONFIGURATION - Change these values
# ============================================
# Dataset selection
use_wmt14 = False     # True to use WMT14 News
dataset_name = "Text8" # "Text8" or "WMT14 News"

# Dataset size (only for WMT14)
max_sentences = None   # None = full dataset, or number like 100000
max_files = None       # None = all files, or number like 10
max_words = None       # None = no limit, or number like 700000000 for 700M words

# Training method
use_hs_only = True    # True = HS only (HS=1, k=0), False = NS only (HS=0, k=5)

# Model selection
should_train_skipgram = True  # True to train Skip-gram
should_train_cbow = True      # True to train CBOW

# Phrase detection
use_phrases = False    # True to enable phrase detection

import os
import sys
from pathlib import Path

def print_section_header(title: str):
    """
    Print formatted section header
    """
    print(f"\n{'='*60}")
    print(f"  {title}")
    print(f"{'='*60}")


def print_summary(sg_acc: float, cbow_acc: float, sg_stats: dict, cbow_stats: dict,
                 sg_sem: float = None, sg_syn: float = None,
                 cbow_sem: float = None, cbow_syn: float = None):
    """
    Print final summary of results
    """
    print_section_header("FINAL SUMMARY")

    print(f"Model Performance:")
    if sg_acc is not None:
        print(f"Skip-gram accuracy: {sg_acc:.4f} ({sg_acc*100:.2f}%)")
        if sg_sem is not None and sg_syn is not None:
            print(f" -Semantic:  {sg_sem:.4f} ({sg_sem*100:.2f}%)")
            print(f" -Syntactic: {sg_syn:.4f} ({sg_syn*100:.2f}%)")
    if cbow_acc is not None:
        print(f"CBOW accuracy: {cbow_acc:.4f} ({cbow_acc*100:.2f}%)")
        if cbow_sem is not None and cbow_syn is not None:
            print(f" -Semantic:  {cbow_sem:.4f} ({cbow_sem*100:.2f}%)")
            print(f" -Syntactic: {cbow_syn:.4f} ({cbow_syn*100:.2f}%)")
    if sg_acc is not None and cbow_acc is not None:
        print(f"Difference: {sg_acc - cbow_acc:.4f} ({(sg_acc - cbow_acc)*100:+.2f}%)")

    has_stats = sg_stats or cbow_stats
    if has_stats:
        print(f"\nTraining Times:")
        if sg_stats:
            sg_time = sg_stats.get('epoch_time_total_seconds', 0)
            print(f" -Skip-gram: {sg_time:.2f}s")
        if cbow_stats:
            cbow_time = cbow_stats.get('epoch_time_total_seconds', 0)
            print(f" -CBOW: {cbow_time:.2f}s")
        if sg_stats and cbow_stats:
            sg_time = sg_stats.get('epoch_time_total_seconds', 0)
            cbow_time = cbow_stats.get('epoch_time_total_seconds', 0)
            print(f"Difference: {sg_time - cbow_time:.2f}s")

        print(f"\nData Processed:")
        stats = sg_stats if sg_stats else cbow_stats
        if stats:
            words = stats.get('word_count', 0)
            print(f" -Words: {words:,}")
            print(f" -Sentences: {stats.get('sentence_count', 0):,}")
            print(f" -Vocabulary: {stats.get('vocab_size', 0):,}")

    print(f"\nOutput Files:")
    if sg_acc is not None:
        print(f" -Skip-gram vectors: ./output/vectors_skipgram")
        print(f" -Skip-gram evaluation: ./output/skipgram_eval.json")
        print(f" -Skip-gram statistics: ./output/vectors_skipgram_stats.json")
    if cbow_acc is not None:
        print(f" -CBOW vectors: ./output/vectors_cbow")
        print(f" -CBOW evaluation: ./output/cbow_eval.json")
        print(f" -CBOW statistics: ./output/vectors_cbow_stats.json")


def main():
    # Print configuration
    print(f"\nDataset: {dataset_name}")
    if use_wmt14:
        print(" -WMT14/WMT15 News Crawl (combines WMT14 2012 + WMT15 2014)")
        print(" -Higher quality news articles")
        if max_words:
            print(f" -Limited to {max_words:,} words ({max_words/1e6:.1f}M words)")
        elif max_sentences:
            print(f" -Limited to {max_sentences:,} sentences")
    else:
        print(" -Text8")
        print(" -Smaller, faster to download and process")

    if use_hs_only:
        print("Training: Hierarchical Softmax ONLY (HS=1, k=0)")
    else:
        print("Training: Negative Sampling ONLY (HS=0, k=5)")

    if not should_train_skipgram:
        print("Skip-gram training: Disabled")
    if not should_train_cbow:
        print("CBOW training: Disabled")

    if use_phrases:
        print("Phrase detection: Enabled")

    print_section_header(f"STEP 1: DOWNLOADING & PREPROCESSING {dataset_name.upper()}")
    data_dir = "./data"

    if use_phrases:
        print("Phrase detection: Enabled (will combine frequent bigrams)")

    if use_wmt14:
        news_file = download_wmt14_news(data_dir)
        processed_dir = preprocess_wmt14_news(news_file, "./data/wmt14_processed",
                                            max_sentences=max_sentences, max_files=max_files,
                                            use_phrases=use_phrases)
    else:
        text8_file = download_text8(data_dir)
        processed_dir = preprocess_text8(text8_file, "./data/text8_processed",
                                        use_phrases=use_phrases)

    # Prepare training parameters (used by both models)
    epochs_value = 10  # Set epochs here for consistency
    base_params = {
        "epochs": epochs_value,
        "embed_dim": 300,
        "min_occurs": 5,
        "c": 5,
        "k": 0 if use_hs_only else 5,
        "t": 1e-5,
        "vocab_freq_exponent": 0.75,
        "lr_max": 0.025,
        # For 1 epoch with large dataset, keep learning rate high (as in paper)
        "lr_min": 0.025 if epochs_value == 1 else 0.0001,
        "cuda_threads_per_block": 32,  # Optimized for A100 GPU
        "hs": 1 if use_hs_only else 0,
        "max_words": max_words  # Limit total words for training (None = no limit)
    }

    # Build vocabulary once if training both models (to save time)
    shared_vocab = None
    shared_w_to_i = None
    shared_word_counts = None
    shared_ssw = None
    shared_negs = None

    if should_train_skipgram and should_train_cbow:
        print_section_header("STEP 2: BUILDING SHARED VOCABULARY")
        print("Building vocabulary once for both Skip-gram and CBOW models")
        print("Vocabulary will be cached for future runs (even with different epochs/dim)")
        import time
        start = time.time()
        shared_vocab, shared_w_to_i, shared_word_counts = handle_vocab(
            processed_dir, base_params["min_occurs"], freq_exponent=base_params["vocab_freq_exponent"], use_cache=True
        )
        shared_ssw, shared_negs = get_subsampling_weights_and_negative_sampling_array(shared_vocab, t=base_params["t"])
        vocab_size = len(shared_vocab)
        build_time = time.time() - start
        print(f"✅ Vocabulary {'loaded from cache' if build_time < 1.0 else 'built'} in {build_time:.2f}s. Vocab size: {vocab_size:,}")
        print(f"✅ Vocabulary will be reused for both models\n")

    # 2/3. Train Skip-gram (if selected)
    if should_train_skipgram:
        step_num = 3 if (should_train_skipgram and should_train_cbow) else 2
        print_section_header(f"STEP {step_num}: TRAINING SKIP-GRAM MODEL")
        skipgram_params = base_params.copy()

        if epochs_value == 1:
            print("Using 1 epoch: Learning rate will be kept constant at 0.025 (as per paper)")

        print("Skip-gram parameters:")
        for key, value in skipgram_params.items():
            print(f" {key}: {value}")

        # Validate: HS and NS cannot be used together
        if skipgram_params["hs"] == 1 and skipgram_params["k"] > 0:
            raise ValueError("Error: Cannot use HS (hs=1) and Negative Sampling (k>0) together. Please choose either HS only (hs=1, k=0) or NS only (hs=0, k>0).")

        # Pass shared vocabulary if available
        if shared_vocab is not None:
            train_skipgram(processed_dir, "./output/vectors_skipgram",
                          vocab=shared_vocab, w_to_i=shared_w_to_i, word_counts=shared_word_counts,
                          ssw=shared_ssw, negs=shared_negs, **skipgram_params)
        else:
            train_skipgram(processed_dir, "./output/vectors_skipgram", **skipgram_params)
    else:
        step_num = 3 if (should_train_skipgram and should_train_cbow) else 2
        print_section_header(f"STEP {step_num}: SKIPPING SKIP-GRAM TRAINING")
        print("Skip-gram training skipped as requested")
        skipgram_params = base_params.copy()  # Still need params for CBOW if training both

    # 4. Train CBOW (if selected)
    if should_train_cbow:
        print_section_header("STEP 4: TRAINING CBOW MODEL")
        cbow_params = base_params.copy()
        # CBOW uses same learning rate as Skip-gram (0.025) to prevent gradient explosion
        cbow_params["lr_max"] = 0.025
        # For 1 epoch with large dataset, keep learning rate high (same as Skip-gram)
        cbow_params["lr_min"] = 0.025 if epochs_value == 1 else 0.0001

        if epochs_value == 1:
            print("Using 1 epoch: Learning rate will be kept constant at 0.025 (same as Skip-gram)")

        print("CBOW parameters:")
        for key, value in cbow_params.items():
            print(f"  {key}: {value}")

        # Validate: HS and NS cannot be used together
        if cbow_params["hs"] == 1 and cbow_params["k"] > 0:
            raise ValueError("Error: Cannot use HS (hs=1) and Negative Sampling (k>0) together. Please choose either HS only (hs=1, k=0) or NS only (hs=0, k>0).")

        # Pass shared vocabulary if available
        if shared_vocab is not None:
            train_cbow(processed_dir, "./output/vectors_cbow",
                      vocab=shared_vocab, w_to_i=shared_w_to_i, word_counts=shared_word_counts,
                      ssw=shared_ssw, negs=shared_negs, **cbow_params)
        else:
            train_cbow(processed_dir, "./output/vectors_cbow", **cbow_params)
    else:
        print_section_header("STEP 4: SKIPPING CBOW TRAINING")
        print("CBOW training skipped as requested")

    # 5. Evaluate Skip-gram (if trained)
    sg_result = None
    sg_details = None
    sg_sem = None
    sg_syn = None
    sg_total = None
    sg_acc = None
    sg_sim = None

    if should_train_skipgram:
        print_section_header("STEP 5: EVALUATING SKIP-GRAM MODEL")
        sg_result, sg_details = word_analogy_test("./output/vectors_skipgram")

        sg_sem   = sg_result["semantic_accuracy"]
        sg_syn   = sg_result["syntactic_accuracy"]
        sg_total = sg_result["total_accuracy"]
        sg_acc   = sg_total  # Total accuracy for comparison functions

        sg_sim = similarity_test("./output/vectors_skipgram")

        save_evaluation_results({
            "semantic_accuracy": sg_sem,
            "syntactic_accuracy": sg_syn,
            "total_accuracy": sg_total,
            "details": sg_details,
            "similarity_test": sg_sim
        }, "./output/skipgram_eval.json")
    else:
        print_section_header("STEP 5: SKIPPING SKIP-GRAM EVALUATION")
        print("Skip-gram evaluation skipped (model not trained)")

    # 6. Evaluate CBOW (if trained)
    cbow_result = None
    cbow_details = None
    cbow_sem = None
    cbow_syn = None
    cbow_total = None
    cbow_acc = None
    cbow_sim = None

    if should_train_cbow:
        print_section_header("STEP 6: EVALUATING CBOW MODEL")
        cbow_result, cbow_details = word_analogy_test("./output/vectors_cbow")

        cbow_sem   = cbow_result["semantic_accuracy"]
        cbow_syn   = cbow_result["syntactic_accuracy"]
        cbow_total = cbow_result["total_accuracy"]
        cbow_acc   = cbow_total  # Total accuracy for comparison functions

        cbow_sim = similarity_test("./output/vectors_cbow")

        save_evaluation_results({
            "semantic_accuracy": cbow_sem,
            "syntactic_accuracy": cbow_syn,
            "total_accuracy": cbow_total,
            "details": cbow_details,
            "similarity_test": cbow_sim
        }, "./output/cbow_eval.json")
    else:
        print_section_header("STEP 6: SKIPPING CBOW EVALUATION")
        print("CBOW evaluation skipped (model not trained)")

    # 7. Model Comparison (Custom Skip-gram vs CBOW) - only if both trained
    if should_train_skipgram and should_train_cbow:
        print_section_header("STEP 7: COMPARING CUSTOM MODELS (Skip-gram vs CBOW)")
        # Pass pre-computed accuracy values to avoid re-evaluating
        comparison = compare_models("./output/vectors_skipgram", "./output/vectors_cbow",
                                    sg_acc=sg_acc, sg_details=sg_details,
                                    cbow_acc=cbow_acc, cbow_details=cbow_details)
    else:
        print_section_header("STEP 7: SKIPPING MODEL COMPARISON")
        if should_train_skipgram:
            print("Model comparison skipped (CBOW not trained)")
        elif should_train_cbow:
            print("Model comparison skipped (Skip-gram not trained)")

    # Load statistics for summary
    sg_stats = {}
    cbow_stats = {}

    try:
        import json
        if should_train_skipgram:
            try:
                with open("./output/vectors_skipgram_stats.json", "r") as f:
                    sg_stats = json.load(f)
            except FileNotFoundError:
                pass
        if should_train_cbow:
            try:
                with open("./output/vectors_cbow_stats.json", "r") as f:
                    cbow_stats = json.load(f)
            except FileNotFoundError:
                pass
    except Exception:
        print("Warning: Could not load statistics files")

    # Final Summary
    print_summary(sg_acc, cbow_acc, sg_stats, cbow_stats,
                 sg_sem=sg_sem, sg_syn=sg_syn,
                 cbow_sem=cbow_sem, cbow_syn=cbow_syn)

    print(f"\nWord2Vec training and evaluation completed successfully!")
    print(f"Check the ./output/ directory for all results.")

    print(f"\nDataset used: {dataset_name}")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n\n⚠️ Training interrupted by user.")
        raise
    except Exception as e:
        print(f"\n❌ Error occurred: {e}")
        import traceback
        traceback.print_exc()
        raise
