In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter, defaultdict
import re
from datasets import load_dataset  # Requires 'datasets' library
from sklearn.decomposition import PCA  # Optional, for minimal visualization
import matplotlib.pyplot as plt

# Load IMDb dataset (subset for efficiency)
def load_imdb_data(max_samples=1000):
    dataset = load_dataset("imdb", split="train[:1000]")  # Use train split, first 1000 samples
    texts = dataset["text"][:max_samples]
    labels = dataset["label"][:max_samples]  # 0: negative, 1: positive
    # Split into train/test (80-20)
    train_size = int(0.8 * len(texts))
    train_texts = texts[:train_size]
    test_texts = texts[train_size:]
    train_labels = labels[:train_size]
    test_labels = labels[train_size:]
    return train_texts, train_labels, test_texts, test_labels

# Tokenization functions
def tokenize_whitespace(texts):
    tokenized = []
    for text in texts:
        cleaned = re.sub(r'[^\w\s]', '', text.lower())
        tokens = cleaned.split()
        tokenized.append(tokens[:100])  # Limit tokens per review for efficiency
    return tokenized

def tokenize_char(texts):
    tokenized = []
    for text in texts:
        chars = [c.lower() for c in text if c.isalpha()][:100]  # Limit for efficiency
        tokenized.append(chars)
    return tokenized

def train_bpe(texts, num_merges=10):
    cleaned_texts = [re.sub(r'[^\w\s]', '', text.lower()) for text in texts]
    all_text = ' '.join(cleaned_texts)
    words = [w for w in all_text.split() if w]
    vocab = Counter(words)
    splits = {word: list(word) for word in vocab if word}

    def get_pairs(splits):
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = splits[word]
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    bpe_merges = []
    for _ in range(num_merges):
        pairs = get_pairs(splits)
        if not pairs:
            break
        bigram = max(pairs, key=pairs.get)
        bpe_merges.append(bigram)
        new_splits = {}
        for word, symbols in splits.items():
            new_symbols = []
            i = 0
            while i < len(symbols):
                if i + 1 < len(symbols) and symbols[i] == bigram[0] and symbols[i + 1] == bigram[1]:
                    new_symbol = bigram[0] + bigram[1]
                    new_symbols.append(new_symbol)
                    i += 2
                else:
                    new_symbols.append(symbols[i])
                    i += 1
            new_splits[word] = new_symbols
        splits = new_splits

    def apply_bpe(word):
        symbols = list(word)
        for bigram in bpe_merges:
            i = 0
            while i < len(symbols) - 1:
                if symbols[i] == bigram[0] and symbols[i + 1] == bigram[1]:
                    symbols[i] = bigram[0] + bigram[1]
                    del symbols[i + 1]
                else:
                    i += 1
        return symbols if symbols else [word]

    def tokenize(texts):
        tokenized = []
        for text in texts:
            cleaned = re.sub(r'[^\w\s]', '', text.lower())
            words = cleaned.split()
            tok_sent = []
            for word in words[:100]:  # Limit for efficiency
                if word:
                    tok_sent.extend(apply_bpe(word))
            tokenized.append(tok_sent)
        return tokenized

    return tokenize

# Build vocabulary
def build_vocab(tokenized_corpus):
    flat_tokens = [tok for sent in tokenized_corpus for tok in sent]
    vocab = list(set(flat_tokens))
    token_to_id = {tok: i for i, tok in enumerate(vocab)}
    id_to_token = {i: tok for tok, i in token_to_id.items()}
    return token_to_id, id_to_token, len(vocab)

# Generate training data for Word2Vec
def generate_training_data(tokenized_corpus, token_to_id, window_size=2):
    training_data = []
    for sent in tokenized_corpus:
        sent_ids = [token_to_id[tok] for tok in sent if tok in token_to_id]
        for i in range(len(sent_ids)):
            for j in range(max(0, i - window_size), min(len(sent_ids), i + window_size + 1)):
                if i != j:
                    training_data.append((sent_ids[i], sent_ids[j]))
    return training_data[:5000]  # Limit for efficiency

