In [1]:
#!pip install datasets==1.4.1

In [2]:
#!pip install transformers

In [3]:
#!pip install accelerate -U

In [4]:
#!python -m pip install -U nn_pruning

In [5]:
import torch
import datasets
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__} and torch v{torch.__version__}")

Using transformers v4.35.0 and datasets v1.4.1 and torch v2.1.0+cu121


In [6]:
from datasets import load_dataset

boolq = load_dataset("super_glue", "boolq")
boolq

DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'label'],
        num_rows: 3245
    })
})

In [7]:
boolq['train'][0]

{'idx': 0,
 'label': 1,
 'passage': 'Persian language -- Persian (/ˈpɜːrʒən, -ʃən/), also known by its endonym Farsi (فارسی fārsi (fɒːɾˈsiː) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.',
 'question': 'do iran and afghanistan speak the same language'}

In [8]:
boolq.rename_column("label", "labels")

DatasetDict({
    train: Dataset({
        features: ['question', 'passage', 'idx', 'labels'],
        num_rows: 9427
    })
    validation: Dataset({
        features: ['question', 'passage', 'idx', 'labels'],
        num_rows: 3270
    })
    test: Dataset({
        features: ['question', 'passage', 'idx', 'labels'],
        num_rows: 3245
    })
})

In [9]:
#!pip install --upgrade --quiet jupyter_client ipywidgets

In [10]:
from transformers import AutoTokenizer

bert_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)

In [11]:
def tokenize_and_encode(examples):
    return tokenizer(examples['question'], examples['passage'], truncation="only_second")

boolq_enc = boolq.map(tokenize_and_encode, batched=True)

In [12]:
!pip install nn_pruning

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [13]:
from nn_pruning.sparse_trainer import SparseTrainer

In [14]:
!ls

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


01_sparse_trainer_rohan.ipynb  checkpoints  sst_thres30.ipynb
boolq.ipynb		       models
boolq_thres30.ipynb	       sst.ipynb


