# Formal vs informal classification for russian text using transofrmers architecture

## EDA and data preparation

### Imports and params
- `DOC_STRIDE = 128` creates a 25% overlap between 512-token windows during tokenization so the model can see transitions across chunks without exceeding the max length.
- `RANDOM_SEED = 42` keeps train/val/test splits and shuffling deterministic so experiments are comparable.


In [None]:
%matplotlib inline

import os
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display
import textwrap
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from collections import Counter, defaultdict
from datasets import Dataset, DatasetDict, Value
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
    EarlyStoppingCallback,
)
import torch
import numpy as np

plt.rcParams['figure.figsize'] = (10, 4)
plt.rcParams['axes.grid'] = True
np.random.seed(42)

os.environ['WANDB_DISABLED'] = 'true'
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'

### Set data paths and verify they exist

In [None]:
DATA_DIR = Path('..') / 'data'
LENTA_PATH_SMALL = DATA_DIR / 'lenta_data.csv'
LENTA_PATH = DATA_DIR / 'lenta_texts_2000-20.csv'
ANY_PHRASES_PATH = DATA_DIR / 'dataset_any_phrases.csv'
TAIGA_PATH = DATA_DIR / 'taiga_style_dataset_clean.csv'
SOVA_PATH = DATA_DIR / 'sovadataset.csv'

for path in (LENTA_PATH, LENTA_PATH_SMALL, ANY_PHRASES_PATH, TAIGA_PATH, SOVA_PATH):
    if not path.exists():
        print(f'Warning: missing file {path}')



### `load_lenta_dataset`
- **Logic**: Reads the newline-delimited Lenta dataset where every line is a formal news passage and converts it into the unified DataFrame schema for all data that would be used below (text, numeric label, human-readable label, source).
- **Line by line**:
  - `assert path.exists()`: fail fast if the expected file is missing
  - Initialize `texts = []` and iterate through each raw line, stripping whitespace with `raw_line.strip()`
  - `if text:` keeps only non-empty lines
  - `assert texts`: guarantees at least one sample was read—if not, raising to avoid passing an empty frame downstream.
  - `pd.DataFrame({'text': texts})`: wraps the collected strings into a DataFrame for further processing.
  - `df['label'] = 1` / `df['label_name'] = 'formal'`: mark every record as formal since the corpus only contains official news.
  - `df['source'] = 'lenta'`: keeps provenance so later analysis can filter by origin; the optional `source_prefix` argument lets us distinguish alternative Lenta datasets while leave ability to reuse this function
  - `return df[...]`: select the standardized column order expected by later steps.
- **Notes**: Keeping the loader minimal avoids extra transformations here; any deduplication or length filtering happens after all sources are merged so a single policy applies to everyone.


In [None]:
def load_lenta_dataset(path: Path, source_prefix=""):
    assert path.exists(), f'Missing Lenta dataset: {path}'

    texts = []
    with path.open(encoding='utf-8') as handle:
        for raw_line in handle:
            text = raw_line.strip()
            if text:
                texts.append(text)

    assert texts, f'No text rows found in {path}'

    df = pd.DataFrame({'text': texts})
    df['label'] = 1
    df['label_name'] = 'formal'
    df['source'] = 'lenta' + source_prefix
    return df[['text', 'label', 'label_name', 'source']]

### `load_any_phrases_dataset`
- **Logic**: Parses the informal messenger phrases dataset where each line starts with an optional numeric id followed by a comma and the quote, producing informal-labeled rows.
- **Line by line**:
  - Asserts file exist, iterate to strip whitespace and `continue` past blanks
  - `parts = line.split(',', 1)`: split at most once; commas that belong to the phrase stay intact because we only separate the leading id portion.
  - `text = parts[1] if len(parts) > 1 else parts[0]`: handles both `id,text` and `text`-only lines to keep the loader tolerant of minor format drift.
  - `text.strip().strip('"')`: remove surrounding whitespace and stray double quotes left in the raw dump.
  - Append non-empty strings, verify list isn't empty
  - `df['word_count'] = df['text'].str.split().map(len)` counts tokens by whitespace; `df = df[df['word_count'] >= min_words]` enforces the ≥5-word rule to drop extremely short acknowledgements because they are not suppored by design
  - `df = df.drop(columns=['word_count'])` cleans up the helper column once filtering is done.

In [None]:
def load_any_phrases_dataset(path: Path, min_words: int = 5):
    assert path.exists(), f'Missing dataset_any_phrases corpus: {path}'

    texts = []
    with path.open(encoding='utf-8') as handle:
        for raw_line in handle:
            line = raw_line.strip()
            if not line:
                continue
            parts = line.split(',', 1)
            text = parts[1] if len(parts) > 1 else parts[0]
            text = text.strip().strip('"')
            if text:
                texts.append(text)

    assert texts, f'No text rows found in {path}'

    df = pd.DataFrame({'text': texts})
    df['word_count'] = df['text'].str.split().map(len)
    df = df[df['word_count'] >= min_words]
    df = df.drop(columns=['word_count'])

    df['label'] = 0
    df['label_name'] = 'informal'
    df['source'] = 'any_phrases'
    return df[['text', 'label', 'label_name', 'source']]



