# Week 8: Applications of BERT

In this notebook we explore practical ways to use BERT-style encoders:

1. **Masked Language Modeling (MLM):** We will demonstrate two ways to use this option:
    - **Single-word** masked prediction:** Predict the most likely token for `[MASK]` using context on both sides.
    - **Multiple masks:** See why independent predictions aren’t coordinated, then try a left-to-right “coordinated fill” to produce coherent combinations.
2. **Sentiment classification:** Fine-tune a compact BERT (DistilBERT) on IMDB movie reviews and visualize training/validation curves.
3. **Relationship of sentences:** Score whether sentence **B** plausibly follows sentence **A** using BERT’s Next Sentence Prediction (NSP) head.


In [1]:
# Imports

# --- Standard Library ---
import os
import random
import time
import math
from pathlib import Path
from collections import Counter

# --- Third-Party Libraries ---
import numpy as np
import matplotlib.pyplot as plt
import sklearn
from sklearn.model_selection import train_test_split

# --- TensorFlow / Keras ---
import tensorflow as tf
from tensorflow.keras import Input, Sequential, layers, models
from tensorflow.keras.datasets import imdb
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import CosineDecay, ExponentialDecay
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau


# Layers
from tensorflow.keras.layers import (
    Embedding, GlobalAveragePooling1D, Dense, 
    LSTM, GRU, Dropout, SpatialDropout1D, Bidirectional, Lambda
)

# Preprocessing
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer

# --- Reproducibility Settings ---
random_seed = 42

'''
# OS-level determinism
os.environ['PYTHONHASHSEED'] = '0'        # Disable hash randomization
os.environ['TF_DETERMINISTIC_OPS'] = '1'  # Make TF ops deterministic (where possible)
os.environ['TF_CUDNN_DETERMINISM'] = '1'  # CuDNN deterministic (if using GPU)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Suppress TensorFlow INFO and WARNING messages

# Set seeds
random.seed(random_seed)
np.random.seed(random_seed)
tf.random.set_seed(random_seed)
'''

# --- Utility Function ---
def format_hms(seconds):
    return time.strftime("%H:%M:%S", time.gmtime(seconds))

'''
# Example usage for timing code:

start_time = time.time()

# <your code here>

print("Execution Time:", format_hms(time.time() - start_time))
'''
None

### Logging and Displaying Results

In [None]:
results = {}

def print_metrics():
    # Print results sorted by accuracy (highest → lowest)

    for title, (acc, ep) in sorted(results.items(), 
                                   key=lambda kv: kv[1][0],   # kv[1] is (acc, epoch); [0] is acc
                                   reverse=True):
        print(f"{title:<40}\t{acc:.4f} @ {ep}")


