In [1]:
#!pip install datasets==1.4.1

In [2]:
#!pip install transformers

In [3]:
#!pip install accelerate -U

In [4]:
#!python -m pip install -U nn_pruning

In [5]:
import torch
import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__} and torch v{torch.__version__}")

Using transformers v4.35.0 and datasets v1.4.1 and torch v2.1.0+cu121


In [6]:
from datasets import load_dataset

boolq = load_dataset("super_glue", "boolq")
boolq

DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3245
    })
})

In [7]:
boolq['train'][0]

{'idx': 0,
 'label': 1,
 'passage': 'Persian language -- Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.',
 'question': 'do iran and afghanistan speak the same language'}

In [8]:
boolq.rename_column("label", "labels")

DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'labels'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'labels'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'labels'],
        num_rows: 3245
    })
})

In [9]:
#!pip install --upgrade --quiet jupyter_client ipywidgets

In [10]:
from transformers import AutoTokenizer

bert_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)

In [11]:
def tokenize_and_encode(examples):
    return tokenizer(examples['question'], examples['passage'], truncation="only_second")

boolq_enc = boolq.map(tokenize_and_encode, batched=True)

In [13]:
!pip install nn_pruning

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting nn_pruning
  Using cached nn_pruning-0.1.2-py3-none-any.whl (33 kB)
Collecting scikit-learn>=0.24
  Downloading scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.1 MB)