### `load_taiga_dataset`
- **Logic**: Normalizes the Taiga style CSV, which already includes labels, by renaming columns, stripping whitespace, and mapping label ids to our two-class schema.
- **Line by line**:
  - Verify file is not missing, load the csv and rename text column to aligns the name with other datasets.
  - `dropna(subset=['text', 'label'])` removes rows that lost either the text or its class during upstream preprocessing.
  - `df['text'] = df['text'].astype(str).str.strip()` ensures even numeric-looking entries become strings and removes leading/trailing whitespaces
  - Filter empty strings and convert labels to `int` to avoid potential issues in future
  - `label_map = {1: 'formal', 0: 'informal'}` maps Taiga's label ids to our naming convention; `map` fills `label_name` accordingly.
  - Set source and return DataFrame
- **Notes**: We keep both classes because Taiga already distinguishes formality


In [None]:
def load_taiga_dataset(path: Path):
    assert path.exists(), f'Missing Taiga dataset: {path}'
    df = pd.read_csv(path)
    df = df.rename(columns={'new_text': 'text'})
    df = df.dropna(subset=['text', 'label'])
    assert 'text' in df.columns, f'Missing text column: {path}'

    df['text'] = df['text'].astype(str).str.strip()
    df = df[df['text'] != '']
    assert not df.empty, f'No text rows found in {path}'

    df['label'] = df['label'].astype(int)
    label_map = {1: 'formal', 0: 'informal'}
    df['label_name'] = df['label'].map(label_map)
    df['source'] = 'taiga'
    return df[['text', 'label', 'label_name', 'source']]


### `load_sova_dataset`
- **Logic**: Load csv, drops tiny snippets, and labels everything as informal since this corpus contains casual messenger replies only
- **Line by line**:
  - Verify file is not missing, load the csv and verify text column exisix
  - `df = df[['text']].copy()` keeps only the text column—we ignore the provided `id`/`label` because id is not relevant and all text inside in informal
  - Normalizes whitespace and guarantee string format, after that remove blank lines
  - Logic with `df['word_count']` they same as for `load_any_phrases_dataset`
  - Set `label=0`, `label_name='informal'`, `source='sova'`, and warn if nothing survives the filter so the user can adjust the threshold
  - Return the standardized columns for concatenation with other datasets


In [None]:
def load_sova_dataset(path: Path, min_words: int = 5):
    assert path.exists(), f'Missing Sova dataset: {path}'
    df = pd.read_csv(path)
    assert 'text' in df.columns, f'Missing text column: {path}'

    df = df[['text']].copy()
    df['text'] = df['text'].astype(str).str.strip()
    df = df[df['text'] != '']

    df['word_count'] = df['text'].str.split().map(len)
    df = df[df['word_count'] >= min_words]
    df = df.drop(columns=['word_count'])

    df['label'] = 0
    df['label_name'] = 'informal'
    df['source'] = 'sova'
    if df.empty:
        print(f'Warning: no Sova rows with >= {min_words} words')
    return df[['text', 'label', 'label_name', 'source']]


### Dataset loading & sanity report
Call every source-specific loader def, gather them into a dictionary, and print per-source counts so we immediately see label balance or missing files before merging


In [None]:
lenta_df_small = load_lenta_dataset(LENTA_PATH_SMALL, "_small")
lenta_df = load_lenta_dataset(LENTA_PATH)
any_phrases_df = load_any_phrases_dataset(ANY_PHRASES_PATH)
taiga_df = load_taiga_dataset(TAIGA_PATH)
sova_df = load_sova_dataset(SOVA_PATH)

datasets = {
    'lenta_small': lenta_df_small,
    'lenta': lenta_df,
    'any_phrases': any_phrases_df,
    'taiga': taiga_df,
    'sova': sova_df,
}
for name, data in datasets.items():
    print(f'{name} samples: {len(data):,}')
    if not data.empty:
        print(data['label_name'].value_counts())
    print()


### Merge, deduplicate, and derive length stats
- **Concept**: Combine all non-empty sources into a single DataFrame, drop exact text duplicates, and precompute character/word/token-length proxies that downstream EDA and bucketing rely on
- **Details**:
  - We raise if everything is empty to avoid silent failures
  - `pd.concat(sources, ignore_index=True)` forms the unified dataset; `df.duplicated('text')` identifies identical transcripts, log how many rows are removed before resetting the index
  - The length columns (`char_length`, `word_length`, `approx_token_length`) give multiple perspectives on transcript size. 
  - `TOKENS_TO_WORD_RATIO` approximates subword growth so we can reason about the 512-token window limit even before tokenizing. Its' value is approximated from the paper https://aclanthology.org/2021.acl-long.243.pdf Figure 1(b)
