In [3]:
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    AutoModel
)
from datasets import Dataset
import numpy as np
from typing import List
from data_processor import DataProcessor
import torch.nn.functional as F
from typing import Union
from transformers import ModernBertModel
from tqdm import tqdm

In [4]:
torch.cuda.is_available()

True

In [5]:
class MLMTrainer:
    def __init__(
        self,
        model_name: str = "answerdotai/ModernBERT-base",
        output_dir: str = "/ceph/submit/data/user/b/blaised/mlm_output",
        cache_dir: str = "/ceph/submit/data/user/b/blaised/cache",
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.cache_dir = cache_dir
        self.output_dir = output_dir

        # Initialize tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, cache_dir=self.cache_dir
        )
        
        # Initialize model with Flash Attention 2 disabled
        if self.device == "cuda":
            self.model = AutoModelForMaskedLM.from_pretrained(
                model_name,
                cache_dir=self.cache_dir,
                torch_dtype=torch.bfloat16,  # More efficient than float32
                attn_implementation="flash_attention_2",
                reference_compile=False,
                classifier_pooling="mean",
            ).to(f"cuda:{torch.cuda.current_device()}")
        else:
            self.model = AutoModelForMaskedLM.from_pretrained(
                model_name,
                cache_dir=self.cache_dir,
                reference_compile=False,
            )

    def train(
        self,
        train_dataset: Dataset,
        eval_dataset: Dataset = None,
        num_train_epochs: int = 3,
        per_device_train_batch_size: int = 8,
        gradient_accumulation_steps: int = 4,
        learning_rate: float = 3e-4,
        weight_decay: float = 1e-5
    ):
        training_args = TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=num_train_epochs,
            per_device_train_batch_size=per_device_train_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            logging_steps=1,
            save_strategy="epoch",
            eval_strategy="epoch" if eval_dataset else "no",
            # Performance optimizations
            fp16=True,                    # Enable mixed precision training
            dataloader_num_workers=4,     # Parallel data loading
            dataloader_pin_memory=True,   # Faster data transfer to GPU
            optim="adamw_torch_fused",    # Use fused optimizer
            lr_scheduler_type="cosine",   # Cosine decay often works well
            warmup_ratio=0.1,             # Gradual warmup for first 10% of steps
            # Save best model
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
        )
    
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer, mlm=True, mlm_probability=0.15
        )
    
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
        )
    
        trainer.train()
        trainer.save_model(f"{self.output_dir}/final_model")
        self.tokenizer.save_pretrained(f"{self.output_dir}/final_model")
    
        return trainer

In [6]:
data = DataProcessor.load_and_process_data(
    "/ceph/submit/data/user/b/blaised/lhcb_corpus/lhcb_papers.pkl"
)
texts = data["abstract"].tolist()

# Train/eval split
np.random.seed(42)
eval_size = int(len(texts) * 0.1)
eval_indices = np.random.choice(len(texts), eval_size, replace=False)
train_indices = [i for i in range(len(texts)) if i not in eval_indices]

train_texts = [texts[i] for i in train_indices]
eval_texts = [texts[i] for i in eval_indices]

In [7]:
# ------------------------------------------------------------
# 2) Initialize trainer and get tokenizer
# ------------------------------------------------------------
mlm_trainer = MLMTrainer()
tokenizer = mlm_trainer.tokenizer  # Get reference to tokenizer

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [8]:
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding=True,  # ModernBERT unpadded usage
        truncation=True,
        return_special_tokens_mask=True,
    )

In [9]:
    # ------------------------------------------------------------
# 3) Prepare and tokenize datasets
# ------------------------------------------------------------
# Create datasets
train_dataset = Dataset.from_dict({"text": train_texts})
eval_dataset = Dataset.from_dict({"text": eval_texts})

# Tokenize using multiple processes
train_dataset = train_dataset.map(
    tokenize_function,
    batched=False,
    num_proc=4,
    remove_columns=["text"],
)
eval_dataset = eval_dataset.map(
    tokenize_function,
    batched=True,
    num_proc=4,
    remove_columns=["text"],
)

Map (num_proc=4):   0%|          | 0/730 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/81 [00:00<?, ? examples/s]

In [10]:
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'special_tokens_mask'],
    num_rows: 730
})

In [11]:
# ------------------------------------------------------------
# 4) Train the model
# ------------------------------------------------------------
mlm_trainer.train(
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    num_train_epochs=30,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
)



Epoch,Training Loss,Validation Loss
1,1.193,0.578467
2,1.0685,0.574446
3,1.2677,0.591494
4,1.0096,0.574323
5,0.9321,0.539634
6,0.7521,0.548926
7,0.7116,0.523832
8,0.6805,0.554088
9,0.8593,0.540707
10,0.7422,0.509413


There were missing keys in the checkpoint model loaded: ['decoder.weight'].


<transformers.trainer.Trainer at 0x7f0a2b82c580>

