<sub>Developed by SeongKu Kang, August 2025 â€” Do not distribute</sub>

# ðŸ“˜ Guide to Using BERT

To use BERT for classification, we need to understand two key steps:  
1. **How raw text is converted into model inputs**  
2. **How to interpret and leverage BERT outputs**


---

## [Part A] Understanding BERT Input
This notebook focuses on the preprocessing step with Hugging Face's `AutoTokenizer`.  
You will learn how to obtain the following inputs required by BERT:

- `input_ids`: tokenized text represented as integer IDs  
- `attention_mask`: binary mask (1 for real tokens, 0 for padding)  

This illustrates the transformation from **raw string â†’ tokenized IDs â†’ BERT input tensors**.

### ðŸ”Ž BERT Input Example

When we feed raw texts into BERT, the tokenizer automatically performs several steps:

1. **Add special tokens**  
   - `[CLS]` (101) is added at the beginning of every sequence. It serves as a special representation for the whole sequence.  
   - `[SEP]` (102) is added at the end of every sequence to mark separation (even for single sentences).

2. **Convert tokens to IDs**  
   - Each token is mapped to its corresponding integer ID in the BERT vocabulary.

3. **Handle variable lengths**  
   - Since sentences can have different lengths, BERT requires a fixed input length.  
   - The tokenizer uses **padding (0)** to make all sequences the same length.  
   - An **attention mask** is created, where `1` indicates real tokens and `0` indicates padding.  

In [1]:
from utils import *
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Example texts with different lengths
texts = [
    "Knitting hooks",
    "Crochet hooks for beginners with ergonomic handle",
    "Sewing and knitting accessories including needles, hooks, and more"
]

# Tokenize with padding & truncation
encoded = tokenizer(
    texts,
    padding="max_length",   # pad to fixed length
    truncation=True,        # truncate if longer than max_length
    max_length=12,          # small value for illustration
    return_tensors="pt"
)

print("[Input IDs]:\n", encoded["input_ids"])
print("[Attention Mask]:\n", encoded["attention_mask"])

[Input IDs]:
 tensor([[  101, 26098, 18008,   102,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  101, 13675, 23555,  2102, 18008,  2005,  4088, 16912,  2007,  9413,
          7446,   102],
        [  101, 22746,  1998, 26098, 16611,  2164, 17044,  1010, 18008,  1010,
          1998,   102]])
[Attention Mask]:
 tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])


---

## [Part B] Understanding BERT Output

When we feed tokenized inputs into BERT, the model returns several outputs.  
For classification, the most important one is:

- **`last_hidden_state`**  
  Shape: `(batch_size, seq_len, hidden_size)`  
  â†’ Contextual embedding for each token in the sequence.

From `last_hidden_state`, we can derive two common sequence-level representations:

1. **[CLS] token embedding**  
   - The first token (`[CLS]`) is designed to capture the meaning of the entire sequence.  
   - Example usage: `cls_embedding = outputs.last_hidden_state[:, 0]`  
   - Shape: `(batch_size, hidden_size)`

2. **Mean pooling**  
   - Average all token embeddings across the sequence, weighted by the attention mask (ignoring padding).  
   - Captures information from all tokens, not just the first one.  
   - Example usage:  
     ```python
     mean_embedding = (outputs.last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) \
                      / attention_mask.sum(1).unsqueeze(-1)
     ```  
   - Shape: `(batch_size, hidden_size)`

In [2]:
from transformers import AutoModel
import torch

# Load BERT model
model = AutoModel.from_pretrained("bert-base-uncased")

# Forward pass
with torch.no_grad():
    outputs = model(**encoded)

# CLS embedding (first token from last_hidden_state)
cls_embedding = outputs.last_hidden_state[:, 0]   # (batch_size, hidden_size)

# Mean pooling (mask out padding tokens)
mean_embedding = (outputs.last_hidden_state * encoded["attention_mask"].unsqueeze(-1)).sum(1) \
                 / encoded["attention_mask"].sum(1).unsqueeze(-1)

print("[CLS Embedding Shape]:", cls_embedding.shape)
print("[Mean-pooled Embedding Shape]:", mean_embedding.shape)

# (Optional) Show first 5 values for one example
print("\nExample CLS embedding (first 5 dims):", cls_embedding[0][:5])
print("Example Mean embedding (first 5 dims):", mean_embedding[0][:5])

[CLS Embedding Shape]: torch.Size([3, 768])
[Mean-pooled Embedding Shape]: torch.Size([3, 768])

Example CLS embedding (first 5 dims): tensor([-0.3316,  0.0918, -0.9048,  0.1675,  0.2810])
Example Mean embedding (first 5 dims): tensor([ 0.3487, -0.2039, -0.7064,  0.1054,  0.3160])


---
## [Part C] Using BERT Mean-Pooled Embeddings

In this assignment, we mainly use **fixed BERT mean-pooled embeddings** for efficiency (i.e., partial fine-tuning with fixed encoder).

