### Step 3: Distill knowledge from teacher into pruned students
In this step, we will distill the depth and width pruned models using Knowledge Distillation. For usage details, please refer to the [distillation docs](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html) for more details.

Let's define the common parameters for distillation of depth or width pruned models first.

> `NOTE:` While this notebook uses the `wikitext` dataset as it is the most easy to get started with, in practice, we recommend using bigger, more recent and much higher quality datasets like [ClimbMix](https://huggingface.co/datasets/OptimalScale/ClimbMix) or [Nemotron-Pretraining-SFT-v1](https://huggingface.co/datasets/nvidia/Nemotron-Pretraining-SFT-v1). The WikiText dataset only has ~125M tokens while in practice, we recommend distilling the pruned model for ~50-100B tokens. Generally, the larger the dataset, the better the pruned model will perform; and the more aggressive the pruning, the more tokens are needed.

In [None]:
from math import ceil


NEMO_ROOT = "/opt/NeMo"
ROOT_DIR = "/workspace"
TEACHER_MODEL_PATH = f"{ROOT_DIR}/Qwen3-8B-nemo"

##### Set data paths
# NOTE: If you have multiple partitioned datasets, you can pass in a space-separated list of paths below.
DATA_PATH = f"{ROOT_DIR}/wikitext-data"
DATA_PATHS = f"{DATA_PATH}/wikitext-train_text_document"
INDEX_MAPPING_DIR = f"{DATA_PATH}/index_mappings"
# NOTE: Update this to the number according to your dataset
NUM_TOKENS = int(125e6)
NUM_VAL_TOKENS = int(NUM_TOKENS * 0.01)

##### Set Training Parameters
# NOTE: Use 4096 or 8192 Seq Len depending on whether your dataset texts are short or long
SEQ_LENGTH = 4096
# NOTE: GBS 768 and LR 1e-4 to 1e-5 generally works fine so dont change them unless you know what you are doing
GLOBAL_BATCH_SIZE = 768
LR = 1e-4
MIN_LR = 1e-5

MAX_STEPS = ceil(NUM_TOKENS / (SEQ_LENGTH * GLOBAL_BATCH_SIZE))
WARMUP_STEPS = min(100, ceil(MAX_STEPS / 10))
LOG_INTERVAL = min(100, ceil(MAX_STEPS / 10))
VAL_CHECK_INTERVAL = min(100, ceil(MAX_STEPS / 10))
LIMIT_VAL_BATCHES = min(32, ceil(NUM_VAL_TOKENS / (SEQ_LENGTH * GLOBAL_BATCH_SIZE)))

# Change these to accommodate your resources
DEVICES = 8
NODES = 1
TENSOR_PARALLEL_SIZE = DEVICES
PIPELINE_PARALLEL_SIZE = 1
# NOTE: Use as large of a micro batch size as your GPU can handle for better utilization
MICRO_BATCH_SIZE = 8


print("Training parameters:")
for k, v in list(locals().items()):
    if not k.startswith('_') and k.upper() == k:
        print("\t", k, v)

#### Step 3a: Distilling depth-pruned student
While distilling knowledge from the teacher to depth-pruned model, the `student_model_path` model would be  `<ROOT_DIR>/Qwen3-8B-nemo-depth-pruned` as produced by the depth-pruning step in the [pruning](./02_pruning.ipynb) notebook.

In [None]:
STUDENT_MODEL_PATH = f"{ROOT_DIR}/Qwen3-8B-nemo-depth-pruned"
LOG_DIR = ROOT_DIR
EXP_NAME = "Qwen3-8B-nemo-depth-pruned-distill"

!torchrun --nproc_per_node "{DEVICES}" "{NEMO_ROOT}/scripts/llm/gpt_train.py" \
    --name "{EXP_NAME}" \
    --devices "{DEVICES}" \
    --num_nodes "{NODES}" \
    --tp_size "{TENSOR_PARALLEL_SIZE}" \
    --pp_size "{PIPELINE_PARALLEL_SIZE}" \
    --model_path "{STUDENT_MODEL_PATH}" \
    --teacher_path "{TEACHER_MODEL_PATH}" \
    --legacy_ckpt \
    --max_steps "{MAX_STEPS}" \
    --warmup_steps "{WARMUP_STEPS}" \
    --gbs "{GLOBAL_BATCH_SIZE}" \
    --mbs "{MICRO_BATCH_SIZE}" \
    --lr "{LR}" \
    --min_lr "{MIN_LR}" \
    --seq_length "{SEQ_LENGTH}" \
    --log_dir "{LOG_DIR}" \
    --log_interval "{LOG_INTERVAL}" \
    --val_check_interval "{VAL_CHECK_INTERVAL}" \
    --limit_val_batches "{LIMIT_VAL_BATCHES}" \
    --data_paths "{DATA_PATHS}" \
    --index_mapping_dir "{INDEX_MAPPING_DIR}"

This will create the final distilled model at something like `<ROOT_DIR>/Qwen3-8B-nemo-depth-distilled/checkpoints/{model_name}--{val_loss:.2f}-{step}-{consumed_samples}`. Exact path depends on your distillation run. For simpicity in next steps, we can rename it to `<ROOT_DIR>/Qwen3-8B-nemo-depth-distilled/checkpoints/best`.

> `NOTE:`This script takes about 1 hour on 8x H100 to generate the final distilled model.

#### Step 3b: Distilling width-pruned student
While distilling knowledge from the teacher to width-pruned model, the `student_model_path` model would be  `<ROOT_DIR>/Qwen3-8B-nemo-width-pruned` as produced by the width-pruning step in the [pruning](./02_pruning.ipynb) notebook.

In [None]:
STUDENT_MODEL_PATH = f"{ROOT_DIR}/Qwen3-8B-nemo-width-pruned"
LOG_DIR = ROOT_DIR
EXP_NAME = "Qwen3-8B-nemo-width-pruned-distill"

!torchrun --nproc_per_node "{DEVICES}" "{NEMO_ROOT}/scripts/llm/gpt_train.py" \
    --name "{EXP_NAME}" \
    --devices "{DEVICES}" \
    --num_nodes "{NODES}" \
    --tp_size "{TENSOR_PARALLEL_SIZE}" \
    --pp_size "{PIPELINE_PARALLEL_SIZE}" \
    --model_path "{STUDENT_MODEL_PATH}" \
    --teacher_path "{TEACHER_MODEL_PATH}" \
    --legacy_ckpt \
    --max_steps "{MAX_STEPS}" \
    --warmup_steps "{WARMUP_STEPS}" \
    --gbs "{GLOBAL_BATCH_SIZE}" \
    --mbs "{MICRO_BATCH_SIZE}" \
    --lr "{LR}" \
    --min_lr "{MIN_LR}" \
    --seq_length "{SEQ_LENGTH}" \
    --log_dir "{LOG_DIR}" \
    --log_interval "{LOG_INTERVAL}" \
    --val_check_interval "{VAL_CHECK_INTERVAL}" \
    --limit_val_batches "{LIMIT_VAL_BATCHES}" \
    --data_paths "{DATA_PATHS}" \
    --index_mapping_dir "{INDEX_MAPPING_DIR}"

This will create the final distilled model at something like `<ROOT_DIR>/Qwen3-8B-nemo-width-distilled/checkpoints/{model_name}--{val_loss:.2f}-{step}-{consumed_samples}`. Exact path depends on your distillation run. For simpicity in next steps, we can rename it to `<ROOT_DIR>/Qwen3-8B-nemo-width-distilled/checkpoints/best`.

> `NOTE:`This script takes about 1 hour on 8x H100 to generate the final distilled model.


Checkout the next notebook to compare the depth and width pruned models.