- **Why**: Deduping prevents leakage between train/val/test splits (duplicate sentences would otherwise appear in multiple splits). Precomputing lengths lets us analyze coverage without running a tokenizer


In [None]:
TOKENS_TO_WORD_RATIO = 1.3

sources = []
for name, data in datasets.items():
    if data.empty:
        print(f'Warning: dataset {name} is empty and will be skipped')
    else:
        sources.append(data)
if not sources:
    raise ValueError('No data loaded from any source.')

df = pd.concat(sources, ignore_index=True)
duplicates = df.duplicated(subset='text')
print(f"Dropping {duplicates.sum()} duplicate rows based on text column")
df = df[~duplicates].reset_index(drop=True)

df['char_length'] = df['text'].str.len()
df['word_length'] = df['text'].str.split().map(len)
df['approx_token_length'] = (df['word_length'] * TOKENS_TO_WORD_RATIO).round().astype(int)
display(df.head())

### Label-level summary tables
- Produce quick aggregate tables for counts, descriptive statistics, and length (>512 tokens) to understand class balance and sequence-length distribution before modeling.
- **Details**:
  - `length_summary = ...` aggregates mean/median/max across character, word, and approximate token lengths for each label
  - `share_over_512 = ...` computes both absolute counts and percentages of samples whose approximate token length exceeds the 512-token BERT window, per label.
- **Why**: These summaries validate that the merge worked, highlight any major imbalance

In [None]:
label_counts = df.groupby('label_name').size().rename('samples').to_frame()
length_summary = (
    df.groupby('label_name')[['char_length', 'word_length', 'approx_token_length']]
    .agg(['mean', 'median', 'max'])
    .round(2)
)
share_over_512 = (
    df.groupby('label_name')['approx_token_length']
    .agg(
        total='size',
        over_512=lambda s: (s > 512).sum()
    )
)
share_over_512['percent_over_512'] = (share_over_512['over_512'] / share_over_512['total']).round(4) * 100

display(label_counts)
display(length_summary)
display(share_over_512)
print(f'Total samples: {len(df):,}')

Almost 5% of data is over ~512 tokens, which may require special handling during training, window sliding

### Draw statistics and show examples of data

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))
for label_name, subset in df.groupby('label_name'):
    axes[0].hist(subset['word_length'], bins=60, alpha=0.6, label=label_name)
    axes[1].hist(subset['approx_token_length'], bins=60, alpha=0.6, label=label_name)
axes[0].set_title('Word length distribution')
axes[0].set_xlabel('Words per sample')
axes[0].set_ylabel('Count')
axes[1].set_title('Approximate token length distribution')
axes[1].set_xlabel('Approx tokens per sample')
axes[1].set_ylabel('Count')
for ax in axes:
    ax.legend()
plt.tight_layout()

In [None]:
def show_examples(source_df, label_name, n=3, width=120, seed=13):
    subset = source_df[source_df['label_name'] == label_name]
    if subset.empty:
        print(f'No rows for label {label_name}')
        return
    sample = subset.sample(n=min(n, len(subset)), random_state=seed)
    for idx, text in enumerate(sample['text'], start=1):
        print(f'{label_name.upper()} #{idx}')
        print(textwrap.fill(text, width=width))
        print('-' * 80)


show_examples(df, 'formal')
show_examples(df, 'informal')


### Train/val/test split

  - Using `indices` method with `loc` to split data by moving cursor over the data
  - `TRAIN_FRAC = 0.8`, `VAL_FRAC = 0.1`, and `TEST_FRAC = 0.1` keeps most data for learning while reserving enough samples to tune hyperparameters and report final metrics; equal-sized val/test splits make comparisons fair.
  - Shuffle with a fixed seed so every experiment sees the same partition

In [None]:
TRAIN_FRAC = 0.8
VAL_FRAC = 0.1
TEST_FRAC = 0.1

num_samples = len(df)
indices = np.arange(num_samples)
np.random.shuffle(indices)

train_end = int(TRAIN_FRAC * num_samples)
train_idx = indices[:train_end]

val_end = train_end + int(VAL_FRAC * num_samples)
val_idx = indices[train_end:val_end]

test_idx = indices[val_end:]

df['split'] = 'test'
df.loc[train_idx, 'split'] = 'train'
df.loc[val_idx, 'split'] = 'val'
print(f'Train samples: {len(train_idx):,}')
print(f'Val samples: {len(val_idx):,}')
print(f'Test samples: {len(test_idx):,}')

### Split sanity table
- **Why print this**: After slicing the dataset into train/val/test, we immediately inspect the cross-tab of split vs. label to ensure no class vanished in a split.
- **Reading it**: Each row shows how many formal/informal samples ended up in each partition; ideally every split has both labels in roughly expected proportions.

In [None]:
split_counts = (
    df.groupby(['split', 'label_name'])
    .size()
    .rename('samples')
    .to_frame()
    .sort_index()
)
display(split_counts)