In [3]:
def plot_learning_curves(hist, 
                         title="Learning Curves", 
                         best_epoch=None,
                         verbose=True,
                         lr_logger=None       # NEW: pass LRSchedulerLogger callback
                        ):
    val_losses = hist.history['val_loss']
    
    # Determine best epoch (min val_loss or provided)
    if best_epoch is not None:
        min_val_epoch = best_epoch
    else:
        min_val_epoch = val_losses.index(min(val_losses))
        
    val_loss_at_min_loss = hist.history['val_loss'][min_val_epoch]    
    val_acc_at_min_loss = hist.history['val_accuracy'][min_val_epoch]

    n_epochs = len(val_losses)
    epochs = list(range(1, n_epochs + 1))

    # Tick interval for x-axis
    tick_interval = max(1, n_epochs // 20)
    xticks = list(range(0, n_epochs + 1, tick_interval))

    # Figure layout: add 3rd subplot if LR is provided
    n_subplots = 3 if lr_logger is not None else 2
    plt.figure(figsize=(8, 10 if n_subplots == 3 else 8))

    # --- Loss Plot ---
    plt.subplot(n_subplots, 1, 1)
    plt.plot(epochs, hist.history['loss'], label='Train Loss')
    plt.plot(epochs, hist.history['val_loss'], label='Val Loss')
    plt.scatter(min_val_epoch + 1, val_loss_at_min_loss, color='red', marker='x', s=50, label='Min Val Loss') 
    plt.title(f'{title} - Loss')
    plt.ylabel('Loss')
    plt.xticks(xticks)
    plt.legend()
    plt.grid(True)

    # --- Accuracy Plot ---
    plt.subplot(n_subplots, 1, 2)
    plt.plot(epochs, hist.history['accuracy'], label='Train Accuracy')
    plt.plot(epochs, hist.history['val_accuracy'], label='Val Accuracy')
    plt.scatter(min_val_epoch + 1, val_acc_at_min_loss, color='red', marker='x', s=50, label='Acc @ Min Val Loss')
    plt.title(f'{title} - Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.xticks(xticks)
    plt.legend()
    plt.grid(True)

    # --- LR Plot (if provided) ---
    if lr_logger is not None:
        plt.subplot(n_subplots, 1, 3)
        plt.plot(epochs, lr_logger.lrs, color='gray', label='Learning Rate')
        plt.title(f'{title} - Learning Rate Schedule')
        plt.xlabel('Epoch')
        plt.ylabel('Learning Rate')
        plt.xticks(xticks)
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.show()

    # Print summary
    if verbose:
        print(f"Final Training Loss:            {hist.history['loss'][-1]:.4f}")
        print(f"Final Training Accuracy:        {hist.history['accuracy'][-1]:.4f}")
        print(f"Final Validation Loss:          {hist.history['val_loss'][-1]:.4f}")
        print(f"Final Validation Accuracy:      {hist.history['val_accuracy'][-1]:.4f}")
        print(f"Minimum Validation Loss:        {val_loss_at_min_loss:.4f} (Epoch {min_val_epoch + 1})")
        print(f"Validation Accuracy @ Min Loss: {val_acc_at_min_loss:.4f}")

    # Store result in global results dict
    results[title] = (val_acc_at_min_loss, min_val_epoch + 1)


### LR Schedulers and Callbacks

In [4]:
# Reduce the learning rate by half if val_loss does not improve for 5 epochs,
# but never go below 1e-7
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',    # Quantity to be monitored.
    factor=0.5,            # Factor by which the learning rate will be reduced.
                           # new_lr = lr * factor
    patience=5,            # Number of epochs with no improvement
                           # after which learning rate will be reduced.
    min_delta=1e-5,        # Threshold for measuring the new optimum,
                           # to only focus on significant changes.
    cooldown=0,            # Number of epochs to wait before resuming
                           # normal operation after lr has been reduced.
    min_lr=1e-8,           # Lower bound on the learning rate.
    verbose=0,             # 0: quiet, 1: update messages.
)

# Copy this into cell where X_train, epochs, and batch_size are defined
'''
exp_decay = ExponentialDecay(
    initial_learning_rate = 0.00001,  # ─ the starting learning rate (before any decay)
    decay_steps = epochs * int(np.ceil(len(X_train) / batch_size)),  # ─ how many training steps (batches) before decay rate % is reached
    decay_rate  = 0.9,                # ─ target % of lr after decay_steps steps (training of one batch)
    staircase   = False,              # ─ if True, decay in discrete intervals (floor(step/decay_steps)),
                                      #   if False, decay smoothly each step
)
'''

# Ditto, copy this
'''
cosine_decay = CosineDecay(
    initial_learning_rate=0.00001,   # ─ the starting learning rate
    decay_steps = epochs * int(np.ceil(len(X_train) / batch_size)),            # ─ number of training steps (batches) over which to decay
    alpha=0.0,                    # ─ minimum learning rate value as a fraction of initial_learning_rate
                                  #    (final_lr = initial_lr * alpha)
)
'''

# Used to display learning rate plot

class LRSchedulerLogger(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.lrs = []

    def on_epoch_end(self, epoch, logs=None):
        # Get the optimizer's current LR (handles schedules automatically)
        lr = self.model.optimizer.learning_rate
        if isinstance(lr, tf.keras.optimizers.schedules.LearningRateSchedule):
            current_step = tf.cast(self.model.optimizer.iterations, tf.float32)
            lr = lr(current_step).numpy()
        else:
            lr = tf.keras.backend.get_value(lr)
        self.lrs.append(lr)



### Train and Test Wrapper

In [5]:
import time, numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam    # tf.keras, not keras
from tensorflow.keras.callbacks  import EarlyStopping, Callback

# If you have a custom logger, make sure it subclasses tf.keras.callbacks.Callback
class LRSchedulerLogger(Callback):
    _supports_tf_logs = True
    def on_train_begin(self, logs=None):
        self.lrs = []
    def on_batch_end(self, batch, logs=None):
        try:
            lr = tf.keras.backend.get_value(self.model.optimizer.learning_rate)
        except Exception:
            lr = getattr(self.model.optimizer, "learning_rate", None)
            if hasattr(lr, "numpy"): lr = lr.numpy()
        self.lrs.append(float(lr) if lr is not None else None)

def train_and_test(model, X_train, y_train, X_test, y_test,
                   title         = "Learning Curves",
                   epochs        = 200,
                   lr_schedule   = 0.0001,
                   optimizer     = "Adam",
                   loss          = "binary_crossentropy",
                   batch_size    = 256,
                   use_early_stopping = True,
                   patience      = 10,
                   min_delta     = 1e-4,
                   log_learning_rate = True,
                   callbacks     = None,          # <- FIX: default None (no shared list)
                   verbose       = 0,
                   return_history = True
                  ):
    print(f"{title}:  ", end="")

    # fresh callbacks list every call
    callbacks = list(callbacks) if callbacks is not None else []

    # Optimizer (ensure tf.keras)
    if isinstance(optimizer, str):
        if optimizer.lower() == "adam":
            opt = Adam(learning_rate=lr_schedule)
        else:
            opt = optimizer
    else:
        opt = optimizer

    # Compile
    model.compile(optimizer=opt, loss=loss, metrics=["accuracy"])

    # Install callbacks
    if use_early_stopping:
        early_stop = EarlyStopping(
            monitor="val_loss", patience=patience, min_delta=min_delta,
            restore_best_weights=True, verbose=verbose
        )
        callbacks.append(early_stop)

    lr_logger = None
    if log_learning_rate:
        lr_logger = LRSchedulerLogger()
        callbacks.append(lr_logger)

    start = time.time()

    history = model.fit(
        X_train, y_train,
        epochs=epochs,
        batch_size=batch_size,
        validation_split=0.2,
        callbacks=callbacks,
        verbose=verbose
    )

    best_epoch = int(np.argmin(history.history["val_loss"]))
    best_acc   = float(history.history["val_accuracy"][best_epoch])

    # Your existing plotting helper
    plot_learning_curves(history, title=title, lr_logger=lr_logger)

    test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=verbose)

    print(f"\nTest Accuracy: {test_accuracy:.4f}")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"\nValidation-Test Gap (accuracy): {abs(best_acc - test_accuracy):.6f}")

    end = time.time()
    print(f"\nExecution Time: {format_hms(end-start)}")

    if return_history:
        return history


