# Quantization and Pruning with TensorFlow Model Optimization Toolkit

## Dataset Preparation

For this example, we will use the AG News dataset. The dataset contains 120,000 training samples and 7,600 test samples. Each sample consists of a title and a description of a news article, and a label that classifies the article into one of four categories: World, Sports, Business, and Sci/Tech.

In [None]:
from datasets import load_dataset

# Load AG News dataset
dataset = load_dataset("ag_news")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

# Preview the dataset
print(train_dataset[0])

## Data Preprocessing


During this process we will use the `BertTokenizer` from the `transformers` library to tokenize the text data. We will also apply padding and truncation to ensure that all sequences have the same length. Finally, we will format the datasets for PyTorch. 

> Note: Be careful with truncation, as it may remove important information from the text data. Also be careful with padding, as it may introduce noise into the data. For example, if you are working with text data that has a lot of padding, you may want to consider using a smaller batch size to reduce the amount of padding in the data.

---
**Truncation** 
It is the process of removing tokens from the beginning or end of a sequence to make it fit within a certain length. This is useful when working with sequences that are longer than the maximum length allowed by the model.

In this example the max_length is 5 and the padding token is `[PAD]`.

Text: "This is a test sentence with more than 5 words."

Tokenized: `['This', 'is', 'a', 'test', 'sentence', 'with', 'more', 'than', '5', 'words', '.']`

Truncated: `['This', 'is', 'a', 'test', 'sentence']`

---
**Padding** is the process of adding tokens to the end of a sequence to make it fit within a certain length. This is useful when working with sequences that are shorter than the maximum length allowed by the model.

In this example the max_length is 20 and the padding token is `[PAD]`.
sentence: "This is a sentence with less than 20 words."

Tokenized: `['This', 'is', 'a', 'sentence', 'with', 'less', 'than', '20', 'words', '.']`

Padded: `['This', 'is', 'a', 'sentence', 'with', 'less', 'than', '20', 'words', '.', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']`

In [None]:
from transformers import BertTokenizer

# Load BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenize function
def tokenize_function(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)

# Apply tokenizer
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

# Format datasets for PyTorch
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])


## Model Setup

We will use the `BertForSequenceClassification` model from the `transformers` library. This model is a pre-trained BERT model that has been fine-tuned for sequence classification tasks. We will load the pre-trained model with the `from_pretrained` method and specify the number of labels for the classification task.

In [None]:
from transformers import BertForSequenceClassification

# Load BERT for sequence classification
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4)

## Model Pruining

Pruning is a technique used to reduce the size of a neural network by removing unimportant weights. This can help reduce the computational cost of running the model and make it more efficient. In this example, we will use the `prune` method from the `transformers` library to prune the pre-trained BERT model.

Pros:
- Reduces the size of the model
- Reduces the computational cost of running the model
- Can improve the efficiency of the model

Cons:
- May reduce the accuracy of the model
- May require re-training the model after pruning

In [None]:
import torch
from torch.nn.utils import prune

# Prune the linear layers in the attention mechanism
for name, module in model.bert.encoder.layer[0].attention.self.named_modules():
    if isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name="weight", amount=0.2)  # Prune 20%

### - Fine-Tune the Pruned Model

In [None]:
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = logits.argmax(axis=-1)  # Convert logits to predictions
    accuracy = accuracy_score(labels, predictions)
    return {"accuracy": accuracy}


# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    save_steps=100,
    weight_decay=0.01,
)

# Define Trainer for fine-tuning
trainer = Trainer(
    model=model,  # Pruned model
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

# Train the pruned model
trainer.train()

In [None]:
results = trainer.evaluate()
print(f"Pruned and Fine-Tuned Model Accuracy: {results['eval_accuracy']:.2f}")

trainer.save_model("./models/pruned_fine_tuned_model.pth")

## Post-Training Quantization

In [None]:
# Load the pruned model
model = BertForSequenceClassification.from_pretrained("./models/pruned_fine_tuned_model.pth")

# Quantize the model
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

# Evaluate the quantized model
quantized_results = trainer.evaluate(quantized_model)

print(f"Quantized Model Accuracy: {quantized_results['eval_accuracy']:.2f}")

## Quantization-aware Training (QAT)

QAT is a technique that trains a model with quantization in mind. It simulates the effects of quantization during training, allowing the model to learn to be more robust to the quantization process. This can lead to better performance than when the model is quantized post-training.

Under the hood, this method uses the QuantizationAwareTraining class from the TensorFlow Model Optimization Toolkit to perform quantization-aware training. During training, it simulates the effects of converting 32-bit floating-point weights and activations into 8-bit integers, allowing the model to learn to compensate for quantization-induced errors. At inference time, the quantized 8-bit representation is typically used, which reduces the size of the model and improves computational efficiency.

Pros:
- Can improve the performance of the quantized model
- Can reduce the size of the model
- Can reduce the computational cost of running the model

Cons:
- May require additional training time
- May require additional computational resources

In [None]:
from torch.quantization import QuantStub, DeQuantStub
import torch

class QATBERT(torch.nn.Module):
    def __init__(self, model):
        super(QATBERT, self).__init__()
        self.quant = QuantStub()
        self.model = model
        self.dequant = DeQuantStub()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
        input_ids = self.quant(input_ids)
        output = self.model(input_ids, attention_mask, token_type_ids, labels)
        return self.dequant(output.logits)

qat_model = QATBERT(model)
qat_model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
torch.quantization.prepare_qat(qat_model, inplace=True)
print("QAT model prepared.")


### - Fine-tune the QAT model

In [None]:
trainer = Trainer(
    model=qat_model.model,  # QAT-enabled model
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

trainer.train()

In [None]:
torch.quantization.convert(qat_model, inplace=True)
print("Fully quantized model ready for inference.")

results = trainer.evaluate()
print(f"Quantized and Fine-Tuned Model Accuracy: {results['eval_accuracy']:.2f}")

# END OF NOTEBOOK