# TopicWeave: Model Distillation and Ensemble Embedding Generation

This notebook implements the distillation of SciBERT and SPECTER2 models and generates ensemble embeddings for scientific literature topic modeling.

## Setup

In [1]:
# %%
# Mount Google Drive for storing models and embeddings
from google.colab import drive
drive.mount('/content/drive')

# Define output directory in Drive
OUTPUT_DIR = '/content/drive/MyDrive/MIDS/W266/final_project/'

Mounted at /content/drive


In [2]:
# %%
# Install necessary packages
!pip install -q transformers sentence-transformers datasets torch scikit-learn numpy pandas tqdm nltk jsonlines

import os
import json
import torch
import numpy as np
import pandas as pd
import jsonlines
from tqdm.auto import tqdm
from pathlib import Path
from typing import List, Dict, Union, Optional
import random
import matplotlib.pyplot as plt

# For model distillation and embeddings
from transformers import AutoConfig, AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer, models, losses, InputExample
from torch.utils.data import DataLoader

# For preprocessing
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('punkt_tab')

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m487.4/487.4 kB[0m [31m34.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m109.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m81.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m60.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [3]:
import huggingface_hub
print(huggingface_hub.__version__)

import numpy as np
print(np.__version__)

try:
    from sentence_transformers import SentenceTransformer
    print(np.dtype(np.int32))
    print("Sentence transformers imported without error")
except AttributeError as e:
    print(f"AttributeError when importing sentence transformers: {e}")
except ImportError as e:
    print(f"ImportError when importing sentence transformers: {e}")
except Exception as e:
    print(f"General Error: {e}")

0.28.1
1.26.4
int32
Sentence transformers imported without error


In [4]:
# Create directories for saving models and embeddings
os.makedirs(f"{OUTPUT_DIR}/models", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/embeddings", exist_ok=True)
os.makedirs(f"{OUTPUT_DIR}/figures", exist_ok=True)
os.environ["WANDB_DISABLED"] = "true"  # Disable wandb logging

In [5]:
# %%
# Function to check if GPU is available
def check_gpu():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("GPU not available, using CPU")
    return device

device = check_gpu()

Using GPU: NVIDIA A100-SXM4-40GB


## Data loading and preprocessing

In [6]:
# %%
def load_arxiv_data(data_path: str, categories: List[str], samples_per_category: int = 5000, seed: int = 42) -> pd.DataFrame:
    """
    Load and preprocess arXiv dataset from JSON file.

    Args:
        data_path: Path to the arXiv metadata JSON file
        categories: List of categories to include
        samples_per_category: Number of papers to sample per category
        seed: Random seed for reproducibility

    Returns:
        DataFrame with processed arXiv data
    """
    print(f"Loading arXiv data from {data_path}")

    # Set random seed for reproducibility
    random.seed(seed)

    # Create a dictionary to store papers by category
    papers_by_category = {cat: [] for cat in categories}

    # Process the JSON file line by line to avoid loading the entire file into memory
    with jsonlines.open(data_path) as reader:
        for paper in tqdm(reader, desc="Processing papers"):
            # Check if the paper has a category we're interested in
            if 'categories' in paper:
                # Split categories and check if any match our target categories
                paper_cats = paper['categories'].split()
                primary_cat = paper_cats[0] if paper_cats else None

                # Only consider papers where target category is the primary category
                if primary_cat in categories and len(papers_by_category[primary_cat]) < samples_per_category:
                    if 'abstract' in paper and len(paper.get('abstract', '')) > 100:  # Ensure it has a decent abstract
                        # Add this paper to the category
                        papers_by_category[primary_cat].append({
                            'id': paper.get('id', ''),
                            'title': paper.get('title', ''),
                            'abstract': paper.get('abstract', ''),
                            'category': primary_cat
                        })

    # Combine all papers into a single list
    all_papers = []
    for cat, papers in papers_by_category.items():
        print(f"Category {cat}: {len(papers)} papers")
        all_papers.extend(papers)

    # Convert to DataFrame
    df = pd.DataFrame(all_papers)

    # Shuffle the data
    df = df.sample(frac=1, random_state=seed).reset_index(drop=True)

    print(f"Loaded {len(df)} papers across {len(categories)} categories")
    return df

In [7]:
# %%
def preprocess_text(text: str) -> str:
    """Clean and normalize text by removing stopwords and lowercasing"""
    if not isinstance(text, str) or not text:
        return ""

    try:
        # Tokenize text
        tokens = word_tokenize(text.lower())

        # Remove stopwords and short tokens
        stop_words = set(stopwords.words('english'))
        filtered_tokens = [token for token in tokens if token not in stop_words and len(token) > 2]

        return " ".join(filtered_tokens)
    except Exception as e:
        print(f"Error processing text: {e}")
        return text.lower()  # Fallback to just lowercasing if tokenization fails

In [8]:
# %%
def prepare_arxiv_dataset(df: pd.DataFrame) -> pd.DataFrame:
    """Prepare arXiv dataset for embedding generation"""
    # Combine title and abstract with weight on title
    df['text'] = df['title'] + " " + df['title'] + " " + df['abstract']

    # Clean text
    print("Cleaning and preprocessing text...")
    df['processed_text'] = df['text'].apply(preprocess_text)

    # Remove papers with very short processed text
    len_before = len(df)
    df = df[df['processed_text'].str.len() > 50]
    print(f"Removed {len_before - len(df)} papers with short text")

    return df

In [9]:
# %%
class EarlyStoppingCallback:
    """
    Early stopping callback for SentenceTransformer training.
    Stops training when loss drops below a threshold or doesn't improve.
    """
    def __init__(self, min_loss=0.01, patience=3):
        self.min_loss = min_loss
        self.patience = patience
        self.best_loss = float('inf')
        self.counter = 0
        self.history = []

    def __call__(self, score, epoch, steps):
        self.history.append(score)

        # Stop if loss is below minimum threshold
        if score < self.min_loss:
            print(f"\nEarly stopping: Loss {score:.4f} below threshold {self.min_loss:.4f}")
            return True

        # Stop if loss hasn't improved in 'patience' epochs
        if score < self.best_loss:
            self.best_loss = score
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print(f"\nEarly stopping: No improvement for {self.patience} epochs")
                return True

        return False

## Model distillation

In [10]:
def create_distilled_model(model_name: str, num_layers: int = 6, max_seq_length: int = 256) -> SentenceTransformer:
    """
    Create a distilled model by reducing the number of layers in a transformer model.
    """
    print(f"Creating distilled model from {model_name} with {num_layers} layers")

    # Create a reduced configuration
    config = AutoConfig.from_pretrained(model_name)
    original_layers = config.num_hidden_layers
    print(f"Original model has {original_layers} layers")

    # Set the number of layers to our desired count
    config.num_hidden_layers = num_layers

    # Create the smaller model with explicit device assignment
    small_model = AutoModel.from_pretrained(model_name, config=config)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Create a sentence transformer with standard components
    word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)

    # Important: Replace the transformer model with our reduced version
    word_embedding_model.auto_model = small_model

    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())

    # Create the sentence transformer model
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

    # Verify the model exists before returning
    if model is None:
        raise ValueError("Failed to create SentenceTransformer model")

    return model

In [11]:
# %%
def create_distilled_scibert(teacher_model_name: str,
                           train_texts: List[str],
                           eval_texts: List[str] = None,
                           max_seq_length: int = 256,
                           num_epochs: int = 10,
                           num_layers: int = 6) -> SentenceTransformer:
    """
    Create a distilled SciBERT model using knowledge distillation.

    Args:
        teacher_model_name: Name of the teacher model
        train_texts: List of texts for training
        eval_texts: List of texts for evaluation (optional)
        max_seq_length: Maximum sequence length
        num_epochs: Maximum number of training epochs
        num_layers: Number of layers in the distilled model

    Returns:
        Distilled SentenceTransformer model
    """
    print(f"\nDistilling SciBERT from {teacher_model_name}...")

    # If no evaluation texts, use a subset of training texts
    if eval_texts is None and train_texts:
        eval_size = min(1000, int(len(train_texts) * 0.1))
        eval_texts = train_texts[-eval_size:]
        train_texts = train_texts[:-eval_size]

    # Load the teacher model as SentenceTransformer
    teacher_model = SentenceTransformer(teacher_model_name)

    # Create a smaller student model
    student_model = create_distilled_model(
        model_name=teacher_model_name,
        num_layers=num_layers,
        max_seq_length=max_seq_length
    )

    # Generate teacher embeddings for training
    print("Generating teacher embeddings for training...")
    train_teacher_embeddings = teacher_model.encode(
        train_texts,
        batch_size=32,
        show_progress_bar=True,
        convert_to_tensor=False
    )

    # Create training examples
    train_examples = []
    for idx, text in enumerate(train_texts):
        train_examples.append(InputExample(texts=[text], label=train_teacher_embeddings[idx]))

    # Create a DataLoader for our training data
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

    # Use MSE loss for distillation
    train_loss = losses.MSELoss(model=student_model)

    # Set up early stopping
    early_stopping = EarlyStoppingCallback(min_loss=0.005, patience=3)

    # Train the student model
    warmup_steps = int(len(train_examples) * 0.1)

    # Set model path
    student_model_path = f"{OUTPUT_DIR}/models/distilled_scibert"
    os.makedirs(student_model_path, exist_ok=True)

    print(f"Training student model with {len(train_examples)} examples for up to {num_epochs} epochs...")
    student_model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        output_path=student_model_path,
        show_progress_bar=True,
        save_best_model=True,
        callback=early_stopping
    )

    # Load best model
    student_model = SentenceTransformer(student_model_path)

    # Plot training history if available
    if early_stopping.history:
        plt.figure(figsize=(10, 6))
        plt.plot(early_stopping.history)
        plt.xlabel("Evaluation Step")
        plt.ylabel("MSE Loss")
        plt.title("SciBERT Distillation Training Progress")
        plt.grid()
        plt.savefig(f"{OUTPUT_DIR}/figures/scibert_distillation_progress.png")
        plt.close()

    return student_model

