# Example Merge Script

This example demonstrates how to merge two checkpoints using Fisher mask nodes. For new tasks and architectures 
the code may need to be modified, but it's generally straightforward.

## Setup
We use `bert-tiny` with the tasks `mnli` and `sst2` for demonstration purposes.

In [4]:
import torch
from calc_fisher import calculate_fisher
from fisher_nodes import FisherNodeWrapper
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BertModel

SEED = 0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_SAMPLES = 128

checkpoint_names = ["M-FAC/bert-tiny-finetuned-mnli", "M-FAC/bert-tiny-finetuned-sst2"]
tasks = ["mnli", "sst2"]

2024-03-13 11:34:36.962685: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-13 11:34:36.962723: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-13 11:34:36.963768: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-13 11:34:37.037724: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint_names[0], max_len=512)
checkpoints = [AutoModelForSequenceClassification.from_pretrained(name) for name in checkpoint_names]

  return self.fget.__get__(instance, owner)()


As the head of the model is different for each task, we need to keep track of it seperately.

In [43]:
dropouts = [checkpoint.dropout for checkpoint in checkpoints]
classifiers = [checkpoint.classifier for checkpoint in checkpoints]

## Calculate Fisher Information

`neuron_mask` and `head_mask` provide the tensors corresponding to the Fisher information of the inserted nodes.

In [6]:
neuron_masks = []
head_masks = []
for task, checkpoint in zip(tasks, checkpoints):
    neuron_mask, head_mask = calculate_fisher(
        model=checkpoint,
        task_name=task,
        tokenizer=tokenizer,
        num_samples=NUM_SAMPLES, 
        device=DEVICE, 
        seed=0)
    
    neuron_masks.append(neuron_mask)
    head_masks.append(head_mask)

Map:   0%|          | 0/9815 [00:00<?, ? examples/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

100%|██████████| 4/4 [00:00<00:00, 175.45it/s]


## Merge

`FisherNodeWrapper` is an abstraction to store the Fisher information of the nodes. As it directly implements the `__add__` method, we can simply add them together to perform the merge.

In [7]:
model_wrappers = [
    FisherNodeWrapper(finetuned=checkpoint, neuron_mask=neuron_mask, head_mask=head_mask) 
    for checkpoint, neuron_mask, head_mask in zip(checkpoints, neuron_masks, head_masks)
    ]
merged_params = sum(model_wrappers)

`apply_to` sets the model parameters to that of the merged model.

In [27]:
model = BertModel.from_pretrained('prajjwal1/bert-tiny')
model = merged_params.apply_to(model)

## Test

We test with an example from the appropriate dataset. As BERT requires heads for each task, we use the appropriate head stored in `classifiers` and `dropouts`. For other architectures, this may not be necessary.

### MNLI

In [64]:
premise = "Fun for adults and children."
hypothesis = "Fun for only children."
mnli_labels = ['entailment', 'neutral', 'contradiction']

inputs = tokenizer(premise, hypothesis, return_tensors="pt", padding=True, truncation=True)

with torch.no_grad():
    bert_output = model(**inputs).pooler_output
    logits = classifiers[0](dropouts[0](bert_output))

predicted_class_id = logits.argmax().item()
mnli_labels[predicted_class_id]

'contradiction'

### SST2

In [63]:
sentence = "equals the original and in some ways even betters it"
sst2_labels = ['negative', 'positive']

inputs = tokenizer(premise, hypothesis, return_tensors="pt", padding=True, truncation=True)

with torch.no_grad():
    bert_output = model(**inputs).pooler_output
    logits = classifiers[1](dropouts[1](bert_output))

predicted_class_id = logits.argmax().item()
sst2_labels[predicted_class_id]

'positive'