[K     |████████████████████████████████| 11.1 MB 5.9 MB/s eta 0:00:01     |██████████████████████████████▋ | 10.6 MB 5.9 MB/s eta 0:00:01
Collecting scipy>=1.5.0
  Downloading scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
[K     |████████████████████████████████| 34.5 MB 99.5 MB/s eta 0:00:01
[?25hCollecting threadpoolctl>=2.0.0
  Using cached threadpoolctl-3.2.0-py3-none-any.whl (15 kB)
Collecting joblib>=1.1.1
  Using cached joblib-1.3.2-py3-none-any.whl (302 kB)
Installing collected packages: scipy, threadpoolctl, joblib, scikit-learn, nn-pruning
Successfully installed joblib-1.3.2 nn-pruning-0.1.2 scikit-learn-1.3.2 scipy-1.10.1 threadpoolctl-3.2.0


In [14]:
from nn_pruning.sparse_trainer import SparseTrainer

In [15]:
!ls

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


01_sparse_trainer.ipynb


In [None]:
!python -m pip install -e ".[dev]"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Obtaining file:///home/ubuntu/pruning/nn_pruning
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
Collecting pytest
  Downloading pytest-7.4.3-py3-none-any.whl (325 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m325.1/325.1 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting iniconfig
  Downloading iniconfig-2.0.0-py3-none-any.whl (5.9 kB)
Collecting tomli>=1.0.0
  Downloading tomli-2.0.1-py3-none-any.whl (12 kB)
Collecting exceptiongroup>=1.0.0rc8
  Downloading exceptiongroup-1.1.3-py3-none-any.whl (14 kB)
Building wheels for collected packages: nn-pruning
  Building editable for nn-pruning (pyproject.toml) ... [?25ldone
[?25h  Created wheel for nn-pruning: filename=nn_pruning-0.1.2-0.editable-py3-none-any.whl size=7289 sha256=34d09fcb65de0123

In [16]:
from transformers import Trainer
from nn_pruning.sparse_trainer import SparseTrainer

class PruningTrainer(SparseTrainer, Trainer):
    def __init__(self, sparse_args, *args, **kwargs):
        Trainer.__init__(self, *args, **kwargs)
        SparseTrainer.__init__(self, sparse_args)

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        self.metrics["ce_loss"] += float(loss)
        self.loss_counter += 1
        return (loss, outputs) if return_outputs else loss

In [17]:
from nn_pruning.patch_coordinator import SparseTrainingArguments

sparse_args = SparseTrainingArguments()
sparse_args

SparseTrainingArguments(mask_scores_learning_rate=0.01, dense_pruning_method='topK', attention_pruning_method='topK', ampere_pruning_method='disabled', attention_output_with_dense=True, bias_mask=True, mask_init='constant', mask_scale=0.0, dense_block_rows=1, dense_block_cols=1, attention_block_rows=1, attention_block_cols=1, initial_threshold=1.0, final_threshold=0.5, initial_warmup=1, final_warmup=2, initial_ampere_temperature=0.0, final_ampere_temperature=20.0, regularization='disabled', regularization_final_lambda=0.0, attention_lambda=1.0, dense_lambda=1.0, distil_teacher_name_or_path=None, distil_alpha_ce=0.5, distil_alpha_teacher=0.5, distil_temperature=2.0, final_finetune=False, layer_norm_patch=False, layer_norm_patch_steps=50000, layer_norm_patch_start_delta=0.99, gelu_patch=False, gelu_patch_steps=50000, linear_min_parameters=0.005, rewind_model_name_or_path=None)

In [18]:
hyperparams = {
    "dense_pruning_method": "topK:1d_alt",
    "attention_pruning_method": "topK",
    "initial_threshold": 1.0,
    "final_threshold": 0.5,
    "initial_warmup": 1,
    "final_warmup": 3,
    "attention_block_rows":32,
    "attention_block_cols":32,
    "attention_output_with_dense": 0
}

for k,v in hyperparams.items():
    if hasattr(sparse_args, k):
        setattr(sparse_args, k, v)
    else:
        print(f"sparse_args does not have argument {k}")

In [19]:
from transformers import TrainingArguments

batch_size = 16
learning_rate = 2e-5
num_train_epochs = 6
logging_steps = len(boolq_enc["train"]) // batch_size
warmup_steps = logging_steps * num_train_epochs * 0.1

args = TrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy="epoch",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=0.01,
    logging_steps=logging_steps,
    save_strategy="epoch",
    disable_tqdm=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to=None
)

In [20]:
import torch
from transformers import AutoModelForSequenceClassification
from nn_pruning.patch_coordinator import ModelPatchingCoordinator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mpc = ModelPatchingCoordinator(
    sparse_args=sparse_args,
    device=device,
    cache_dir="checkpoints",
    logit_names="logits",
    teacher_constructor=None)

In [21]:
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to(device)
mpc.patch_model(bert_model)

bert_model.save_pretrained("models/patched")

HBox(children=(FloatProgress(value=0.0, description='model.safetensors', max=440449768.0, style=ProgressStyle(…




In [22]:
import numpy as np
from datasets import load_metric

accuracy_score = load_metric('accuracy')

def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

In [23]:
trainer = PruningTrainer(
    sparse_args=sparse_args,
    args=args,
    model=bert_model,
    train_dataset=boolq_enc["train"],
    eval_dataset=boolq_enc["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [24]:
trainer.set_patch_coordinator(mpc)

In [25]:
trainer.train();



Epoch,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second,Steps Per Second,Threshold,Regu Lambda,Ampere Temperature
1,0.6574,0.664053,0.622018,64.6737,50.562,3.17,0.5,0.0,20.0
2,0.603,0.630869,0.638226,64.7031,50.539,3.168,0.5,0.0,20.0
3,0.5168,0.596548,0.687156,64.6741,50.561,3.17,0.5,0.0,20.0
4,0.429,0.639322,0.680428,64.6497,50.58,3.171,0.5,0.0,20.0
5,0.3447,0.81536,0.688073,64.6559,50.575,3.171,0.5,0.0,20.0
6,0.2768,0.837686,0.678899,64.5995,50.62,3.173,0.5,0.0,20.0


In [26]:
output_model_path = "models/bert-base-uncased-finepruned-boolq"
trainer.save_model(output_model_path)

In [27]:
mpc.compile_model(trainer.model)

(0, 144)

In [31]:
from nn_pruning.inference_model_patcher import optimize_model

prunebert_model = optimize_model(trainer.model, "dense")

removed heads 0, total_heads=144, percentage removed=0.0
bert.encoder.layer.0.intermediate.dense, sparsity = 50.00
bert.encoder.layer.0.output.dense, sparsity = 50.00
bert.encoder.layer.1.intermediate.dense, sparsity = 50.00
bert.encoder.layer.1.output.dense, sparsity = 50.00
bert.encoder.layer.2.intermediate.dense, sparsity = 50.00
bert.encoder.layer.2.output.dense, sparsity = 50.00
bert.encoder.layer.3.intermediate.dense, sparsity = 50.00
bert.encoder.layer.3.output.dense, sparsity = 50.00
bert.encoder.layer.4.intermediate.dense, sparsity = 50.00
bert.encoder.layer.4.output.dense, sparsity = 50.00
bert.encoder.layer.5.intermediate.dense, sparsity = 50.00
bert.encoder.layer.5.output.dense, sparsity = 50.00
bert.encoder.layer.6.intermediate.dense, sparsity = 50.00
bert.encoder.layer.6.output.dense, sparsity = 50.00
bert.encoder.layer.7.intermediate.dense, sparsity = 50.00
bert.encoder.layer.7.output.dense, sparsity = 50.00
bert.encoder.layer.8.intermediate.dense, sparsity = 50.00
bert.

In [32]:
prunebert_model.num_parameters() / bert_model.num_parameters()

0.7412403506937804

In [33]:
from time import perf_counter

def compute_latencies(model,
                      question="Is Saving Private Ryan based on a book?",
                      passage="""In 1994, Robert Rodat wrote the script for the film. Rodat’s script was submitted to
                      producer Mark Gordon, who liked it and in turn passed it along to Spielberg to direct. The film is
                      loosely based on the World War II life stories of the Niland brothers. A shooting date was set for
                      June 27, 1997"""):
    inputs = tokenizer(question, passage, truncation="only_second", return_tensors="pt")
    latencies = []

    # Warmup
    for _ in range(10):
        _ = model(**inputs)

    for _ in range(100):
        start_time = perf_counter()
        _ = model(**inputs)
        latency = perf_counter() - start_time
        latencies.append(latency)
        # Compute run statistics
        time_avg_ms = 1000 * np.mean(latencies)
        time_std_ms = 1000 * np.std(latencies)
    print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
    return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

In [34]:
latencies = {}
latencies["prunebert"] = compute_latencies(prunebert_model.to("cpu"))

Average latency (ms) - 61.33 +\- 0.31


In [35]:
bert_unpruned = AutoModelForSequenceClassification.from_pretrained("lewtun/bert-base-uncased-finetuned-boolq").to("cpu")

latencies["bert-base"] = compute_latencies(bert_unpruned.to("cpu"))

HBox(children=(FloatProgress(value=0.0, description='(…)finetuned-boolq/resolve/main/config.json', max=563.0, …




HBox(children=(FloatProgress(value=0.0, description='pytorch_model.bin', max=438022420.0, style=ProgressStyle(…


Average latency (ms) - 85.28 +\- 0.61


In [None]:
#ref: https://github.com/huggingface/nn_pruning