for split, group in df.groupby('split'):
    missing = {'formal', 'informal'} - set(group['label_name'])
    if missing:
        print(f'Warning: split {split} missing labels: {missing}')


### Choosing the model, stride and random seed
- **Model selection**: Used some of BERT models pretrained for Russian language understanding because of focus of language's features and sctucture.
- `DOC_STRIDE = 128` - yields ~25 % overlap between 512-token windows, preserving context at chunk boundaries.
- `RANDOM_SEED = 42` - keeps data splits and shuffling deterministic for consistent experiments.


In [None]:
# Possible models list:
## https://huggingface.co/cointegrated/rubert-tiny2, "cointegrated/rubert-tiny2"
## https://huggingface.co/DeepPavlov/rubert-base-cased

MODEL_NAME = "DeepPavlov/rubert-base-cased"
MAX_SEQ_LENGTH = 512
DOC_STRIDE = 128
RANDOM_SEED = 42


### Preparing Hugging Face datasets & metadata
- **Resetting indices**:After we deduplicate and merge everything, we call `df = df.reset_index(drop=True)` and then add `df['example_id'] = df.index`. That “example_id” is just a simple 0…N counter that uniquely tags every original message. Later, when tokenization creates sliding windows with `return_overflowing_tokens=True`, each window inherits the parent example_id
- **Split snapshots**: `split_to_df` stores separate pandas frames for train/val/test, letting us inspect or export them independently before converting to `Dataset` objects
- **`source_to_id` mapping**: Keeping a compact integer id per source helps with analysis (e.g., per-source metrics) without dragging string columns through the trainer; we attach it to every window later
- **`DatasetDict` conversion**: Moving to Hugging Face `Dataset` objects now means tokenization, batching, and trainer hooks all operate on the same structure. Using `preserve_index=False` drops pandas' index to avoid carrying redundant columns


In [None]:
df = df.reset_index(drop=True).copy()
df['example_id'] = df.index
split_to_df = {
    split: df[df['split'] == split].reset_index(drop=True)
    for split in ['train', 'val', 'test']
}
source_to_id = {name: idx for idx, name in enumerate(sorted(df['source'].unique()))}
datasets_dict = DatasetDict({
    split: Dataset.from_pandas(frame, preserve_index=False)
    for split, frame in split_to_df.items()
})
source_to_id

## Tokenizing

#### **Tokenizer configuration**
- **Checkpoint-aware tokenizer**: `AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)` loads the exact subword split rules that match the RuBERT checkpoint so ids align with the encoder. `use_fast=True` is enabled to rely on Hugging Face's Rust-backed tokenizer, which is significantly faster and exposes the `overflow_to_sample_mapping` metadata required for sliding windows

#### **Sliding-window tokenization**
- **Why sliding windows**: Approximately 5% of data could exceed BERT's 512-token limit. Instead of truncating, we split them into overlapping windows so the model sees every part of the message, and we can later aggregate logits per message
- **Parameters explained**:
  - `max_length=MAX_SEQ_LENGTH` caps each window at 512 tokens (BERT's limit)
  - `padding=False` lets us keep variable-length windows; padding happens later in the collator so memory is only spent on the longest items in a batch
  - `truncation=True` ensures the tokenizer doesn't error when a chunk is still longer than 512 (should never happen because we slide by `DOC_STRIDE`, but it's a safeguard)
  - `stride=MAX_SEQ_LENGTH - DOC_STRIDE` creates overlap (e.g., 512-128=384 tokens carried over) so boundary sentences appear fully in at least one window
  - `return_overflowing_tokens=True` emits extra rows for each additional window instead of truncating
  - `return_attention_mask=True` keeps attention masks for later batching; special tokens mask is off because we don't need it
- **Mapping and weights**:
  - `overflow_to_sample_mapping` tells us which original example produced each window. We use it to copy labels, `example_id`, and `source_id` to every window
  - `counts = Counter(mapping)` tracks how many windows each message spawned; we set `sample_weight = 1 / counts[idx]` so each message contributes the same total loss despite having multiple windows
  - `source_id` uses the earlier `source_to_id` lookup so we can track per-source metrics later without string columns in the trainer
- **Flow**: The function returns a flat dict of tokenized windows ready for Hugging Face `Dataset.map`, ensuring the trainer operates on consistent lists (`input_ids`, `attention_mask`, `labels`, `example_id`, `sample_weight`, `source_id`)



In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)