In [12]:
# %%
def create_citation_aware_model(base_model_name: str,
                              train_texts: List[str],
                              eval_texts: List[str] = None,
                              max_seq_length: int = 256,
                              num_epochs: int = 10,
                              num_layers: int = 6) -> SentenceTransformer:
    """
    Create a citation-aware model to approximate SPECTER2's behavior.

    This model is trained to identify relationships between sections of the same document,
    simulating how SPECTER2 captures citation relationships.

    Args:
        base_model_name: Name of the base model
        train_texts: List of texts for training
        eval_texts: List of texts for evaluation (optional)
        max_seq_length: Maximum sequence length
        num_epochs: Maximum number of training epochs
        num_layers: Number of layers in the model

    Returns:
        SentenceTransformer model trained for citation-awareness
    """
    print(f"\nCreating citation-aware model based on {base_model_name}...")

    # Create a smaller model
    model = create_distilled_model(
        model_name=base_model_name,
        num_layers=num_layers,
        max_seq_length=max_seq_length
    )

    # Create training pairs that help simulate citation context learning
    print("Creating citation simulation training pairs...")
    train_examples = []

    # We'll create positive pairs (sections from same document) and negative pairs (sections from different documents)
    for i, text in enumerate(tqdm(train_texts, desc="Creating training pairs")):
        if len(text.split()) > 30:  # Only use texts with sufficient length
            # Split text into sections
            sentences = text.split('. ')
            if len(sentences) >= 3:
                # Create positive pairs from same document
                sections = []

                # Create 2-3 sections from document
                section_size = max(1, len(sentences) // 3)
                for j in range(0, len(sentences), section_size):
                    if j + section_size <= len(sentences):
                        section = '. '.join(sentences[j:j+section_size]) + '.'
                        sections.append(section)

                # Create positive pairs (same document)
                if len(sections) >= 2:
                    for j in range(len(sections)):
                        for k in range(j+1, len(sections)):
                            train_examples.append(InputExample(texts=[sections[j], sections[k]], label=1.0))

                    # Create negative pairs (different documents)
                    if i > 0 and len(train_texts) > 1:
                        other_idx = (i + 1) % len(train_texts)
                        if len(train_texts[other_idx].split()) > 30:
                            other_sentences = train_texts[other_idx].split('. ')
                            if len(other_sentences) >= 2:
                                other_section = '. '.join(other_sentences[:min(len(other_sentences), section_size)]) + '.'
                                train_examples.append(InputExample(texts=[sections[0], other_section], label=0.0))

    # Shuffle and limit examples for efficiency
    random.shuffle(train_examples)
    max_examples = 100000
    if len(train_examples) > max_examples:
        train_examples = train_examples[:max_examples]

    if len(train_examples) == 0:
        print("Error: No training examples created. Check your input texts.")
        return model

    # Create DataLoader
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

    # Use cosine similarity loss to learn citation-like relationships
    train_loss = losses.CosineSimilarityLoss(model=model)

    # Set up early stopping
    early_stopping = EarlyStoppingCallback(min_loss=0.05, patience=3)

    # Train the model
    model_path = f"{OUTPUT_DIR}/models/citation_aware_model"
    os.makedirs(model_path, exist_ok=True)

    print(f"Training citation-aware model with {len(train_examples)} pairs for up to {num_epochs} epochs...")
    model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=num_epochs,
        warmup_steps=int(len(train_examples) * 0.1),
        output_path=model_path,
        show_progress_bar=True,
        save_best_model=True,
        callback=early_stopping
    )

    # Load best model
    model = SentenceTransformer(model_path)

    # Plot training history if available
    if early_stopping.history:
        plt.figure(figsize=(10, 6))
        plt.plot(early_stopping.history)
        plt.xlabel("Evaluation Step")
        plt.ylabel("Loss")
        plt.title("Citation-Aware Model Training Progress")
        plt.grid()
        plt.savefig(f"{OUTPUT_DIR}/figures/citation_model_progress.png")
        plt.close()

    return model