Instead of fine-tuning BERT every time, we pre-compute the embeddings once and reuse them.  

You can extract and save these embeddings with the following code snippet.  
*(You will need them in your assignments.)*

### 1. Save corpus & query embedding

In [3]:
from utils import *
import json
from pathlib import Path
import torch
from transformers import BertTokenizer, BertModel
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

# Path config
ROOT = Path("./dataset")
CORPUS_PATH = ROOT / "corpus.jsonl" #product_id, title, description
QUERY_PATH = ROOT / "queries_1k.jsonl" #query_id, query, product_id

# Load PLM
MODEL_NAME = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = BertModel.from_pretrained(MODEL_NAME).eval().to(device)

In [4]:
def mean_pooling(model_output, attention_mask):
    """
    Apply mean pooling on BERT token embeddings, masking out padding tokens.

    Args:
        model_output: Output object from a BERT model (contains last_hidden_state).
        attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len),
                                       where 1 = real token and 0 = padding.

    Returns:
        torch.Tensor: Sentence embeddings of shape (batch_size, hidden_size).
    """
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, dim=1)
    sum_mask = torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
    return sum_embeddings / sum_mask

def encode_texts(texts, batch_size=64):
    """
    Encode a list of texts into mean-pooled BERT embeddings.

    Args:
        texts (list of str): Input texts to encode.
        batch_size (int, optional): Batch size for encoding. Default is 64.

    Returns:
        torch.Tensor: Tensor of shape (len(texts), hidden_size) containing embeddings.
    """
    all_embeddings = []

    # Process texts in mini-batches
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i:i+batch_size]

        # Tokenize and move to model device
        encoded = tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors="pt"
        ).to(model.device)

        # Forward pass through BERT
        with torch.no_grad():
            output = model(**encoded)

        # Mean pooling (exclude padding tokens)
        embeddings = mean_pooling(output, encoded["attention_mask"])
        all_embeddings.append(embeddings.cpu())

    # Concatenate all batch embeddings
    return torch.cat(all_embeddings, dim=0)

In [3]:
# Encode corpus with BERT (mean pooling)
pid2text = load_corpus(CORPUS_PATH)              # Load corpus as {pid: text} dictionary
corpus_ids, corpus_texts = dict2list(pid2text)   # Convert dict â†’ (ids, texts) lists

corpus_emb = encode_texts(corpus_texts)          # Compute mean-pooled BERT embeddings
torch.save({"ids": corpus_ids, "embeddings": corpus_emb}, ROOT / "corpus_bert_mean.pt")

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 617/617 [05:46<00:00,  1.78it/s]


In [4]:
# Encode training queries with BERT (mean pooling)
qid2text = load_queries(QUERY_PATH)               # Load queries as {qid: text} dictionary
query_ids, query_texts = dict2list(qid2text)      # Convert dict â†’ (ids, texts) lists

query_emb = encode_texts(query_texts)             # Compute mean-pooled BERT embeddings

torch.save({"ids": query_ids, "embeddings": query_emb}, ROOT / "queries_1k_bert_mean.pt")

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 16/16 [00:01<00:00,  8.81it/s]


In [5]:
TEST_QUERY_PATH = ROOT / "queries_test.jsonl"

# Encode test queries (set 1) with BERT (mean pooling)
qid2text_test = load_queries(TEST_QUERY_PATH)           # Load as {qid: text}
query_ids_test, query_texts_test1 = dict2list(qid2text_test)  # Convert to lists
query_emb_test = encode_texts(query_texts_test)         # Compute embeddings

torch.save({"ids": query_ids_test, "embeddings": query_emb_test}, ROOT / "test_queries_bert_mean.pt")

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 8/8 [00:00<00:00, 51.81it/s]


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 8/8 [00:00<00:00, 53.11it/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 8/8 [00:00<00:00, 49.60it/s]


### 2. Category class label embedding

In [7]:
ROOT = Path("dataset")
LABEL_MAP_PATH = ROOT / "category_classification"
LABEL_TEXT_PATH = LABEL_MAP_PATH / "labelid2label.json"

In [8]:
with open(LABEL_TEXT_PATH) as f:
    id2label = json.load(f)

label_ids = list(map(int, id2label.keys()))
label_texts = [id2label[str(i)] for i in label_ids]

# label encoding
batch_size = 64
all_embeddings = []
for i in tqdm(range(0, len(label_texts), batch_size)):
    batch = label_texts[i:i+batch_size]
    encoded = tokenizer(batch, padding=True, truncation=True, max_length=128, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model(**encoded)
    emb = mean_pooling(output, encoded["attention_mask"])  # (B, D)
    all_embeddings.append(emb.cpu())

label_embeddings = torch.cat(all_embeddings, dim=0)  # (C, D)
torch.save({"ids": label_ids, "embeddings": label_embeddings}, LABEL_MAP_PATH / "category_labels_bert_mean.pt")

100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [00:01<00:00,  7.58it/s]


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15/15 [00:00<00:00, 18.55it/s]