# Skip-gram Word2Vec model
class SkipGram(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.in_embeddings = nn.Embedding(vocab_size, embed_dim)
        self.out_embeddings = nn.Embedding(vocab_size, embed_dim)

    def forward(self, center_ids):
        centers_emb = self.in_embeddings(center_ids)
        out_emb = self.out_embeddings.weight
        scores = torch.matmul(centers_emb, out_emb.t())
        return scores

# Train Word2Vec
def train_word2vec(training_data, vocab_size, embed_dim=50, epochs=10, lr=0.001):
    model = SkipGram(vocab_size, embed_dim)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        total_loss = 0
        for center, context in training_data:
            center_t = torch.tensor([center], dtype=torch.long)
            context_t = torch.tensor([context], dtype=torch.long)
            optimizer.zero_grad()
            scores = model(center_t)
            loss = criterion(scores, context_t)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Word2Vec Epoch {epoch}, Average loss: {total_loss / len(training_data):.4f}')
    return model.in_embeddings.weight.data.numpy()

# GloVe simple implementation
def train_glove(tokenized_corpus, token_to_id, embed_dim=50, epochs=10, lr=0.01):
    vocab_size = len(token_to_id)
    cooc = np.zeros((vocab_size, vocab_size))
    window = 2
    for sent in tokenized_corpus:
        sent_ids = [token_to_id[tok] for tok in sent if tok in token_to_id]
        for i in range(len(sent_ids)):
            for j in range(max(0, i - window), min(len(sent_ids), i + window + 1)):
                if i != j:
                    cooc[sent_ids[i], sent_ids[j]] += 1.0 / abs(i - j)
    W = np.random.normal(0, 0.01 / embed_dim, (vocab_size, embed_dim))
    Wtilde = np.random.normal(0, 0.01 / embed_dim, (vocab_size, embed_dim))
    b = np.zeros(vocab_size)
    btilde = np.zeros(vocab_size)
    for epoch in range(epochs):
        total_loss = 0
        for i in range(vocab_size):
            for j in range(vocab_size):
                if cooc[i, j] > 0:
                    inner = np.dot(W[i], Wtilde[j]) + b[i] + btilde[j]
                    target = np.log(max(cooc[i, j], 1e-10))
                    error = inner - target
                    weight = 1.0
                    total_loss += weight * error ** 2
                    dWi = weight * error * Wtilde[j]
                    dWtildej = weight * error * W[i]
                    dbi = weight * error
                    dbtildej = weight * error
                    W[i] -= lr * dWi
                    Wtilde[j] -= lr * dWtildej
                    b[i] -= lr * dbi
                    btilde[j] -= lr * dbtildej
        print(f'GloVe Epoch {epoch}, Loss: {total_loss:.4f}')
    embeddings = (W + Wtilde) / 2.0
    return embeddings

# Word relationship analysis (cosine similarity)
def print_word_relationship(embeddings, token_to_id):
    words = ['good', 'bad', 'movie', 'film']
    for word1 in words:
        for word2 in words:
            if word1 != word2 and word1 in token_to_id and word2 in token_to_id:
                id1 = token_to_id[word1]
                id2 = token_to_id[word2]
                emb1 = embeddings[id1]
                emb2 = embeddings[id2]
                sim = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
                print(f"Cosine similarity between '{word1}' and '{word2}': {sim:.4f}")

# Minimal visualization (optional, comment out if sklearn not available)
def visualize_embeddings(embeddings, id_to_token, method='PCA'):
    words = ['good', 'bad', 'movie', 'film', 'great', 'poor']
    indices = [id_to_token[i] for i in range(len(id_to_token)) if id_to_token[i] in words]
    if not indices:
        print("No target words found for visualization.")
        return
    idxs = [token_to_id[word] for word in indices]
    emb_subset = embeddings[idxs]
    pca = PCA(n_components=2)
    emb_2d = pca.fit_transform(emb_subset)
    plt.figure(figsize=(8, 6))
    for i, word in enumerate(indices):
        plt.scatter(emb_2d[i, 0], emb_2d[i, 1])
        plt.annotate(word, (emb_2d[i, 0], emb_2d[i, 1]))
    plt.title('Embedding Visualization (PCA)')
    plt.show()

# Sentence embedding
def get_sentence_embedding(sentence_tokens, token_to_id, embeddings):
    valid_tokens = [tok for tok in sentence_tokens if tok in token_to_id]
    if not valid_tokens:
        return np.zeros(embeddings.shape[1])
    ids = [token_to_id[tok] for tok in valid_tokens]
    emb = embeddings[ids]
    return np.mean(emb, axis=0)

# Simple classifier for downstream task
class Classifier(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

def train_classifier(train_reps, train_labels, embed_dim, epochs=20):
    model = Classifier(embed_dim)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    train_inputs = torch.tensor(train_reps, dtype=torch.float32)
    train_targets = torch.tensor(train_labels, dtype=torch.float32).unsqueeze(1)
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(train_inputs)
        loss = criterion(outputs, train_targets)
        loss.backward()
        optimizer.step()
    return model

def evaluate_classifier(model, test_reps, test_labels):
    test_inputs = torch.tensor(test_reps, dtype=torch.float32)
    with torch.no_grad():
        preds = model(test_inputs).round()
    acc = (preds == torch.tensor(test_labels).unsqueeze(1).float()).float().mean().item()
    return acc

# Main execution
if __name__ == "__main__":
    # Load data
    train_texts, train_labels, test_texts, test_labels = load_imdb_data(max_samples=1000)
    corpus = train_texts  # Use training texts for embedding training

    # Define tokenizers
    tokenizers = {
        'Whitespace': tokenize_whitespace,
        'Character': tokenize_char,
        'BPE': train_bpe(corpus, num_merges=10)
    }

    embed_methods = ['Word2Vec', 'GloVe']
    results = {}

    for tok_name, tokenizer in tokenizers.items():
        print(f"\n--- {tok_name} Tokenization ---")
        tokenized_corpus = tokenizer(corpus)
        token_to_id, id_to_token, vocab_size = build_vocab(tokenized_corpus)
        training_data = generate_training_data(tokenized_corpus, token_to_id)
        print(f"Vocab size: {vocab_size}, Training pairs: {len(training_data)}")

        tokenized_train = tokenizer(train_texts)
        tokenized_test = tokenizer(test_texts)

        method_results = {}
        for method in embed_methods:
            print(f"\nTraining {method}...")
            if method == 'Word2Vec':
                embeddings = train_word2vec(training_data, vocab_size, embed_dim=50)
            else:
                embeddings = train_glove(tokenized_corpus, token_to_id, embed_dim=50)

            # Word relationship analysis
            print(f"\n{method} Word Relationships:")
            print_word_relationship(embeddings, token_to_id)

            # Optional visualization (uncomment to use)
            # visualize_embeddings(embeddings, id_to_token)

            # Downstream task: sentiment classification
            train_reps = np.array([get_sentence_embedding(sent, token_to_id, embeddings) for sent in tokenized_train])
            test_reps = np.array([get_sentence_embedding(sent, token_to_id, embeddings) for sent in tokenized_test])

            classifier = train_classifier(train_reps, train_labels, embeddings.shape[1])
            acc = evaluate_classifier(classifier, test_reps, test_labels)
            method_results[method] = acc
            print(f"{method} downstream accuracy: {acc:.4f}")

        results[tok_name] = method_results

    # Comparative analysis
    print("\nComparative Analysis Summary (Accuracy on Sentiment Classification):")
    for tok, methods in results.items():
        print(f"{tok}: Word2Vec = {methods['Word2Vec']:.4f}, GloVe = {methods['GloVe']:.4f}")

    print("\nAnalysis:")
    print("- Whitespace: Captures whole words, effective for common terms but struggles with OOV.")
    print("- Character: Fine-grained, handles OOV well but may lose semantic context.")
    print("- BPE: Balances word and subword, better for rare words and morphological variations.")
    print("Higher accuracy suggests better representations for sentiment classification.")