## Generate embeddings

In [13]:
# %%
def generate_embeddings(model: SentenceTransformer, texts: List[str], batch_size: int = 32, model_name: str = "") -> np.ndarray:
    """
    Generate embeddings using a SentenceTransformer model.

    Args:
        model: SentenceTransformer model
        texts: List of texts to encode
        batch_size: Batch size for encoding
        model_name: Model name for progress bar description

    Returns:
        NumPy array of embeddings
    """
    print(f"Generating embeddings using {model_name} model...")
    embeddings = model.encode(texts, batch_size=batch_size, show_progress_bar=True)

    # Normalize embeddings
    norm = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized_embeddings = embeddings / norm

    return normalized_embeddings

In [14]:
# %%
def combine_embeddings(scibert_embeddings: np.ndarray,
                       citation_embeddings: np.ndarray,
                       weight_scibert: float = 0.5) -> np.ndarray:
    """
    Combine embeddings with specified weights.

    Args:
        scibert_embeddings: SciBERT embeddings
        citation_embeddings: Citation-aware embeddings
        weight_scibert: Weight for SciBERT embeddings (1-weight for citation model)

    Returns:
        Combined embeddings
    """
    assert scibert_embeddings.shape[0] == citation_embeddings.shape[0], "Number of embeddings must match"

    # Weighted combination (embeddings are already normalized)
    combined = (weight_scibert * scibert_embeddings) + ((1 - weight_scibert) * citation_embeddings)

    # Normalize combined embeddings
    norm = np.linalg.norm(combined, axis=1, keepdims=True)
    combined_norm = combined / norm

    return combined_norm

