<a href="https://colab.research.google.com/github/shrisha-rao/bitnet-imdb/blob/main/notebooks/bitnet_imdb_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-tune a BitNet Model on IMDB
This notebook converts a small pretrained model (`bert-tiny`) into a **BitNet** with ternary weights and fine-tunes it on IMDB sentiment classification. It demonstrates:
- Custom `BitLinear` layer with weight quantization and straight-through estimator.
- Replacing all linear layers in a transformer.
- Fine-tuning with Hugging Face `Trainer`.

**Note:** Use a GPU runtime (Runtime → Change runtime type → T4 GPU) for faster training.

In [15]:
# Install dependencies
!pip install transformers datasets accelerate scikit-learn



In [None]:
!pip install --upgrade transformers

Collecting transformers
  Downloading transformers-5.2.0-py3-none-any.whl.metadata (32 kB)
Downloading transformers-5.2.0-py3-none-any.whl (10.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m37.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 5.0.0
    Uninstalling transformers-5.0.0:
      Successfully uninstalled transformers-5.0.0
Successfully installed transformers-5.2.0


In [None]:
!pip install sentencepiece



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)
from sklearn.metrics import accuracy_score, f1_score

## Define BitLinear Layer

In [None]:
class BitLinear(nn.Linear):
    def quantize_weights(self):
        w = self.weight
        alpha = w.abs().mean().clamp(min=1e-8)
        ternary = torch.where(w > 0.5 * alpha, alpha, torch.where(w < -0.5 * alpha, -alpha, 0.0))
        return ternary

    def forward(self, x):
        quantized_w = self.quantize_weights()
        w_ste = self.weight + (quantized_w - self.weight).detach()
        return F.linear(x, w_ste, self.bias)

## Replace Linear Layers in Model

In [None]:
def replace_linear_with_bitlinear(model):
    for name, child in model.named_children():
        if isinstance(child, nn.Linear) and name != 'classifier':
            new_layer = BitLinear(child.in_features, child.out_features, bias=child.bias is not None)
            new_layer.weight.data = child.weight.data.clone()
            if child.bias is not None:
                new_layer.bias.data = child.bias.data.clone()
            setattr(model, name, new_layer)
        else:
            replace_linear_with_bitlinear(child)

## Load Dataset and Model

In [None]:
# Load IMDB dataset (small subset for speed)
dataset = load_dataset("imdb")
train_small = dataset["train"].shuffle(seed=42).select(range(5000))
test_small = dataset["test"].shuffle(seed=42).select(range(1000))


model_name = "distilbert-base-uncased"  # 67M params, still small and fast
tokenizer = AutoTokenizer.from_pretrained(model_name)

#tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")
# tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny",
#                                           trust_remote_code=True)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True,
                     max_length=512)

tokenized_train = train_small.map(tokenize_function, batched=True)
tokenized_test = test_small.map(tokenize_function, batched=True)

# Load model
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny",
#                                                            num_labels=2)
replace_linear_with_bitlinear(model)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]



plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/100 [00:00<?, ?it/s]

[1mDistilBertForSequenceClassification LOAD REPORT[0m from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_layer_norm.bias   | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
vocab_transform.weight  | UNEXPECTED | 
pre_classifier.weight   | MISSING    | 
classifier.bias         | MISSING    | 
pre_classifier.bias     | MISSING    | 
classifier.weight       | MISSING    | 

[3mNotes:
- UNEXPECTED[3m	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING[3m	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.[0m


## Define Metrics

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, predictions),
        "f1": f1_score(labels, predictions, average="weighted"),
    }

## Training Arguments

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
    report_to="none",
)

`logging_dir` is deprecated and will be removed in v5.2. Please set `TENSORBOARD_LOGGING_DIR` instead.


## Train

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    compute_metrics=compute_metrics,
)

trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1
1,No log,0.340793,0.844,0.844022


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

There were missing keys in the checkpoint model loaded: ['distilbert.embeddings.LayerNorm.weight', 'distilbert.embeddings.LayerNorm.bias'].
There were unexpected keys in the checkpoint model loaded: ['distilbert.embeddings.LayerNorm.beta', 'distilbert.embeddings.LayerNorm.gamma'].


TrainOutput(global_step=313, training_loss=0.43188271811975837, metrics={'train_runtime': 282.5023, 'train_samples_per_second': 17.699, 'train_steps_per_second': 1.108, 'total_flos': 662336993280000.0, 'train_loss': 0.43188271811975837, 'epoch': 1.0})

## Save Model

In [None]:
model.save_pretrained("./bitnet-imdb-finetuned")
tokenizer.save_pretrained("./bitnet-imdb-finetuned")

# Zip and download (optional)
import shutil
from google.colab import files
shutil.make_archive("bitnet-imdb-finetuned", 'zip', "./bitnet-imdb-finetuned")
files.download("bitnet-imdb-finetuned.zip")

Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

## Quick Test

In [None]:
from transformers import pipeline
classifier = pipeline("text-classification", model="./bitnet-imdb-finetuned", tokenizer="./bitnet-imdb-finetuned")
print(classifier("This movie was absolutely wonderful!"))
print(classifier("Worst film ever made."))

Loading weights:   0%|          | 0/104 [00:00<?, ?it/s]

[{'label': 'LABEL_1', 'score': 0.983174741268158}]
[{'label': 'LABEL_0', 'score': 0.9422519207000732}]