In [12]:
class EncoderModel:
    """Model loading, tokenisatuionn, and inference."""

    def __init__(
        self,
        model_name: str,
        cache_dir: str = "/ceph/submit/data/user/b/blaised/cache",
        device: Union[str, None] = None,
    ) -> None:
        """Initialize the model."""
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)
        self.cache_dir = cache_dir

        # tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            cache_dir=self.cache_dir,
        )

        # book the encoding model
        if self.device.type == "cuda":
            self.model = AutoModel.from_pretrained(
                model_name,
                cache_dir=self.cache_dir,
                #attn_implementation="flash_attention_2",
            ).to("cuda")
        else:
            self.model = ModernBertModel.from_pretrained(
                model_name,
                cache_dir=self.cache_dir,
            )

        # sanity device check
        assert (
            self.model.device.type == self.device.type
        ), f"Model is on {self.model.device.type}, but expected {self.device.type}."

    def encode(self, texts: list[str], batch_size: int = 1) -> torch.Tensor:
        """Get embeddings for a list of texts."""
        embeddings = []

        # Process texts in full-batch mode
        self.model.eval()
        for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
            batch_text = texts[i : i + min(batch_size, len(texts))]

            # Tokenize and encode the batch
            inputs = self.tokenizer(batch_text, return_tensors="pt").to(self.device) 

            with torch.no_grad():
                outputs = self.model(**inputs)

                # Fetch the [CLS] representation in the last embedding layer - following BERT - see ModernBertConfig() in Transformers
                batch_embeddings = outputs.last_hidden_state[:, 0, :]  # checked

                embeddings.append(batch_embeddings.cpu())

        return torch.cat(embeddings, dim=0)


    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]
        input_mask_expanded = (
            attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        )
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
            input_mask_expanded.sum(1), min=1e-9
    )
    
    def mean_pool_encode(self, texts, batch_size=1, prefix: Union[str, None] = None):
        """Encoding via mean pooling"""
        embeddings = []

        # for modernbert-embed, we need to preped prefix
        if prefix:
            texts = [f"{prefix}: {t}" for t in texts]
        
        for i in tqdm(range(0, len(texts), batch_size), desc="Processing batched"):
            batch_texts = texts[i:i + batch_size]
            
            # Tokenize and encode the batch
            inputs = self.tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt").to(self.device) # FIXME: truncation, and/or padding
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                batch_embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
                batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
                embeddings.append(batch_embeddings.cpu())
                
        return torch.cat(embeddings, dim=0)

In [None]:
lhcb_abstract_dataset = DataProcessor.load_and_process_data(
    "/ceph/submit/data/user/b/blaised/lhcb_corpus/lhcb_papers.pkl"
)

abstract_corpus = lhcb_abstract_dataset["abstract"].tolist()
working_groups = lhcb_abstract_dataset["working_groups"].tolist()
abstract_labels = lhcb_abstract_dataset["encoded_wg"].tolist()

print(f"Loaded {len(abstract_corpus)} abstracts")

In [None]:
trained_model = EncoderModel(
        #model_name="/ceph/submit/data/user/b/blaised/mlm_output/final_model",
        model_name="answerdotai/ModernBERT-base",
        #model_name="nomic-ai/modernbert-embed-base",
        #model_name="lightonai/modernbert-embed-large",
        #model_name="thellert/physbert_cased",
        device="cuda:0",
)

In [None]:
embeddings = trained_model.mean_pool_encode(abstract_corpus, prefix=None)
embeddings_np = embeddings.numpy()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
import numpy as np
import umap.umap_ as umap


class Visualizer:
    """Handles visualization of embeddings and metrics"""

    @staticmethod
    def plot_embeddings(embeddings, labels, method="pca", save_path="test.png"):
        # Convert inputs to numpy arrays if they aren't already
        embeddings_np = np.array(embeddings)
        labels_np = np.array(labels)
        
        plt.figure()
        
        reducer = PCA(n_components=2) if method.lower() == "pca" else umap.UMAP(random_state=42)
        reduced = reducer.fit_transform(embeddings_np)
        
        # Get unique labels and assign colors
        unique_labels = np.unique(labels_np)
        colors = sns.color_palette("Spectral", n_colors=len(unique_labels))
        
        for i, label in enumerate(unique_labels):
            # Create boolean mask for this label
            mask = (labels_np == label)
            if np.sum(mask) > 0:  # Only plot if we have points for this label
                plt.scatter(reduced[mask, 0], reduced[mask, 1],
                           c=colors[i], label=str(label), alpha=0.75)
        
        plt.title(f"LHCb Abstracts Embeddings ({method.upper()})")
        plt.legend(bbox_to_anchor=(1.05, 0.85), loc="upper left", bbox_transform=plt.gcf().transFigure)
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, bbox_inches="tight", dpi=300)
        plt.show()

In [None]:
Visualizer().plot_embeddings(
    embeddings_np, working_groups, method="umap"
)

In [None]:
from metrics import MetricsCalculator

metrics_calc = MetricsCalculator()

print("\nComputing clustering metrics...")
clustering_results = metrics_calc.compute_clustering_metrics(
    embeddings=embeddings_np, labels=working_groups
)

print(f"NMI score: {clustering_results['nmi_score']:.3f}")

In [None]:
# 2. Compute group metrics
print("\nComputing group metrics...")
group_metrics = metrics_calc.compute_group_metrics(
    embeddings=embeddings_np, groups=working_groups
)

print("\nGroup statistics:")
for group, metrics in group_metrics.items():
    print(f"\nGroup: {group}")
    print(f"Count: {metrics['count']}")
    print(f"Average similarity: {metrics['avg_similarity']:.3f}")