In [15]:
# %%
def optimize_weights(scibert_embeddings: np.ndarray,
                     citation_embeddings: np.ndarray,
                     categories: List[str],
                     texts: List[str] = None,
                     weight_range: List[float] = None) -> Dict:
    """
    Find the optimal weight combination for SciBERT and citation-aware embeddings.

    Args:
        scibert_embeddings: SciBERT embeddings
        citation_embeddings: Citation-aware embeddings
        categories: Document categories (for evaluating alignment)
        texts: Document texts (optional, for evaluating topic coherence)
        weight_range: List of weights to try (default: 11 weights from 0.0 to 1.0)

    Returns:
        Dictionary with optimal weight and scores
    """
    from sklearn.metrics import silhouette_score, adjusted_rand_score, normalized_mutual_info_score
    from sklearn.metrics.pairwise import cosine_similarity
    from sklearn.cluster import KMeans
    import matplotlib.pyplot as plt

    if weight_range is None:
        weight_range = np.linspace(0.0, 1.0, 11)  # [0.0, 0.1, 0.2, ..., 1.0]

    print("Optimizing ensemble weights...")
    results = []

    # Convert categories to numeric labels
    from sklearn.preprocessing import LabelEncoder
    encoder = LabelEncoder()
    category_labels = encoder.fit_transform(categories)

    # Number of clusters equals number of categories
    n_clusters = len(set(categories))

    for weight in tqdm(weight_range, desc="Testing weights"):
        # Combine embeddings with this weight
        combined = combine_embeddings(scibert_embeddings, citation_embeddings, weight_scibert=weight)

        # 1. Calculate silhouette score using categories as labels
        sil_score = silhouette_score(combined, category_labels, sample_size=min(5000, len(categories)), metric='cosine')

        # 2. Perform K-means clustering and evaluate against true categories
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(combined)

        ari = adjusted_rand_score(category_labels, cluster_labels)
        nmi = normalized_mutual_info_score(category_labels, cluster_labels)

        # 3. Calculate within-category similarity
        # Sample up to 1000 documents for efficiency
        sample_size = min(1000, len(categories))
        indices = np.random.choice(len(categories), sample_size, replace=False)

        sampled_embeddings = combined[indices]
        sampled_categories = [categories[i] for i in indices]

        # Calculate within-category similarity
        similarities = []
        for cat in set(sampled_categories):
            cat_indices = [i for i, c in enumerate(sampled_categories) if c == cat]
            if len(cat_indices) > 1:
                cat_embeddings = sampled_embeddings[cat_indices]
                sim_matrix = cosine_similarity(cat_embeddings)
                # Remove self-similarity (diagonal)
                np.fill_diagonal(sim_matrix, 0)
                avg_sim = sim_matrix.sum() / (len(cat_indices) * (len(cat_indices) - 1))
                similarities.append(avg_sim)

        within_cat_similarity = np.mean(similarities) if similarities else 0

        # Store results
        results.append({
            'weight_scibert': weight,
            'silhouette_score': sil_score,
            'ari': ari,
            'nmi': nmi,
            'within_category_similarity': within_cat_similarity,
            # Combined score (weighted average)
            'combined_score': 0.3 * sil_score + 0.3 * nmi + 0.2 * ari + 0.2 * within_cat_similarity
        })

    # Convert to DataFrame
    results_df = pd.DataFrame(results)

    # Find optimal weight
    best_idx = results_df['combined_score'].idxmax()
    optimal_weight = results_df.loc[best_idx, 'weight_scibert']

    # Plot results
    plt.figure(figsize=(12, 8))
    plt.plot(results_df['weight_scibert'], results_df['silhouette_score'], 'b-o', label='Silhouette Score')
    plt.plot(results_df['weight_scibert'], results_df['nmi'], 'g-o', label='NMI')
    plt.plot(results_df['weight_scibert'], results_df['ari'], 'r-o', label='ARI')
    plt.plot(results_df['weight_scibert'], results_df['within_category_similarity'], 'c-o', label='Within-Category Similarity')
    plt.plot(results_df['weight_scibert'], results_df['combined_score'], 'k-o', label='Combined Score')
    plt.axvline(x=optimal_weight, color='black', linestyle='--', label=f'Optimal Weight: {optimal_weight:.2f}')
    plt.xlabel('SciBERT Weight')
    plt.ylabel('Score')
    plt.title('Weight Optimization Results')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"{OUTPUT_DIR}/figures/weight_optimization.png", dpi=300)
    plt.show()

    print(f"Optimal SciBERT weight: {optimal_weight:.2f}")
    print(f"Optimal Citation Model weight: {1-optimal_weight:.2f}")

    return {
        'optimal_weight_scibert': optimal_weight,
        'optimal_weight_citation': 1 - optimal_weight,
        'results': results_df
    }