def tokenize_with_windows(examples):
    """
    Tokenize a batch of transcripts using sliding windows so texts longer than
    `MAX_SEQ_LENGTH` are split into overlapping 512-token segments.

    Parameters
    ----------
    examples : dict
        Mini-batch pulled from the Hugging Face `Dataset` containing lists for
        `text`, `label`, `example_id`, and `source`.

    Returns
    -------
    dict
        Flat tokenized batch where every window has:
        - `input_ids` / `attention_mask`
        - `labels` copied from the parent example
        - `example_id` so we can regroup windows later
        - `sample_weight = 1 / num_windows_for_example` so each original message
          contributes the same total loss
        - `source_id` for per-source metrics

    Notes
    -----
    - `return_overflowing_tokens=True` instructs the tokenizer to emit every
      additional window instead of truncating after the first 512 tokens.
    - `stride = MAX_SEQ_LENGTH - DOC_STRIDE` keeps 128 tokens of overlap so
      sentences straddling the boundary appear intact in at least one window.
    - `overflow_to_sample_mapping` reports which original example produced each
      window (e.g., [0, 0, 0, 1, 1, ...]). We pop it from the tokenizer output,
      convert it to a `Counter`, and use it to (a) copy metadata to each window
      and (b) compute the per-window `sample_weight`.
    """
    tokenized = tokenizer(
        examples['text'],
        max_length=MAX_SEQ_LENGTH,
        padding=False,
        truncation=True,
        stride=MAX_SEQ_LENGTH - DOC_STRIDE,
        return_overflowing_tokens=True,
        return_attention_mask=True,
        return_special_tokens_mask=False,
    )

    mapping = tokenized.pop('overflow_to_sample_mapping')
    counts = Counter(mapping)

    tokenized['labels'] = [examples['label'][idx] for idx in mapping]
    tokenized['example_id'] = [examples['example_id'][idx] for idx in mapping]
    tokenized['sample_weight'] = [1.0 / counts[idx] for idx in mapping]
    tokenized['source_id'] = [source_to_id[examples['source'][idx]] for idx in mapping]
    return tokenized

### Tokenized dataset cleanup
- **Why drop columns**: once the text is tokenized there is no longer need the raw `text`, labels duplicated in pandas dtype, or other helper columns in the Hugging Face dataset—the trainer only needs tensors returned from `tokenize_with_windows`, so `remove_columns` keeps memory low and batches lean.
- **`Dataset.map` usage**: applying `tokenize_with_windows` to each split via `.map(..., batched=True, desc='Tokenizing with sliding windows')` ensures the exact same logic (sliding windows, metadata copy, sample weights) runs across train/val/test.
- **Real token length**: immediately after, a second `.map` calculates `length = len(input_ids)` for every window so later steps (length-aware batching, stats) use the true tokenizer length instead of the word-based approximation computed earlier.


In [None]:
columns_to_remove = datasets_dict['train'].column_names
tokenized_datasets = datasets_dict.map(
    tokenize_with_windows,
    batched=True,
    remove_columns=columns_to_remove,
    desc='Tokenizing with sliding windows'
)
tokenized_datasets = tokenized_datasets.map(
    lambda batch: {'length': [len(ids) for ids in batch['input_ids']]},
    batched=True,
    desc='Computing segment lengths'
)
tokenized_datasets

Display metrics of windows

In [None]:
for split, ds in tokenized_datasets.items():
    window_counts = Counter(ds['example_id'])
    avg_windows = sum(window_counts.values()) / len(window_counts)
    max_windows = max(window_counts.values())
    print(f"Split: {split}")
    print(f"  Total windows: {len(ds):,}")
    print(f"  Avg windows per example: {avg_windows:.2f}")
    print(f"  Max windows per example: {max_windows}")
    print(f"  Class distribution: {Counter(ds['labels'])}")
    print()

There is obvious shift towards formal in windows, but not critical

### Reproducibility and casting
- `set_seed(RANDOM_SEED)` freezes Python, NumPy, and PyTorch Random number generations so every run with the same config produces identical splits, window order, and weight initializations
- Enforces column types to avoid future problems

In [None]:
set_seed(RANDOM_SEED)

for split in tokenized_datasets.keys():
    tokenized_datasets[split] = tokenized_datasets[split].cast_column('sample_weight', Value('float32'))
    tokenized_datasets[split] = tokenized_datasets[split].cast_column('labels', Value('int64'))
    tokenized_datasets[split] = tokenized_datasets[split].cast_column('example_id', Value('int64'))
    tokenized_datasets[split] = tokenized_datasets[split].cast_column('source_id', Value('int64'))
    tokenized_datasets[split] = tokenized_datasets[split].cast_column('length', Value('int64'))

## Training

### Label dictionaries
- Providing both `id2label` and `label2id` ensures the model's config names logits consistently for metrics and exporting, and logging `id2label` gives readable class names in reports.

In [None]:
id2label = {0: 'informal', 1: 'formal'}
label2id = {v: k for k, v in id2label.items()}

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(id2label),
    id2label=id2label,
    label2id=label2id,
)

### Freezing the encoder and clear classifier
- Looping over `model.bert.parameters()` and setting `requires_grad=False` keeps the RuBERT encoder fixed; only the classification head update because of limited compute resources.
- `model.classifier.reset_parameters()` wipes the classification head to a fresh random state so it doesn't inherit any biases from previous fine-tunes, giving the frozen encoder a clean readout layer.

In [None]:
for param in model.bert.parameters():
    param.requires_grad = False

