# Parliamentary Speech Segment Embeddings

This notebook calculates segment embeddings for parliamentary datasets (AT, HR, GB) based on the `Text` column grouped by `Segment_ID`.

## Key Features:
- **Multi-language support**: English and native language embeddings for AT and HR, English-only for GB.
- **Google Colab compatibility**: Includes setup for Google Colab.
- **Efficient processing**: GPU optimization for embedding generation.
- **Chunking for long texts**: Handles texts exceeding the token limit.

## Input:
Pre-segmented datasets with `Text` and `Segment_ID` columns.

## Output:
Dataframes with additional columns for segment embeddings:
- `segment_embeddings_english`
- `segment_embeddings_native_language` (for AT and HR only)

In [None]:
# === GOOGLE COLAB SETUP ===
from google.colab import drive
drive.mount('/content/drive')

!pip install sentence-transformers tqdm

import pandas as pd
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

# GPU check
if torch.cuda.is_available():
    print(f"✅ GPU detected: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ No GPU detected. Processing will be slower.")

In [2]:
# === LOAD DATASETS ===
# Update paths to your Google Drive location
base_path = "/content/drive/My Drive/thesis data/"

AT = pd.read_pickle(f"{base_path}/AT/AT_combined_with_segments.pkl")
HR = pd.read_pickle(f"{base_path}/HR/HR_combined_with_segments.pkl")
GB = pd.read_pickle(f"{base_path}/GB/GB_with_segments.pkl")

print(f"✅ Datasets loaded: AT ({AT.shape}), HR ({HR.shape}), GB ({GB.shape})")

In [None]:
# === EMBEDDING FUNCTIONS ===

def embed_long_text(text, model, tokenizer, max_length=8192, overlap=1024):
    """Handle texts longer than the model's max token length."""
    token_ids = tokenizer.encode(text, add_special_tokens=False)
    chunks = []
    starts = list(range(0, len(token_ids), max_length - overlap))
    for start in starts:
        end = min(start + max_length, len(token_ids))
        chunk_ids = token_ids[start:end]
        chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True)
        chunks.append(chunk_text)
    
    chunk_embeddings = model.encode(chunks, batch_size=128, convert_to_tensor=True, show_progress_bar=False)
    return torch.mean(chunk_embeddings, dim=0).cpu().numpy()

def generate_segment_embeddings(df, text_column, segment_id_column, model, batch_size=128, checkpoint_path=None):
    """Generate embeddings for concatenated segment texts with checkpointing."""
    print(f"🔄 Generating embeddings for {text_column} grouped by {segment_id_column}...")
    
    # Concatenate texts within each segment
    segment_texts = df.groupby(segment_id_column)[text_column].apply(lambda x: ' '.join(x.astype(str)))
    segment_ids = segment_texts.index.tolist()
    segment_texts = segment_texts.tolist()
    
    tokenizer = model.tokenizer
    embeddings = []
    checkpoint_data = {}
    
    # Load checkpoint if available
    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint_data = torch.load(checkpoint_path)
        embeddings = checkpoint_data.get("embeddings", [])
        start_idx = checkpoint_data.get("start_idx", 0)
        print(f"⏳ Resuming from checkpoint at index {start_idx}...")
    else:
        start_idx = 0
    
    for i in tqdm(range(start_idx, len(segment_texts), batch_size), desc="🚀 Embedding segments", unit="batch"):
        batch_texts = segment_texts[i:i+batch_size]
        batch_embeddings = []
        for text in batch_texts:
            if len(tokenizer.encode(text, add_special_tokens=False)) > 8192:
                batch_embeddings.append(embed_long_text(text, model, tokenizer))
            else:
                batch_embeddings.append(model.encode(text, convert_to_tensor=True).cpu().numpy())
        embeddings.extend(batch_embeddings)
        
        # Save checkpoint
        if checkpoint_path:
            checkpoint_data = {"embeddings": embeddings, "start_idx": i + batch_size}
            torch.save(checkpoint_data, checkpoint_path)
            print(f"💾 Checkpoint saved at index {i + batch_size}")
    
    # Map embeddings back to the dataframe
    embedding_map = dict(zip(segment_ids, embeddings))
    return df[segment_id_column].map(embedding_map)

In [None]:
# === LOAD MODEL ===
model = SentenceTransformer("BAAI/bge-m3", device="cuda" if torch.cuda.is_available() else "cpu")
print("✅ SentenceTransformer model loaded.")

In [None]:
# === PROCESS DATASETS ===

# AT: Generate embeddings for English and German
print("🔄 Processing Austrian Parliament dataset...")
AT["segment_embeddings_english"] = generate_segment_embeddings(
    AT, text_column="Text", segment_id_column="Segment_ID_english", model=model,
    checkpoint_path=f"{base_path}/AT/english_checkpoint.pt"
)
AT.to_pickle(f"{base_path}/AT/AT_with_english_embeddings.pkl")
print("✅ Saved Austrian Parliament dataset with English embeddings.")

AT["segment_embeddings_native_language"] = generate_segment_embeddings(
    AT, text_column="Text_native_language", segment_id_column="Segment_ID_german", model=model,
    checkpoint_path=f"{base_path}/AT/native_checkpoint.pt"
)
AT.to_pickle(f"{base_path}/AT/AT_with_native_embeddings.pkl")
print("✅ Saved Austrian Parliament dataset with native language embeddings.")

# HR: Generate embeddings for English and Croatian
print("🔄 Processing Croatian Parliament dataset...")
HR["segment_embeddings_english"] = generate_segment_embeddings(
    HR, text_column="Text", segment_id_column="Segment_ID_english", model=model,
    checkpoint_path=f"{base_path}/HR/english_checkpoint.pt"
)
HR.to_pickle(f"{base_path}/HR/HR_with_english_embeddings.pkl")
print("✅ Saved Croatian Parliament dataset with English embeddings.")

HR["segment_embeddings_native_language"] = generate_segment_embeddings(
    HR, text_column="Text_native_language", segment_id_column="Segment_ID_croatian", model=model,
    checkpoint_path=f"{base_path}/HR/native_checkpoint.pt"
)
HR.to_pickle(f"{base_path}/HR/HR_with_native_embeddings.pkl")
print("✅ Saved Croatian Parliament dataset with native language embeddings.")

# GB: Generate embeddings for English only
print("🔄 Processing British Parliament dataset...")
GB["segment_embeddings_english"] = generate_segment_embeddings(
    GB, text_column="Text", segment_id_column="Segment_ID", model=model,
    checkpoint_path=f"{base_path}/GB/english_checkpoint.pt"
)
GB.to_pickle(f"{base_path}/GB/GB_with_english_embeddings.pkl")
print("✅ Saved British Parliament dataset with English embeddings.")