## Main execution

In [16]:
from datetime import datetime
start_time = datetime.now()
print(f"TopicWeave: Starting execution at {start_time.strftime('%Y-%m-%d %H:%M:%S')}")

# Configuration
data_path = '/content/drive/MyDrive/MIDS/W266/final_project/data/arxiv-metadata-oai-snapshot.json'
target_categories = [
    "cs.LG", "cs.AI", "cs.CL", "cs.CV", "physics.comp-ph",
    "math.ST", "q-bio.QM", "q-fin.ST", "stat.ML", "cs.DL"
]
samples_per_category = 5000
max_seq_length = 256
distilled_layers = 6

# Model names
scibert_model_name = "allenai/scibert_scivocab_uncased"

# Step 1: Load and preprocess data
print("\n===== Loading and Preprocessing Data =====")
df = load_arxiv_data(
    data_path=data_path,
    categories=target_categories,
    samples_per_category=samples_per_category
)
df = prepare_arxiv_dataset(df)

# Split data for training/validation
train_val_split = int(len(df) * 0.9)
train_texts = df['processed_text'][:train_val_split].tolist()
val_texts = df['processed_text'][train_val_split:].tolist()
all_texts = df['processed_text'].tolist()
all_categories = df['category'].tolist()

TopicWeave: Starting execution at 2025-03-16 14:19:53

