In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" # the device to load the model onto

merge_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

merge_model = merge_model.to(device)

In [None]:
from datasets import load_dataset
df = load_dataset("cnn_dailymail",  "1.0.0")

In [None]:
from tqdm.auto import tqdm
from transformers import Pipeline
from torch import Tensor
from adpated_forward_call import run_merge

class MyPipeline(Pipeline):
    def _sanitize_parameters(self,
                             **kwargs):
        preprocess_kwargs = {}
        if "maybe_arg" in kwargs:
            preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
        return preprocess_kwargs, {}, {}

    def preprocess(self, inputs):
        inputs = self.tokenizer(inputs, return_tensors = "pt", max_length = self.sl, truncation = True)
        model_input = Tensor(inputs["input_ids"][:,:self.sl])
        return {"model_input": model_input}

    def _forward(self, model_inputs):
        logits, length = run_merge(model_inputs["model_input"], self.lc, self.model)
        return {"logits" : logits, "length" : length}

    def postprocess(self, model_outputs):
        top_5_l, top_5_i = torch.topk(model_outputs["logits"], k=5, dim=-1)
        top_5_l = top_5_l[0,-1,:]
        top_5_i = top_5_i[0,-1,:]
        return {"top_5_l" : top_5_l.numpy(),
                "top_5_i" : top_5_i.numpy(),
                "length" : model_outputs["length"]}

In [None]:
from transformers.pipelines.pt_utils import KeyDataset

check_df = df["train"].shuffle().select(range(5000))

k_df = KeyDataset(check_df, "article")

In [None]:
pipeline = MyPipeline(model = merge_model, 
                              tokenizer = tokenizer,
                              device = 0,
                              num_workers = 8)
pipeline.sl = 516 # Change this
pipeline.lc = 8 # Change this

In [None]:
seq_lengths = [4, 8, 16]
layer_cut = [40, 28, 16, 12, 4, 2]

for res in pipeline(k_df):
    print(res)