In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

device = "cuda" # the device to load the model onto

merge_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

merge_model = merge_model.to(device)

## Pipeline to generate a single token

In [None]:
from datasets import load_dataset
df = load_dataset("cnn_dailymail",  "1.0.0")

In [None]:
from tqdm.auto import tqdm
from transformers import Pipeline
from torch import Tensor
from adpated_forward_call import run_merge

class MyPipeline(Pipeline):
    def _sanitize_parameters(self,
                             **kwargs):
        preprocess_kwargs = {}
        if "maybe_arg" in kwargs:
            preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs):
        inputs = self.tokenizer(inputs, return_tensors = "pt", max_length = self.sl, truncation = True)
        model_input = Tensor(inputs["input_ids"][:,:self.sl])
        return {"model_input": model_input}

    def _forward(self, model_inputs):
        logits, length = run_merge(tokens = model_inputs["model_input"], 
                                   cutoff = self.lc, 
                                   starting_tokens=self.starting_tokens,
                                   max_tokens_before_keeping_end=self.max_tokens_before_keeping_end,
                                   ending_tokens = self.ending_tokens
                                   self.model)
        return {"logits" : logits, "length" : length}

    def postprocess(self, model_outputs):
        top_5_l, top_5_i = torch.topk(model_outputs["logits"], k=5, dim=-1)
        top_5_l = top_5_l[0,-1,:]
        top_5_i = top_5_i[0,-1,:]
        return {"top_5_l" : top_5_l.numpy(),
                "top_5_i" : top_5_i.numpy(),
                "length" : model_outputs["length"]}

In [None]:
from transformers.pipelines.pt_utils import KeyDataset

check_df = df["train"].shuffle().select(range(5000))

k_df = KeyDataset(check_df, "article")

In [None]:
pipeline = MyPipeline(model = merge_model, 
                      tokenizer = tokenizer,
                      device = 0,
                      num_workers = 8)

pipeline.sl = 516 # Change this to the max sequence length you want
pipeline.lc = 8 # Change this to apply the merging at the layer you wish
pipeline.starting_tokens = 0
pipeline.max_tokens_before_keeping_end = 100
pipeline.ending_tokens = 32

In [None]:
for res in pipeline(k_df):
    print(res)

## Pipeline to generate with merged tokens

In [None]:
from OptimizedMistral import OptimizedInferenceMistral

opt_model = OptimizedInferenceMistral.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.float16)
opt_model.to(device)
opt_model.cutoff = 20 # Layer indice where you apply merging

opt_model.max_tokens_before_keeping_end = 64 # Tokens limit to reach before not merging the last tokens defined under
opt_model.ending_tokens = 16 # How many tokens to not merge at the end of generation

merge_model.model = opt_model

In [None]:
from transformers import pipeline

generator = pipeline("text-generation", model=merge_model, tokenizer = tokenizer, device = 0, torch_dtype=torch.float16)
generator.tokenizer.pad_token_id = generator.tokenizer.eos_token_id

In [None]:
your_input = ""

In [None]:
length_of_prompt = len(tokenizer(your_input)["input_ids"])

generator.model.model.starting_tokens = length_of_prompt # You could skip this if you want to fully merge results

res = generator(your_input, max_new_tokens = 384, return_full_text = False, pad_token_id=generator.tokenizer.eos_token_id)