# Fine-tune a BitNet Model on IMDB
This notebook converts a small pretrained model (`bert-tiny`) into a **BitNet** with ternary weights and fine-tunes it on IMDB sentiment classification. It demonstrates:
- Custom `BitLinear` layer with weight quantization and straight-through estimator.
- Replacing all linear layers in a transformer.
- Fine-tuning with Hugging Face `Trainer`.

**Note:** Use a GPU runtime (Runtime → Change runtime type → T4 GPU) for faster training.

In [None]:
# Install dependencies
!pip install transformers datasets accelerate scikit-learn

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)
from sklearn.metrics import accuracy_score, f1_score

## Define BitLinear Layer

In [None]:
class BitLinear(nn.Linear):
    def quantize_weights(self):
        w = self.weight
        alpha = w.abs().mean().clamp(min=1e-8)
        ternary = torch.where(w > 0.5 * alpha, alpha, torch.where(w < -0.5 * alpha, -alpha, 0.0))
        return ternary

    def forward(self, x):
        quantized_w = self.quantize_weights()
        w_ste = self.weight + (quantized_w - self.weight).detach()
        return F.linear(x, w_ste, self.bias)

## Replace Linear Layers in Model

In [None]:
def replace_linear_with_bitlinear(model):
    for name, child in model.named_children():
        if isinstance(child, nn.Linear) and name != 'classifier':
            new_layer = BitLinear(child.in_features, child.out_features, bias=child.bias is not None)
            new_layer.weight.data = child.weight.data.clone()
            if child.bias is not None:
                new_layer.bias.data = child.bias.data.clone()
            setattr(model, name, new_layer)
        else:
            replace_linear_with_bitlinear(child)

## Load Dataset and Model

In [None]:
# Load IMDB dataset (small subset for speed)
dataset = load_dataset("imdb")
train_small = dataset["train"].shuffle(seed=42).select(range(5000))
test_small = dataset["test"].shuffle(seed=42).select(range(1000))

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_train = train_small.map(tokenize_function, batched=True)
tokenized_test = test_small.map(tokenize_function, batched=True)

# Load model
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=2)
replace_linear_with_bitlinear(model)

## Define Metrics

In [None]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, predictions),
        "f1": f1_score(labels, predictions, average="weighted"),
    }

## Training Arguments

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
    report_to="none",
)

## Train

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    compute_metrics=compute_metrics,
)

trainer.train()

## Save Model

In [None]:
model.save_pretrained("./bitnet-imdb-finetuned")
tokenizer.save_pretrained("./bitnet-imdb-finetuned")

# Zip and download (optional)
import shutil
from google.colab import files
shutil.make_archive("bitnet-imdb-finetuned", 'zip', "./bitnet-imdb-finetuned")
files.download("bitnet-imdb-finetuned.zip")

## Quick Test

In [None]:
from transformers import pipeline
classifier = pipeline("text-classification", model="./bitnet-imdb-finetuned", tokenizer="./bitnet-imdb-finetuned")
print(classifier("This movie was absolutely wonderful!"))
print(classifier("Worst film ever made."))