# BERT Knowledge Distillation for Translation

This notebook implements the three-stage approach from the paper ["Distilling Knowledge Learned in BERT for Text Generation"](https://arxiv.org/abs/1911.03829) (ACL 2020):

1. **CMLM Finetuning**: Fine-tune BERT as a conditional masked language model on translation data
2. **Knowledge Extraction**: Extract hidden states and compute teacher logits
3. **Student Training**: Train a smaller encoder-decoder translation model with knowledge distillation

We'll work with the IWSLT14 German-English dataset.

## Setup and Installation

First, we'll install the required dependencies and set up our environment.

In [None]:
# Install required packages
!pip install transformers==4.26.0
!pip install pytorch-pretrained-bert
!pip install cytoolz
!pip install tqdm
!pip install shelve-utils

# Import common libraries
import os
import sys
import torch
import numpy as np
import random
import shelve
import io
import argparse
import yaml
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from pytorch_pretrained_bert import BertTokenizer

# Add project directories to path for importing modules
sys.path.append('.')
sys.path.append('./opennmt')

# Set seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f"Using device: {device}")

## Download and Preprocess the IWSLT14 Dataset

We'll download the German-English translation dataset and prepare it for training.

In [None]:
# Create directories for data and outputs
!mkdir -p data/
!mkdir -p output/cmlm_model
!mkdir -p output/bert_dump
!mkdir -p output/kd-model/ckpt
!mkdir -p output/kd-model/log
!mkdir -p output/translation

# Download IWSLT German-English dataset using the provided script
!bash scripts/download-iwslt_deen.sh

In [None]:
# Prepare the dataset using the provided script
!bash scripts/prepare-iwslt_deen.sh

## Apply BERT Tokenization

We need to tokenize our data with the BERT tokenizer for the CMLM finetuning.

In [None]:
from scripts.bert_tokenize import tokenize, process

# Load BERT tokenizer
bert_model = "bert-base-multilingual-cased"
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case='uncased' in bert_model)

# Define data directories
data_dir = "data/de-en"

# BERT tokenize our dataset files
for language in ['de', 'en']:
    for split in ['train', 'valid', 'test']:
        input_file = f"{data_dir}/{split}.{language}"
        output_file = f"{data_dir}/{split}.{language}.bert"
        print(f"Tokenizing {input_file}...")
        
        with open(input_file, 'r') as reader, open(output_file, 'w') as writer:
            process(reader, writer, tokenizer)

## Prepare BERT Training Data

Now we'll prepare the database and vocabulary for BERT finetuning.

In [None]:
# Create dataset DB for BERT training
from scripts.bert_prepro import main as bert_prepro

# Set up args for bert_prepro
prepro_args = argparse.Namespace(
    src=f"{data_dir}/train.de.bert",
    tgt=f"{data_dir}/train.en.bert",
    output='data/DEEN.db'
)

# Run preprocessing
bert_prepro(prepro_args)

# Create vocabulary file using OpenNMT's preprocess.py
!python opennmt/preprocess.py \
    -train_src {data_dir}/train.de.bert \
    -train_tgt {data_dir}/train.en.bert \
    -valid_src {data_dir}/valid.de.bert \
    -valid_tgt {data_dir}/valid.en.bert \
    -save_data data/DEEN \
    -src_seq_length 150 -tgt_seq_length 150

## Stage 1: CMLM (Conditional Masked Language Model) Finetuning

In this stage, we fine-tune BERT as a Conditional Masked Language Model on our translation data.

In [None]:
from pytorch_pretrained_bert.optimization import BertAdam

# Import needed modules
from cmlm.data import BertDataset, TokenBucketSampler
from cmlm.model import convert_embedding, BertForSeq2seq
from cmlm.util import Logger, RunningMeter
from run_cmlm_finetuning import noam_schedule

# Load vocabulary
vocab_file = "data/DEEN.vocab.pt"
train_file = "data/DEEN.db"
valid_src = f"{data_dir}/valid.de.bert"
valid_tgt = f"{data_dir}/valid.en.bert"
output_dir = "output/cmlm_model"

# Load vocabulary
vocab_dump = torch.load(vocab_file)
vocab = vocab_dump['tgt'].fields[0][1].vocab.stoi

# Create dataset
train_dataset = BertDataset(train_file, tokenizer, vocab, seq_len=512, max_len=150)

# Define sampler and data loader
BUCKET_SIZE = 8192
train_sampler = TokenBucketSampler(
    train_dataset.lens, BUCKET_SIZE, 6144, batch_multiple=1)

train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                         num_workers=4,
                         collate_fn=BertDataset.pad_collate)

# Prepare model
model = BertForSeq2seq.from_pretrained(bert_model)
embedding = convert_embedding(tokenizer, vocab, model.bert.embeddings.word_embeddings.weight)
model.update_output_layer(embedding)
model.to(device)