In [6]:
# Colab/local once per runtime (safe to re-run)
# !pip -q install "tensorflow==2.17.0" "tf-keras==2.17.0" "transformers==4.44.2"

import os
os.environ["TRANSFORMERS_NO_TORCH"] = "1"   # keep backend on TF for consistency


In [7]:
from transformers import logging
logging.set_verbosity_error()

import warnings
warnings.filterwarnings("ignore", module="transformers")

## 1. Basic Masked Language Modeling with BERT

We now demonstrate how a BERT-style encoder predicts missing words marked as `[MASK]`.

**Notes:**
- BERT is **bidirectional** — it looks at the words **before and after** the mask and picks the most likely token in that position.
- **Why this matters:** MLM trains encoders to learn syntax (e.g., subject–verb agreement) and semantics (word sense from context), which makes them useful for classification, QA, etc. Note that it does **not** do any kind of reasoning; it simply calculates probabilities based on its training, much as your N-Gram language generation code did. 


### Filling in a Single Mask

In [8]:
from transformers import pipeline
fill = pipeline("fill-mask", model="bert-base-uncased", framework="tf")
fill.tokenizer.clean_up_tokenization_spaces = True  # nicer printing

def demo(s, k=3):
    preds = fill(s, top_k=k)
    print("\n" + s + "\n")
    print(f"{'Word':>12} | {'Probability':>12}")
    print("-" * 27)
    for p in preds:
        print(f"{p['token_str'].strip():>12} | {p['score']:.4f}")


demo("The keys [MASK] on the table.", 3)    
demo("The key  [MASK] on the table.", 3)   
demo("He went to the [MASK] to withdraw cash.", 5)  
demo("The fish swam near the [MASK] of the river.", 5) 




The keys [MASK] on the table.

        Word |  Probability
---------------------------
        were | 0.5268
         lay | 0.1232
         are | 0.0349

The key  [MASK] on the table.

        Word |  Probability
---------------------------
         was | 0.4977
         lay | 0.1415
      rested | 0.0636

He went to the [MASK] to withdraw cash.

        Word |  Probability
---------------------------
        bank | 0.9382
       store | 0.0098
         atm | 0.0073
       banks | 0.0036
      office | 0.0033

The fish swam near the [MASK] of the river.

        Word |  Probability