model.classifier.reset_parameters()

### WeightedDataCollator logic
- Collator role:  take the list of Python dictionaries that the dataset returns—each dict containing fields like input_ids, attention_mask, labels, + extra metadata we added and stack/pad them into PyTorch tensors so the model can consume the whole batch in one forward pass
---
- **Weighted's goal**: wrap Hugging Face's `DataCollatorWithPadding` so each batch keeps the per-window metadata we computed earlier (`sample_weight`, `example_id`, `source_id`). The base collator only handles padding `input_ids`/`attention_mask`; this wrapper pads first, then reattaches our custom tensors
- **Line by line**:
  - `sample_weight = torch.tensor(...)`: convert the Python floats to a tensor so they can participate in the loss computation. We keep them in the batch dict under `sample_weight` for the custom Trainer
  - `example_id` / `source_id`: same idea—store them as tensors so we can track which original message (and which source dataset) each window belongs to. Useful for aggregation and diagnostics
  - `for f in features: ... pop(...)`: remove the extra keys from each feature dict before calling the base collator. The base collator only expects standard tokenizer outputs; leaving our custom keys would make it crash.
  - `batch = self.base_collator(features)`: pad `input_ids`/`attention_mask` to the longest sequence in the batch using the tokenizer's pad token, returning PyTorch tensors
  - `batch['sample_weight'] = ...`: reattach the metadata tensors so the Trainer can consume them later
- **Why this design**: keeping metadata out of the base collator avoids reimplementing padding logic, while still giving us a batch object that contains everything required for weighted loss and per-example bookkeeping


In [None]:
class WeightedDataCollator:
    def __init__(self, base_collator):
        self.base_collator = base_collator

    def __call__(self, features):
        sample_weight = torch.tensor([f['sample_weight'] for f in features], dtype=torch.float32)
        example_id = torch.tensor([f['example_id'] for f in features], dtype=torch.int64)
        source_id = torch.tensor([f['source_id'] for f in features], dtype=torch.int64)

        for f in features:
            f.pop('sample_weight')
            f.pop('example_id')
            f.pop('source_id')
            f.pop('length', None)

        batch = self.base_collator(features)
        batch['sample_weight'] = sample_weight
        batch['example_id'] = example_id
        batch['source_id'] = source_id
        return batch

base_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest', return_tensors='pt')
data_collator = WeightedDataCollator(base_collator)

### TrainingArguments
- **`output_dir='training/artifacts/rubert_formal_informal'`** – controls where checkpoints/metrics land to be able to get them from there
- **`eval_strategy='epoch'`** – tells Trainer to run validation after each epoch. Because our dataset is mid-sized and we care about message-level F1, per-epoch eval is frequent enough to catch regressions without slowing training too much
- **`save_strategy='epoch'`** – mirrors the eval cadence so every validation pass has a matching checkpoint. This is required for early stopping / best-model logic to align with eval results
- **`logging_strategy='epoch'`** – aggregate logs once per epoch to avoid noisy per-step logs
- **`per_device_train_batch_size=16` / `per_device_eval_batch_size=16`** – batch size per accelerator. On MPS this fits comfortably once the encoder is frozen; it keeps gradient noise low while leaving headroom for sliding windows
- **`gradient_accumulation_steps=2`** – effectively doubles the batch to 32 examples without needing more VRAM. We accumulate two mini-batches before stepping the optimizer so the frozen encoder still sees a stable gradient
- **`learning_rate=2e-5`** – standard fine-tuning LR for BERT-style encoders. Even though we’re training mostly the head, sticking to 2e-5 keeps optimizer behavior familiar and avoids overshooting
- **`num_train_epochs=3`** – typical sweet spot for RuBERT; gives the head enough passes to converge while keeping experiments short (~<6h target). We rely on early stopping if validation stops improving sooner
- **`weight_decay=0.01`** – light L2 regularization to prevent the classifier head from overfitting the training windows. 0.01 is the Transformers default tuned for AdamW
- **`warmup_ratio=0.06`** – gradually ramps LR over the first 6% of training steps to avoid sudden jumps. Helpful when gradients start from scratch due to the reset head
- **`logging_steps=1`** – since we log per-epoch overall, this only affects the internal Trainer logs
- **`save_total_limit=3`** – keep only the latest/best three checkpoints so disk usage stays low while still letting us roll back if needed
- **`load_best_model_at_end=True`** – after training, automatically reloads the checkpoint with the best validation metric; saves us from manually tracking which epoch won
- **`metric_for_best_model='f1_macro'`** / **`greater_is_better=True`** – defines the ranking metric so the best checkpoint is chosen by macro-F1 (balanced between classes), maximizing the score we actually report
- **`fp16=False`** – disable mixed precision; stick to float32 for predictable training
- **`group_by_length=True`** / **`length_column_name='length'`** – batches windows of similar token length to reduce padding waste; uses the true tokenizer length we computed earlier
- **`remove_unused_columns=False`** – keeps `sample_weight`, `example_id`, and `source_id` in the datasets so our custom trainer can access them
- **`seed=RANDOM_SEED`** – duplicates the global randomness lock inside `TrainingArguments` so Trainer-controlled RNG streams stay deterministic
- **`use_mps_device=True`** / **`dataloader_pin_memory=False`** – tells Trainer to prefer Apple’s MPS backend (faster than CPU) and disables pinning since PyTorch warns it isn’t supported on MPS yet