In [None]:
# Training parameters
learning_rate = 5e-5
warmup_steps = 4000
max_steps = 100000  # Full training uses 100k steps
num_steps_to_run = 5000  # We'll do fewer steps for demonstration

# Optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer
                if any(nd in n for nd in no_decay)],
     'weight_decay': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters,
                    lr=learning_rate,
                    warmup=warmup_steps,
                    t_total=max_steps)

# Training loop
running_loss = RunningMeter('loss')
model.train()

print("Starting CMLM fine-tuning...")
for step, batch in enumerate(tqdm(train_loader, desc="Training")):
    if step >= num_steps_to_run:
        break
        
    # Move batch to device
    batch = tuple(t.to(device) for t in batch)
    input_ids, input_mask, segment_ids, lm_label_ids = batch
    
    # Zero gradients
    optimizer.zero_grad()
    
    # Forward pass
    loss = model(input_ids, segment_ids, input_mask, lm_label_ids)
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    running_loss(loss.item())
    
    if step % 100 == 0:
        print(f"Step {step}, Loss: {running_loss.val:.4f}")

# Save model checkpoint
torch.save(model.state_dict(), f"{output_dir}/model_step_{num_steps_to_run}.pt")
print(f"Model saved to {output_dir}/model_step_{num_steps_to_run}.pt")

## Stage 2: Extract Knowledge from Teacher Model

Now we'll extract the hidden states from our fine-tuned BERT model and compute the top-k logits that will be used for knowledge distillation.

In [None]:
# Import extraction functions
from dump_teacher_hiddens import tensor_dumps, gather_hiddens, BertSampleDataset, batch_features, process_batch

# Path to model checkpoint from Stage 1
ckpt_path = f"{output_dir}/model_step_{num_steps_to_run}.pt"
bert_dump_path = "output/bert_dump"

# Load the fine-tuned BERT model
state_dict = torch.load(ckpt_path)
vsize = state_dict['cls.predictions.decoder.weight'].size(0)
bert = BertForSeq2seq.from_pretrained(bert_model).eval()
bert.to(device)

# Update output layer
bert.update_output_layer_by_size(vsize)
bert.load_state_dict(state_dict)

# Save the final projection layer
linear = torch.nn.Linear(bert.config.hidden_size, bert.config.vocab_size)
linear.weight.data = state_dict['cls.predictions.decoder.weight']
linear.bias.data = state_dict['cls.predictions.bias']
torch.save(linear, f'{bert_dump_path}/linear.pt')

In [None]:
# Function to extract hidden states
def build_db_batched(corpus_path, out_db, bert, toker, batch_size=8):
    dataset = BertSampleDataset(corpus_path, toker)
    loader = DataLoader(dataset, batch_size=batch_size,
                       num_workers=4, collate_fn=batch_features)
    
    with tqdm(desc='Computing BERT features', total=len(dataset)) as pbar:
        for ids, *batch in loader:
            outputs = process_batch(batch, bert, toker)
            for id_, output in zip(ids, outputs):
                if output is not None:
                    out_db[id_] = tensor_dumps(output)
            pbar.update(len(ids))

# Extract hidden states
db_path = "data/DEEN.db"
print("Extracting hidden states...")
with shelve.open(f'{bert_dump_path}/db', 'c') as out_db, torch.no_grad():
    build_db_batched(db_path, out_db, bert, tokenizer, batch_size=8)

print(f"Hidden states extracted and saved to {bert_dump_path}/db")

## Computing Top-K Logits

Now we'll compute the top-k logits from the extracted hidden states.

In [None]:
# Import functions for top-k computation
from dump_teacher_topk import tensor_loads, dump_topk

# Top-K parameter
k = 8  # Following the paper

# Load linear layer
linear = torch.load(f'{bert_dump_path}/linear.pt')
linear.to(device)

# Compute top-k logits
print("Computing top-k logits...")
with shelve.open(f'{bert_dump_path}/db', 'r') as db, \
     shelve.open(f'{bert_dump_path}/topk', 'c') as topk_db:
    for key, value in tqdm(db.items(), total=len(db), desc='Computing topk...'):
        bert_hidden = torch.tensor(tensor_loads(value)).to(device)
        topk = linear(bert_hidden).topk(dim=-1, k=k)
        dump = dump_topk(topk)
        topk_db[key] = dump

print(f"Top-k logits computed and saved to {bert_dump_path}/topk")

## Stage 3: Train Student Translation Model with Knowledge Distillation

Finally, we'll train a transformer-based translation model using the knowledge distilled from BERT.

In [None]:
# Import required modules for training
from onmt.inputters.bert_kd_dataset import BertKdDataset, TokenBucketSampler
from onmt.utils.optimizers import Optimizer
from onmt.train_single import build_model_saver, build_trainer, cycle_loader

# Define paths
data_db = "data/DEEN.db"
bert_dump = "output/bert_dump"
data = "data/DEEN"
config_path = "opennmt/config/config-transformer-base-mt-deen.yml"
output_path = "output/kd-model"