---------------------------
       mouth | 0.3574
        edge | 0.1705
        bank | 0.1014
       banks | 0.0540
      bottom | 0.0447


#### We can also score a list of suggested fills:

In [20]:
text = "A cat is a type of [MASK]."

res = fill(text,
           targets=["animal", "vehicle", "color"])

print(text,"\n")
print(f"{'Word':>10} | {'Probability':>12}")
print("-" * 25)
for p in res:
    print(f"{p['token_str'].strip():>10} | {p['score']:.6f}")

A cat is a type of [MASK]. 

      Word |  Probability
-------------------------
    animal | 0.024314
   vehicle | 0.000027
     color | 0.000005


### Multiple Masks in One Sentence (Unlinked Predictions)

> When a sentence contains more than one `[MASK]` token, the BERT *fill-mask* pipeline predicts each mask **independently**, keeping the others as `[MASK]`.
> This means that the predictions for Mask 1, Mask 2, and Mask 3 are **not coordinated** — each list shows the most likely words *in that position alone*, not combinations that make a coherent sentence.

So in the example below, the model gives separate lists of possibilities for each `[MASK]`, but they don’t necessarily form meaningful sequences together.


In [10]:
text = "The [MASK] of The United [MASK] lives in the [MASK] House."
pred_lists = fill(text, top_k=5)  # list of lists (one per mask)

print(text)
for i, plist in enumerate(pred_lists, 1):
    words = [p["token_str"].strip() for p in plist]
    print(f"Mask {i}: {words}")


The [MASK] of The United [MASK] lives in the [MASK] House.
Mask 1: ['president', 'embassy', 'queen', 'ambassador', 'government']
Mask 2: ['kingdom', 'states', 'nations', 'nation', 'emirates']
Mask 3: ['white', 'same', 'guest', 'opera', 'glass']


### Multiple Masks in One Sentence: Linked

To generate **coherent replacements** across several `[MASK]` tokens, we can perform a **left-to-right coordinated fill** using a small *beam search*.

This approach fills one mask at a time, using the best-scoring partial sentences to guide the next prediction.
It works for Hugging Face pipeline outputs whether they return a **flat list** (one mask) or a **list of lists** (multiple masks), and it returns a ranked list of completed sentences with their associated probabilities.


In [11]:
import math
import numpy as np

def coordinated_fill(fill_pipeline, text, beam_size=5, per_mask_k=5):
    MASK = fill_pipeline.tokenizer.mask_token
    beams = [(text, 0.0)]  # (current_text, logprob)

    while MASK in beams[0][0]:
        new_beams = []
        for s, lp in beams:
            preds = fill_pipeline(s, top_k=per_mask_k)

            # Normalize shapes: if multiple masks, HF may return list-of-lists; take first mask's list.
            if isinstance(preds, list) and preds and isinstance(preds[0], list):
                preds = preds[0]

            for p in preds:
                token = p["token_str"].strip()
                score = float(p["score"])
                # Replace only the FIRST [MASK]
                s_next = s.replace(MASK, token, 1)
                new_beams.append((s_next, lp + math.log(max(score, 1e-12))))

        # Keep top `beam_size` hypotheses
        new_beams.sort(key=lambda t: t[1], reverse=True)
        beams = new_beams[:beam_size]

    return beams

# Example
text = "The [MASK] of the United [MASK] lives in the [MASK] House."

print(text,"\n")
joint = coordinated_fill(fill, text, beam_size=5, per_mask_k=5)
for s, lp in joint:
    print(f"{s}   (p={np.exp(lp):.2f})")


The [MASK] of the United [MASK] lives in the [MASK] House. 

The president of the United states lives in the white House.   (p=0.54)
The president of the United states lives in the state House.   (p=0.03)
The president of the United states lives in the guest House.   (p=0.03)
The president of the United states lives in the glass House.   (p=0.01)
The president of the United states lives in the washington House.   (p=0.01)


## 2. Text Classification using BERT

We return to the IMBD Movie Review dataset, to show how BERT can be used as a text classifier,  using a BERT model that’s already set up for classification:
`TFAutoModelForSequenceClassification`. It’s the TensorFlow/Keras wrapper that adds a small classification head on top of BERT (dropout + dense layer on the [CLS] representation) and returns logits for your labels.

We'll retrain BERT using our dataset, which unfreezes every layer, including the small classification head. 

Here are the steps we'll follow:

1. Tokenize texts → tensors

    - Convert raw reviews to input_ids and attention_mask with a BERT tokenizer.
    
    - Use truncation=True, a sensible max_length (128 is plenty for IMDB), and pad to that length.

2. Model

    - Load TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2).
    
    - This attaches the classification head; no manual head building needed.

3. Loss/metrics

    - Use SparseCategoricalCrossentropy(from_logits=True) (because the model returns logits).
    
    - Track accuracy. (You can add F1 later for reports.)

4. Optimizer & LR schedule

    - Use HF’s create_optimizer to get AdamW + warmup (good defaults for transformers).

5. Train/eval loop

    - Use plain Keras model.fit(...) (works with dict inputs).
    
    - Keep your existing train_and_test(...) plotting and early-stopping style.

6. Inference

    - argmax(logits, -1) → predicted class (0=NEG, 1=POS).

### Tokenize IMDB → NumPy arrays (works with validation_split)

In [13]:
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


MODEL = "distilbert-base-uncased"
MAX_LEN = 128
batch_size = 256  # chunk size for tokenization to keep memory sane

ds = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained(MODEL)

def to_str_list(x):
    # Ensure we pass a plain Python list of strings to the tokenizer
    if isinstance(x, (list, tuple)):
        return list(x)
    try:
        return list(x)  # works for HF Arrow arrays, pandas Series, etc.
    except TypeError:
        # single string?
        return [str(x)]

def tokenize_to_np(texts, max_len=MAX_LEN, batch_size=batch_size):
    texts = to_str_list(texts)
    all_input_ids = []
    all_attn = []
    for i in range(0, len(texts), batch_size):
        chunk = texts[i:i+batch_size]
        enc = tokenizer(
            chunk,
            truncation=True,
            padding="max_length",
            max_length=max_len,
            return_tensors="np",
        )
        all_input_ids.append(enc["input_ids"])
        all_attn.append(enc["attention_mask"])
    input_ids = np.concatenate(all_input_ids, axis=0)
    attention_mask = np.concatenate(all_attn, axis=0)
    return {"input_ids": input_ids, "attention_mask": attention_mask}

# Pull texts/labels (can subselect for speed while testing)
train_texts = ds["train"]["text"]
train_labels = ds["train"]["label"]
test_texts  = ds["test"]["text"]
test_labels = ds["test"]["label"]

X_train = tokenize_to_np(train_texts)
y_train = np.array(train_labels, dtype="int32")

X_test  = tokenize_to_np(test_texts)
y_test  = np.array(test_labels, dtype="int32")


### Build a TF/Keras classifier (already a Keras `Model`)

In [14]:
from tensorflow.keras import mixed_precision
from transformers import TFAutoModelForSequenceClassification

mixed_precision.set_global_policy("mixed_float16")

model = TFAutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2
)


### Train and Test

Use a **logits**-aware loss; everything else can stay with your defaults.

In [None]:
from transformers import create_optimizer