In [None]:
training_args = TrainingArguments(
    output_dir='training/artifacts/rubert_formal_informal',
    eval_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=2,
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    warmup_ratio=0.06,
    logging_steps=1,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model='f1_macro',
    greater_is_better=True,
    fp16=False,
    group_by_length=True,
    length_column_name='length',
    remove_unused_columns=False,
    seed=RANDOM_SEED,
    use_mps_device=True,
    dataloader_pin_memory=False,
)

### Trainer & metrics
- **Why a custom trainer**: sliding windows mean one original message can produce 3–5 training rows. Without weighting, those rows would contribute 3–5× the gradient of a short message. Subclassing `Trainer` and overriding `compute_loss` lets the model consume the `sample_weight` tensor so each original message still counts as 1
- **Metadata handling**: `inputs` initially contains `input_ids`, `attention_mask`, `labels`, plus metadata (`sample_weight`, `example_id`, `source_id`, `length`). The overridden `compute_loss` pops the metadata keys before calling the base model (it would error on unknown kwargs), but holds onto them for loss weighting or later analysis
- **Loss computation (example)**: suppose `example_id 7` was split into 3 windows, so each window’s `sample_weight` is `1/3 ≈ 0.33`. Cross-entropy is computed per window (`reduction='none'`), then multiplied by 0.33 before averaging the batch. The sum of the three weighted losses equals what a single window would have contributed, keeping the gradient per message constant
- **`compute_metrics`**: converts logits to predictions and reports accuracy + macro/weighted F1. Macro-F1 is optimized because it forces balanced performance on formal and informal texts even when the dataset skews
- **Trainer wiring**: `WeightedLossTrainer(...)` ties together the model, arguments, datasets, tokenizer, collator, and metrics so the weighted loss/metadata flow happens automatically each step. An early-stopping callback monitors validation macro-F1 and halts if it stalls for two epochs, keeping training efficient

In [None]:
class WeightedLossTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        sample_weight = inputs.pop('sample_weight')
        inputs.pop('example_id', None)
        inputs.pop('source_id', None)
        inputs.pop('length', None)
        labels = inputs.pop('labels')

        outputs = model(**inputs)
        logits = outputs.logits

        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        losses = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))
        weighted_loss = (losses * sample_weight.to(losses.device)).mean()

        if return_outputs:
            return weighted_loss, outputs
        return weighted_loss


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        'accuracy': accuracy_score(labels, preds),
        'f1_macro': f1_score(labels, preds, average='macro'),
        'f1_weighted': f1_score(labels, preds, average='weighted'),
    }

- **`WeightedLossTrainer(...)` arguments**:
  - `model`: the frozen RuBERT encoder plus reinitialized classifier head
  - `args`: the `TrainingArguments` block that governs batching, logging, saving; passing it in ensures the Trainer obeys those settings
  - `train_dataset` / `eval_dataset`: the tokenized Hugging Face splits, allowing the Trainer to iterate and evaluate over it
  - `processing_class=tokenizer`: bundles the tokenizer with checkpoints so reloading the model later recreates the same subword vocab
  - `data_collator`: supplies the padded batches with `sample_weight`/`example_id` tensors that the custom loss expects.
  - `compute_metrics`: hooks up the accuracy/F1 reporter for evaluation
- **Early stopping**: `trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=2))` stops training if validation macro-F1 fails to improve for two evaluation cycles, keeping runtime bounded and preventing the head from overfitting


In [None]:
trainer = WeightedLossTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['val'],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)
trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=2))

In [None]:
# test_metrics = trainer.evaluate(tokenized_datasets['test'])
# trainer.log_metrics('test', test_metrics)
# trainer.save_metrics('test', test_metrics)
# test_metrics

# val_preds = trainer.predict(tokenized_datasets['val'])
# test_preds = trainer.predict(tokenized_datasets['test'])

### Training loop
- `train_result = trainer.train()` launches fine-tuning with all previously defined wiring (sliding windows, weighted loss, early stopping)
- Saving the model, logging metrics, and persisting the Trainer state immediately afterwards captures everything needed to resume or review the run later


In [None]:
train_result = trainer.train()
trainer.save_model()
train_metrics = train_result.metrics
trainer.log_metrics('train', train_metrics)
trainer.save_metrics('train', train_metrics)
trainer.save_state()
train_metrics

