# Flan-T5 Model: Distillation, Pruning, and Quantization

This notebook demonstrates the process of distilling, pruning, and quantizing a Flan-T5 model using real datasets.

In [1]:
!pip install transformers datasets tensorflow

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m36.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
Collecting requests (from transformers)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 

## Import Libraries and Load Dataset

In [2]:
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
import os

# Load the dataset
dataset = load_dataset('wmt16', 'de-en')
train_dataset = dataset['train'].select(range(1000))

# Load teacher and student models
teacher_model = TFAutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')
student_model = TFAutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small')
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/11.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/282M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/267M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/277M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/343k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/475k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4548885 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2169 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2999 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

## Tokenization

In [3]:

# Tokenize the dataset
def preprocess_function(examples):
    inputs = [f"Translate from English to German: {ex['en']}" for ex in examples['translation']]
    targets = [ex['de'] for ex in examples['translation']]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding='max_length', return_tensors='tf')
    labels = tokenizer(targets, max_length=128, truncation=True, padding='max_length', return_tensors='tf')
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["decoder_input_ids"] = labels["input_ids"]
    return model_inputs

train_dataset = train_dataset.map(preprocess_function, batched=True)
train_dataset.set_format('tensorflow', columns=['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])

# Convert datasets to tf.data.Dataset
def dataset_to_tfdata(dataset):
    def generator():
        for example in dataset:
            yield {key: example[key].numpy() for key in example.keys()}

    return tf.data.Dataset.from_generator(generator, output_signature={
        'input_ids': tf.TensorSpec(shape=(128,), dtype=tf.int32),
        'attention_mask': tf.TensorSpec(shape=(128,), dtype=tf.int32),
        'labels': tf.TensorSpec(shape=(128,), dtype=tf.int32),
        'decoder_input_ids': tf.TensorSpec(shape=(128,), dtype=tf.int32)
    })

train_tf_dataset = dataset_to_tfdata(train_dataset).batch(8)


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

## Calculating Teacher logits

In [4]:

# Calculate teacher logits and save in chunks
logits = []
chunk_size = 125  # Adjust chunk size according to your memory constraints
chunk_counter = 0
os.makedirs('teacher_logits_chunks', exist_ok=True)

for i, batch in enumerate(train_tf_dataset):
    logit = teacher_model(batch, training=False).logits
    logits.append(logit)

    if len(logits) * logit.shape[0] >= chunk_size:
        np.save(f'teacher_logits_chunks/teacher_logits_chunk_{chunk_counter}.npy', tf.concat(logits, axis=0).numpy())
        logits = []
        chunk_counter += 1

# Save remaining logits
if logits:
    np.save(f'teacher_logits_chunks/teacher_logits_chunk_{chunk_counter}.npy', tf.concat(logits, axis=0).numpy())

# Load saved logits for training
def load_teacher_logits():
    logits = []
    for chunk_file in sorted(os.listdir('teacher_logits_chunks')):
        chunk_logits = np.load(os.path.join('teacher_logits_chunks', chunk_file))
        logits.append(chunk_logits)
    return tf.convert_to_tensor(np.concatenate(logits, axis=0))

teacher_logits = load_teacher_logits()


## Creating final dataset

In [5]:

# Combine teacher logits with the training data
def add_teacher_logits_to_dataset(dataset, teacher_logits):
    new_dataset = []
    for i, example in enumerate(dataset):
        input_ids = example['input_ids']
        decoder_input_ids = example['decoder_input_ids']
        logit = teacher_logits[i]
        new_dataset.append((input_ids, decoder_input_ids, logit))
    return new_dataset

train_dataset_final = add_teacher_logits_to_dataset(train_dataset, teacher_logits)

# Convert to tf.data.Dataset
def create_tf_dataset(dataset):
    inputs = {'input_ids': [], 'decoder_input_ids': []}
    logits = []
    for input_ids, decoder_input_ids, logit in dataset:
        inputs['input_ids'].append(input_ids)
        inputs['decoder_input_ids'].append(decoder_input_ids)
        logits.append(logit)

    inputs = {key: tf.convert_to_tensor(value) for key, value in inputs.items()}
    logits = tf.convert_to_tensor(logits)
    return tf.data.Dataset.from_tensor_slices((inputs, logits))

train_tf_dataset = create_tf_dataset(train_dataset_final).batch(8)


## Training the student model

In [6]:

# Define custom loss function
def custom_loss(y_true, y_pred):
    return tf.keras.losses.KLDivergence()(y_true, y_pred)

# Compile the model
student_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-2), loss=custom_loss)

# Train the student model
student_model.fit(train_tf_dataset, epochs=3)

# Save the distilled model
student_model.save_pretrained('./distilled_model')


Epoch 1/3


Cause: for/else statement not yet supported


Cause: for/else statement not yet supported
Epoch 2/3
Epoch 3/3


In [7]:
import gc

tf.keras.backend.clear_session()
tf.compat.v1.reset_default_graph()
del train_tf_dataset
del train_dataset_final
del teacher_logits
del teacher_model
del student_model
del logits
del dataset
gc.collect()


4405

## Prune the Model

In [8]:
import tf_keras
import numpy as np

def prune_layer(layer, pruning_factor=0.2):
    if isinstance(layer, tf_keras.src.layers.core.dense.Dense):
        weights = layer.get_weights()[0]
        abs_weights = np.abs(weights)
        threshold = np.percentile(abs_weights, pruning_factor * 100)
        mask = abs_weights >= threshold
        mask = tf.cast(mask, weights.dtype)
        new_weights = weights * mask
        layer.set_weights([new_weights])
        return layer

def prune_model(model, pruning_factor=0.2):
    for layer in model.encoder._flatten_layers():
        layer =  prune_layer(layer, pruning_factor*2)
    for layer in model.decoder._flatten_layers():
        layer =  prune_layer(layer, pruning_factor)
    layer = prune_layer(model.layers[-1], pruning_factor/2)
    return model

def get_zero_and_nonzero_params(model):
  total_zero_params = 0
  total_nonzero_params = 0
  for weight in model.get_weights():
    rows = weight.shape[0]
    cols = weight.shape[1] if len(weight.shape) > 1 else 1
    total_params = rows * cols
    total_params_non_zero = tf.math.count_nonzero(weight).numpy()
    total_zero_params += total_params - total_params_non_zero
    total_nonzero_params += total_params_non_zero
  return total_zero_params, total_nonzero_params

In [9]:
from transformers import TFAutoModelForSeq2SeqLM
import tensorflow as tf
model = TFAutoModelForSeq2SeqLM.from_pretrained('./distilled_model')
pruned_model = prune_model(model, pruning_factor=0.3)
zero_params_pruned, total_nonzero_params_pruned = get_zero_and_nonzero_params(pruned_model)
print(f"Total number of zero parameters after pruning: {zero_params_pruned}, and nonzero parameters after pruning: {total_nonzero_params_pruned}")
pruned_model.save_pretrained('./pruned_model')
del model
del pruned_model
gc.collect()

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at ./distilled_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Total number of zero parameters after pruning: 21341856, and nonzero parameters after pruning: 55619296


189667

## Quantize the Model

In [10]:
model = TFAutoModelForSeq2SeqLM.from_pretrained('./pruned_model')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Enable TensorFlow Select ops
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # Enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS     # Enable TensorFlow Select ops.
]
tflite_model = converter.convert()

# Save the quantized model
tflite_model_path = "model_quantized.tflite"
with open(tflite_model_path, "wb") as f:
    f.write(tflite_model)


All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at ./pruned_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