# HF optimizer (tf.keras-compatible) – optional; you can also just pass optimizer="adam"
epochs = 5
steps_per_epoch   = (int(0.8 * X_train["input_ids"].shape[0]) // batch_size)
total_train_steps = max(1, steps_per_epoch * epochs)
warmup_steps      = int(0.1 * total_train_steps)

optimizer, _ = create_optimizer(
    init_lr=2e-5, num_warmup_steps=warmup_steps, num_train_steps=total_train_steps
)

hist = train_and_test(
    model,
    X_train, y_train,
    X_test,  y_test,
    title="IMDB — DistilBERT (TF/Keras)",
    epochs=epochs,
    optimizer=optimizer,  # or "adam" and set lr_schedule=2e-5
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    batch_size=batch_size,
    use_early_stopping=False,    # OFF for first run
    log_learning_rate=False,     # OFF for first run
    callbacks=[],                # ensure empty list
    verbose=1,
    return_history=True
)


In [None]:
from sklearn.metrics import classification_report
logits = model.predict(X_test, verbose=0).logits
y_pred = logits.argmax(axis=-1)
print(classification_report(y_test, y_pred, target_names=["NEG","POS"], digits=3))


## 3. Classifying the Relationship of Sentences with BERT

We use **BERT’s Next Sentence Prediction (NSP)** head to score whether a sentence **B** plausibly follows sentence **A**.

- **Input:** a pair of sentences *(A, B)* tokenized together.
- **Model:** `BertForNextSentencePrediction` (or TF: `TFBertForNextSentencePrediction`).
  - Warning: `DistilBERT` does **not** have an NSP head—use full BERT (e.g., `bert-base-uncased`).
- **Output:** two logits for classes **[IS_NEXT, NOT_NEXT]**; we apply softmax to get  
  $ p(\text{IS\_NEXT}\mid A,B) $ and $ p(\text{NOT\_NEXT}\mid A,B) $.

**How to interpret:**  
- Higher $ p(\text{IS\_NEXT}) $ means B is a plausible continuation of A (per BERT’s pretraining).  
- Scores are **not perfectly calibrated**; NSP tends to be overconfident for **topically related** pairs.

**Good demo patterns:**
- Compare the true continuation vs. a same-topic-but-wrong sentence vs. an off-topic sentence, and **rank by $ p(\text{IS\_NEXT}) $**.
- For clearer behavior, construct **hard negatives** (shuffle within the same paragraph) rather than random sentences.

**When to fine-tune instead:**  
If you need “nextness” that matches your domain or dataset, fine-tune a small classifier on your own consecutive vs. mismatched pairs rather than relying on the pretrained NSP head.



In [17]:
import nltk

# Make sure both resources are present
for pkg, handle in [
    ("punkt", "tokenizers/punkt"),
    ("punkt_tab", "tokenizers/punkt_tab"),
]:
    try:
        nltk.data.find(handle)
    except LookupError:
        nltk.download(pkg)


In [18]:
# !pip -q install transformers torch

import torch
from transformers import AutoTokenizer, AutoModelForNextSentencePrediction

# BERT checkpoints include an NSP head; DistilBERT does NOT.
MODEL = "bert-base-uncased"

tok = AutoTokenizer.from_pretrained(MODEL)
nsp = AutoModelForNextSentencePrediction.from_pretrained(MODEL)
nsp.eval()  # inference mode

def next_sentence_prob(s1: str, s2: str):
    """
    Returns: (p_is_next, p_not_next, predicted_label)
    """
    inputs = tok(s1, s2, return_tensors="pt", truncation=True, max_length=256)
    with torch.no_grad():
        logits = nsp(**inputs).logits  # [batch=1, 2] -> [IsNext, NotNext]
        probs = torch.softmax(logits, dim=-1).squeeze(0).tolist()
    p_is_next, p_not_next = probs[0], probs[1]
    label = "IS_NEXT" if p_is_next >= p_not_next else "NOT_NEXT"
    return p_is_next, p_not_next, label


In [23]:
def show_candidates(s1, s2_list):
    rows = []
    for s2 in s2_list:
        p_next, p_not, lab = next_sentence_prob(s1, s2)
        rows.append((s2, p_next, lab))
    rows.sort(key=lambda r: r[1], reverse=True)
    print(f"Context: {s1}\n")
    print(f"{'Candidate':<80} | {'p(IS_NEXT)':>10} | Label")
    print("-"*110)
    for s2, p, lab in rows:
        print(f"{s2[:80]:<80} | {p:10.6f} | {lab}")

s1 = "Natural language processing models are widely used in industry."
cands = [
    "They power search, chatbots, and content moderation.",             # true next
    "Boosting methods like XGBoost focus on hard-to-predict examples.", # same topic
    "The Eiffel Tower is located in Paris, France.",                    # off-topic
    "The President announced a new round of tariffs.",
    "But the President tweeted something about AI, so the future of NLP in industry is unclear.",
    "But the President tweeted something, so the future of NLP in industry is unclear.",
    "So the class is quite popular.",
    "The coal industry, however, is in decline and does not use these techniques.",
    "The coal industry, however, does not.",
    "Most industries are affected.",
    "Most industries are affected by this trend.",
    "And most students want to learn them for this reason.",
    "But the future of AI is uncertain."
]
show_candidates(s1, cands)


Context: Natural language processing models are widely used in industry.

Candidate                                                                        | p(IS_NEXT) | Label
--------------------------------------------------------------------------------------------------------------
Boosting methods like XGBoost focus on hard-to-predict examples.                 |   0.999970 | IS_NEXT
They power search, chatbots, and content moderation.                             |   0.999966 | IS_NEXT
But the future of AI is uncertain.                                               |   0.999946 | IS_NEXT
And most students want to learn them for this reason.                            |   0.999928 | IS_NEXT
The coal industry, however, is in decline and does not use these techniques.     |   0.998985 | IS_NEXT
But the President tweeted something about AI, so the future of NLP in industry i |   0.992282 | IS_NEXT
So the class is quite popular.                                                   |   0.03