### Validation & test evaluation
- `trainer.evaluate(tokenized_datasets['val'])` runs the validation split using the best checkpoint (because `load_best_model_at_end=True`), logging and saving the raw validation metrics over windows
- After that call `trainer.predict(tokenized_datasets['val'])` does the same on the test split, returning raw logits and labels for window-level predictions. We need them o regroup sliding windows into message-level predictions
- The following code repeats the process on the test split

In [None]:
val_metrics = trainer.evaluate(tokenized_datasets['val'])
trainer.log_metrics('eval', val_metrics)
trainer.save_metrics('eval', val_metrics)
val_metrics
val_preds = trainer.predict(tokenized_datasets['val'])

test_metrics = trainer.evaluate(tokenized_datasets['test'])
trainer.log_metrics('test', test_metrics)
trainer.save_metrics('test', test_metrics)
test_metrics

test_preds = trainer.predict(tokenized_datasets['test'])

### Message-level aggregation & reporting
- **Goal**: convert per-window predictions back into per-message metrics so evaluation reflects the original transcripts, not individual sliding windows
- **`aggregate_message_logits`**
  - Inputs: raw window logits (`logits`), true labels (`labels`), and `example_ids` identifying the parent message for each window
  - It buckets windows by `example_id`, stacks their logits, and pools them (`mean` by default, `max` optional) to obtain **one logit vector per message**. The function returns three arrays: pooled logits, the true label per message, and the message ids
  - Example: if message produced three windows, the bucket contains three logits. `np.vstack` stacks them, `mean` averages across rows → resulting pooled logit is used to decide the single prediction for this message
- **`report_message_metrics`**
  - Calls `aggregate_message_logits`, converts logits to predictions via `argmax`, and prints a `classification_report` plus confusion matrix labeled with `id2label`. This gives accuracy/F1 **per original message**
  - Returns a dictionary containing the per-message logits, labels, predicted labels, example ids, and confusion matrix—handy for further analysis or saving to disk
- **Usage**: runs `report_message_metrics` on both validation and test splits immediately after `trainer.predict`. The window-level metrics from `trainer.evaluate` are quick checks; these aggregated numbers are the authoritative scores for model selection and reporting


In [None]:
def aggregate_message_logits(logits, labels, example_ids, reduction='mean'):
    buckets = defaultdict(list)
    label_map = {}
    for logit, label, eid in zip(logits, labels, example_ids):
        buckets[int(eid)].append(logit)
        label_map[int(eid)] = int(label)
    agg_logits = []
    agg_labels = []
    agg_ids = []
    for eid, chunk in buckets.items():
        stack = np.vstack(chunk)
        if reduction == 'max':
            pooled = stack.max(axis=0)
        else:
            pooled = stack.mean(axis=0)
        agg_logits.append(pooled)
        agg_labels.append(label_map[eid])
        agg_ids.append(eid)
    return np.vstack(agg_logits), np.array(agg_labels), np.array(agg_ids)


def report_message_metrics(split_name, dataset, preds_output):
    logits, labels, example_ids = (
        preds_output.predictions,
        preds_output.label_ids,
        dataset['example_id'],
    )
    msg_logits, msg_labels, msg_ids = aggregate_message_logits(logits, labels, example_ids)
    msg_pred_labels = msg_logits.argmax(axis=-1)
    print(f"Message-level metrics for {split_name} split")
    print(classification_report(msg_labels, msg_pred_labels, target_names=[id2label[0], id2label[1]]))
    cm = confusion_matrix(msg_labels, msg_pred_labels)
    print(cm)
    return {
        'logits': msg_logits,
        'labels': msg_labels,
        'pred_labels': msg_pred_labels,
        'example_ids': msg_ids,
        'confusion_matrix': cm,
    }
    

val_message_stats = report_message_metrics('val', tokenized_datasets['val'], val_preds)
test_message_stats = report_message_metrics('test', tokenized_datasets['test'], test_preds)

### Summarizing test metrics
- Uses `classification_report` on the message-level predictions to produce precision/recall/F1 per class, then converts it to a tidy DataFrame for inspection or logging. This is the aggregate score to record for the test split


In [None]:
test_report = classification_report(
    test_message_stats['labels'],
    test_message_stats['pred_labels'],
    target_names=[id2label[0], id2label[1]],
    output_dict=True,
)
pd.DataFrame(test_report).T

### Inspecting per-message predictions
- Builds a DataFrame from `test_message_stats` so each row corresponds to an original transcript with its true label, predicted label, and aggregated class probabilities
- Displaying `.head()` provides a quick sanity check on the aggregated outputs

In [None]:
results_df = pd.DataFrame({
    'example_id': test_message_stats['example_ids'],
    'true_label': [id2label[i] for i in test_message_stats['labels']],
    'pred_label': [id2label[i] for i in test_message_stats['pred_labels']],
    'P(informal)': torch.softmax(torch.tensor(test_message_stats['logits']), dim=-1).numpy()[:, 0],
    'P(formal)': torch.softmax(torch.tensor(test_message_stats['logits']), dim=-1).numpy()[:, 1],
})
results_df.head()