===== Loading and Preprocessing Data =====
Loading arXiv data from /content/drive/MyDrive/MIDS/W266/final_project/data/arxiv-metadata-oai-snapshot.json


Processing papers: 0it [00:00, ?it/s]

Category cs.LG: 5000 papers
Category cs.AI: 5000 papers
Category cs.CL: 5000 papers
Category cs.CV: 5000 papers
Category physics.comp-ph: 5000 papers
Category math.ST: 5000 papers
Category q-bio.QM: 5000 papers
Category q-fin.ST: 2087 papers
Category stat.ML: 5000 papers
Category cs.DL: 3568 papers
Loaded 45655 papers across 10 categories
Cleaning and preprocessing text...
Removed 0 papers with short text


In [17]:
# Step 2: Create distilled SciBERT model
print("\n===== Creating Distilled SciBERT Model =====")
distilled_scibert = create_distilled_scibert(
    teacher_model_name=scibert_model_name,
    train_texts=train_texts,
    eval_texts=val_texts,
    max_seq_length=max_seq_length,
    num_layers=distilled_layers
)


===== Creating Distilled SciBERT Model =====

Distilling SciBERT from allenai/scibert_scivocab_uncased...




config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/442M [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/228k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/442M [00:00<?, ?B/s]

Creating distilled model from allenai/scibert_scivocab_uncased with 6 layers
Original model has 12 layers
Generating teacher embeddings for training...


Batches:   0%|          | 0/1285 [00:00<?, ?it/s]

Training student model with 41089 examples for up to 10 epochs...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.1512
1000,0.0671
1500,0.0547
2000,0.0455
2500,0.0402
3000,0.0354
3500,0.0325
4000,0.03
4500,0.0283
5000,0.0268


Step,Training Loss
500,0.1512
1000,0.0671
1500,0.0547
2000,0.0455
2500,0.0402
3000,0.0354
3500,0.0325
4000,0.03
4500,0.0283
5000,0.0268


In [18]:
# Step 3: Create citation-aware model
print("\n===== Creating Citation-Aware Model =====")
citation_model = create_citation_aware_model(
    base_model_name=scibert_model_name,
    train_texts=train_texts,
    eval_texts=val_texts,
    max_seq_length=max_seq_length,
    num_layers=distilled_layers
)


===== Creating Citation-Aware Model =====

Creating citation-aware model based on allenai/scibert_scivocab_uncased...
Creating distilled model from allenai/scibert_scivocab_uncased with 6 layers
Original model has 12 layers
Creating citation simulation training pairs...


Creating training pairs:   0%|          | 0/41089 [00:00<?, ?it/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Training citation-aware model with 3608 pairs for up to 10 epochs...


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0249
1000,0.0113
1500,0.0034
2000,0.0009


In [None]:
# Step 4: Generate embeddings
print("\n===== Generating Embeddings =====")
scibert_embeddings = generate_embeddings(
    model=distilled_scibert,
    texts=all_texts,
    model_name="Distilled SciBERT"
)

In [20]:
citation_embeddings = generate_embeddings(
        model=citation_model,
        texts=all_texts,
        model_name="Citation-Aware Model"
    )

Generating embeddings using Citation-Aware Model model...


Batches:   0%|          | 0/1427 [00:00<?, ?it/s]

In [None]:
# Step 5: Optimize weights
print("\n===== Optimizing Ensemble Weights =====")
weight_results = optimize_weights(
    scibert_embeddings=scibert_embeddings,
    citation_embeddings=citation_embeddings,
    categories=all_categories,
    texts=all_texts
)

optimal_weight = weight_results['optimal_weight_scibert']

In [None]:
# Step 6: Generate combined embeddings
print("\n===== Generating Combined Embeddings =====")
combined_embeddings = combine_embeddings(
    scibert_embeddings=scibert_embeddings,
    citation_embeddings=citation_embeddings,
    weight_scibert=optimal_weight
)

In [23]:
# Step 7: Save results
print("\n===== Saving Results =====")
np.save(f"{OUTPUT_DIR}/embeddings/scibert_embeddings.npy", scibert_embeddings)
np.save(f"{OUTPUT_DIR}/embeddings/citation_embeddings.npy", citation_embeddings)
np.save(f"{OUTPUT_DIR}/embeddings/topicweave_embeddings.npy", combined_embeddings)

# Save document info
df[['id', 'title', 'processed_text', 'category']].to_csv(f"{OUTPUT_DIR}/embeddings/document_info.csv", index=False)

# Save weights results
weight_results['results'].to_csv(f"{OUTPUT_DIR}/embeddings/weight_optimization_results.csv", index=False)

 # Save metadata about the embeddings
with open(f"{OUTPUT_DIR}/embeddings/embeddings_metadata.json", "w") as f:
    json.dump({
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "scibert_model": scibert_model_name,
        "scibert_weight": float(optimal_weight),
        "citation_model_weight": float(1 - optimal_weight),
        "num_documents": len(all_texts),
        "embedding_dimension": int(combined_embeddings.shape[1]),
        "categories": list(set(all_categories)),
        "max_seq_length": max_seq_length,
        "distilled_layers": distilled_layers
    }, f, indent=2)

end_time = datetime.now()
duration = end_time - start_time
print(f"\n===== TopicWeave Embedding Generation Complete =====")
print(f"Total execution time: {duration}")
print(f"Number of documents: {len(all_texts)}")
print(f"SciBERT Embeddings Shape: {scibert_embeddings.shape}")
print(f"Citation Embeddings Shape: {citation_embeddings.shape}")
print(f"Combined Embeddings Shape: {combined_embeddings.shape}")
print(f"Optimal SciBERT Weight: {optimal_weight:.4f}")
print(f"Optimal Citation Model Weight: {(1-optimal_weight):.4f}")
print("\nAll results saved to:")
print(f"  - {OUTPUT_DIR}/embeddings/")
print(f"  - {OUTPUT_DIR}/models/")
print(f"  - {OUTPUT_DIR}/figures/")


===== Saving Results =====

===== TopicWeave Embedding Generation Complete =====
Total execution time: 0:49:12.165451
Number of documents: 45655
SciBERT Embeddings Shape: (45655, 768)
Citation Embeddings Shape: (45655, 768)
Combined Embeddings Shape: (45655, 768)
Optimal SciBERT Weight: 0.7000
Optimal Citation Model Weight: 0.3000

All results saved to:
  - /content/drive/MyDrive/MIDS/W266/final_project//embeddings/
  - /content/drive/MyDrive/MIDS/W266/final_project//models/
  - /content/drive/MyDrive/MIDS/W266/final_project//figures/