# Load configuration
with open(config_path, 'r') as stream:
    config = yaml.safe_load(stream)

# Create args object
args = argparse.Namespace(**config)

# Setup KD parameters
args.train_from = None
args.max_grad_norm = None
args.kd_topk = 8
args.train_steps = 100000
args.kd_temperature = 10.0
args.kd_alpha = 0.5
args.warmup_steps = 8000
args.learning_rate = 2.0
args.bert_dump = bert_dump
args.data_db = data_db
args.bert_kd = True
args.data = data

args.save_model = os.path.join(output_path, 'ckpt', 'model')
args.log_file = os.path.join(output_path, 'log', 'log')
args.tensorboard_log_dir = os.path.join(output_path, 'log')

In [None]:
# Load vocabulary and dataset
vocab = torch.load(data + '.vocab.pt')
src_vocab = vocab['src'].fields[0][1].vocab.stoi
tgt_vocab = vocab['tgt'].fields[0][1].vocab.stoi

# Create dataset
train_dataset = BertKdDataset(data_db, bert_dump, 
                             src_vocab, tgt_vocab,
                             max_len=150, k=args.kd_topk)

# Create data loader
BUCKET_SIZE = 8192
train_sampler = TokenBucketSampler(
    train_dataset.keys, BUCKET_SIZE, 6144,
    batch_multiple=1)

train_loader = DataLoader(train_dataset, batch_sampler=train_sampler,
                         num_workers=4,
                         collate_fn=BertKdDataset.pad_collate)

train_iter = cycle_loader(train_loader, device)

In [None]:
# Build the model
from onmt.model_builder import build_model
model = build_model(args, args, checkpoint=None)
model.to(device)

# Build optimizer
optim = Optimizer.from_opt(model, args, checkpoint=None)

# Build model saver
model_saver = build_model_saver(args, args, model, vocab, optim)

# Build trainer
trainer = build_trainer(args, 0, model, vocab, optim, model_saver=model_saver)

# Train - for demonstration, we'll only do a few steps
num_steps_to_run_kd = 500  # Adjust for full training (paper used 100k steps)

print("Starting model training with knowledge distillation...")
trainer.train(
    train_iter,
    num_steps_to_run_kd,
    valid_iter=None
)

print(f"Model trained for {num_steps_to_run_kd} steps and saved to {output_path}/ckpt")

## Translation and Evaluation

Finally, we'll translate some text using our trained model and evaluate the performance.

In [None]:
# Define paths for translation
model_path = f"{output_path}/ckpt/model_step_{num_steps_to_run_kd}.pt"
src_file = f"{data_dir}/test.de.bert"
tgt_file = f"{data_dir}/test.en.bert"
out_dir = "output/translation"
ref_file = f"{data_dir}/test.en"

# Run translation if model exists
if os.path.exists(model_path):
    # Run translation
    !python opennmt/translate.py -model {model_path} \
                                -src {src_file} \
                                -tgt {tgt_file} \
                                -output {out_dir}/result.en \
                                -beam_size 5 -alpha 0.6 \
                                -length_penalty wu

    # Detokenize output
    !python scripts/bert_detokenize.py --file {out_dir}/result.en \
                                      --output_dir {out_dir}

    # Evaluate with BLEU
    !perl opennmt/tools/multi-bleu.perl {ref_file} \
                                       < {out_dir}/result.en.detok \
                                       > {out_dir}/result.bleu

    # Display BLEU score
    with open(f"{out_dir}/result.bleu", "r") as f:
        print(f.read())
else:
    print(f"Model file {model_path} not found. Skipping translation.")

## Visualize Training Results

Let's visualize the training progress and compare with the results from the paper.

In [None]:
# Install matplotlib if needed
!pip install matplotlib
import matplotlib.pyplot as plt

# Display the figures from the paper
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Plot CMLM finetuning
axes[0].set_title('CMLM Finetuning')
img = plt.imread('figures/cmlm-finetuning.png')
axes[0].imshow(img)
axes[0].axis('off')

# Plot translation losses
axes[1].set_title('Translation Losses')
img = plt.imread('figures/translation-losses.png')
axes[1].imshow(img)
axes[1].axis('off')

# Plot translation accuracy
axes[2].set_title('Translation Accuracy')
img = plt.imread('figures/translation-accuracy.png')
axes[2].imshow(img)
axes[2].axis('off')

plt.tight_layout()
plt.show()

## Conclusion

In this notebook, we've implemented the three-stage knowledge distillation process described in the paper "Distilling Knowledge Learned in BERT for Text Generation":

1. Fine-tuned a BERT model as a Conditional Masked Language Model (CMLM)
2. Extracted knowledge from the BERT teacher model and computed top-k logits
3. Trained a student translation model with knowledge distillation from BERT

For a full implementation with complete results, the model should be trained for many more steps:
- CMLM finetuning: 100,000 steps
- Student model training: 100,000 steps

As shown in the figures, knowledge distillation from BERT can improve translation performance compared to baseline methods.