In [15]:
!python -m pip install -e ".[dev]"

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Obtaining file:///home/ubuntu/pruning
[31mERROR: file:///home/ubuntu/pruning does not appear to be a Python project: neither 'setup.py' nor 'pyproject.toml' found.[0m[31m
[0m

In [16]:
from transformers import Trainer
from nn_pruning.sparse_trainer import SparseTrainer

class PruningTrainer(SparseTrainer, Trainer):
    def __init__(self, sparse_args, *args, **kwargs):
        Trainer.__init__(self, *args, **kwargs)
        SparseTrainer.__init__(self, sparse_args)

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        self.metrics["ce_loss"] += float(loss)
        self.loss_counter += 1
        return (loss, outputs) if return_outputs else loss

In [17]:
from nn_pruning.patch_coordinator import SparseTrainingArguments

sparse_args = SparseTrainingArguments()
sparse_args

SparseTrainingArguments(mask_scores_learning_rate=0.01, dense_pruning_method='topK', attention_pruning_method='topK', ampere_pruning_method='disabled', attention_output_with_dense=True, bias_mask=True, mask_init='constant', mask_scale=0.0, dense_block_rows=1, dense_block_cols=1, attention_block_rows=1, attention_block_cols=1, initial_threshold=1.0, final_threshold=0.5, initial_warmup=1, final_warmup=2, initial_ampere_temperature=0.0, final_ampere_temperature=20.0, regularization='disabled', regularization_final_lambda=0.0, attention_lambda=1.0, dense_lambda=1.0, distil_teacher_name_or_path=None, distil_alpha_ce=0.5, distil_alpha_teacher=0.5, distil_temperature=2.0, final_finetune=False, layer_norm_patch=False, layer_norm_patch_steps=50000, layer_norm_patch_start_delta=0.99, gelu_patch=False, gelu_patch_steps=50000, linear_min_parameters=0.005, rewind_model_name_or_path=None)

In [18]:
hyperparams = {
    "dense_pruning_method": "topK:1d_alt",
    "attention_pruning_method": "topK",
    "initial_threshold": 1.0,
    "final_threshold": 0.3,
    "initial_warmup": 1,
    "final_warmup": 3,
    "attention_block_rows":32,
    "attention_block_cols":32,
    "attention_output_with_dense": 0
}

for k,v in hyperparams.items():
    if hasattr(sparse_args, k):
        setattr(sparse_args, k, v)
    else:
        print(f"sparse_args does not have argument {k}")

In [19]:
from transformers import TrainingArguments

batch_size = 16
learning_rate = 2e-5
num_train_epochs = 6
logging_steps = len(boolq_enc["train"]) // batch_size
warmup_steps = logging_steps * num_train_epochs * 0.1

args = TrainingArguments(
    output_dir="checkpoints",
    evaluation_strategy="epoch",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=0.01,
    logging_steps=logging_steps,
    save_strategy="epoch",
    disable_tqdm=False,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to=None
)

In [20]:
import torch
from transformers import AutoModelForSequenceClassification
from nn_pruning.patch_coordinator import ModelPatchingCoordinator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

mpc = ModelPatchingCoordinator(
    sparse_args=sparse_args,
    device=device,
    cache_dir="checkpoints",
    logit_names="logits",
    teacher_constructor=None)

In [21]:
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt).to(device)
mpc.patch_model(bert_model)

bert_model.save_pretrained("models/patched")

In [22]:
import numpy as np
from datasets import load_metric

accuracy_score = load_metric('accuracy')

def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

In [23]:
trainer = PruningTrainer(
    sparse_args=sparse_args,
    args=args,
    model=bert_model,
    train_dataset=boolq_enc["train"],
    eval_dataset=boolq_enc["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [24]:
trainer.set_patch_coordinator(mpc)

and fine-prune:

In [25]:
trainer.train();



Epoch,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second,Steps Per Second,Threshold,Regu Lambda,Ampere Temperature
1,0.6613,0.670379,0.621713,63.9896,51.102,3.204,0.3,0.0,20.0
2,0.6206,0.660605,0.621713,63.7401,51.302,3.216,0.3,0.0,20.0
3,0.559,0.635593,0.633333,63.7194,51.319,3.217,0.3,0.0,20.0
4,0.4889,0.593086,0.688991,63.6877,51.344,3.219,0.3,0.0,20.0
5,0.421,0.717192,0.661468,63.71,51.326,3.218,0.3,0.0,20.0
6,0.3654,0.736642,0.678899,63.8196,51.238,3.212,0.3,0.0,20.0


In [26]:
output_model_path = "models/bert-base-uncased-finepruned-boolq-less"
trainer.save_model(output_model_path)

In [27]:
mpc.compile_model(trainer.model)

(11, 144)

In [28]:
!pip install matplotlib

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [30]:
from nn_pruning.inference_model_patcher import optimize_model

prunebert_model = optimize_model(trainer.model, "dense")

removed heads 0, total_heads=133, percentage removed=0.0
bert.encoder.layer.0.intermediate.dense, sparsity = 69.99
bert.encoder.layer.0.output.dense, sparsity = 69.99
bert.encoder.layer.1.intermediate.dense, sparsity = 69.99
bert.encoder.layer.1.output.dense, sparsity = 69.99
bert.encoder.layer.2.intermediate.dense, sparsity = 69.99
bert.encoder.layer.2.output.dense, sparsity = 69.99
bert.encoder.layer.3.intermediate.dense, sparsity = 69.99
bert.encoder.layer.3.output.dense, sparsity = 69.99
bert.encoder.layer.4.intermediate.dense, sparsity = 69.99
bert.encoder.layer.4.output.dense, sparsity = 69.99
bert.encoder.layer.5.intermediate.dense, sparsity = 69.99
bert.encoder.layer.5.output.dense, sparsity = 69.99
bert.encoder.layer.6.intermediate.dense, sparsity = 69.99
bert.encoder.layer.6.output.dense, sparsity = 69.99
bert.encoder.layer.7.intermediate.dense, sparsity = 69.99
bert.encoder.layer.7.output.dense, sparsity = 69.99
bert.encoder.layer.8.intermediate.dense, sparsity = 69.99
bert.

We can also see what fraction of total parameters remain in our pruned model:

In [36]:
bert_model.num_parameters()

107318978

In [37]:
prunebert_model.num_parameters()

67664378

In [31]:
prunebert_model.num_parameters() / bert_model.num_parameters()

0.6304977857690743

In [32]:
from time import perf_counter

def compute_latencies(model,
                      question="Is Saving Private Ryan based on a book?",
                      passage="""In 1994, Robert Rodat wrote the script for the film. Rodat’s script was submitted to
                      producer Mark Gordon, who liked it and in turn passed it along to Spielberg to direct. The film is
                      loosely based on the World War II life stories of the Niland brothers. A shooting date was set for
                      June 27, 1997"""):
    inputs = tokenizer(question, passage, truncation="only_second", return_tensors="pt")
    latencies = []

    # Warmup
    for _ in range(10):
        _ = model(**inputs)

    for _ in range(100):
        start_time = perf_counter()
        _ = model(**inputs)
        latency = perf_counter() - start_time
        latencies.append(latency)
        # Compute run statistics
        time_avg_ms = 1000 * np.mean(latencies)
        time_std_ms = 1000 * np.std(latencies)
    print(f"Average latency (ms) - {time_avg_ms:.2f} +\- {time_std_ms:.2f}")
    return {"time_avg_ms": time_avg_ms, "time_std_ms": time_std_ms}

In [33]:
latencies = {}
latencies["prunebert"] = compute_latencies(prunebert_model.to("cpu"))

Average latency (ms) - 60.76 +\- 0.14


In [34]:
bert_unpruned = AutoModelForSequenceClassification.from_pretrained("lewtun/bert-base-uncased-finetuned-boolq").to("cpu")

latencies["bert-base"] = compute_latencies(bert_unpruned.to("cpu"))

Average latency (ms) - 104.06 +\- 3.07


In [None]:
#ref: https://github.com/huggingface/nn_pruning