In [1]:
!git clone https://github.com/pasapas321/spec.git

Cloning into 'spec'...
remote: Enumerating objects: 81, done.[K
remote: Counting objects: 100% (81/81), done.[K
remote: Compressing objects: 100% (73/73), done.[K
remote: Total 81 (delta 23), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (81/81), 582.51 KiB | 6.00 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [2]:
import sys

sys.path.append('/kaggle/working/spec')

In [1]:
!pip install -U "huggingface_hub[cli]"

Collecting huggingface_hub[cli]
  Downloading huggingface_hub-0.31.4-py3-none-any.whl.metadata (13 kB)
Collecting InquirerPy==0.3.4 (from huggingface_hub[cli])
  Downloading InquirerPy-0.3.4-py3-none-any.whl.metadata (8.1 kB)
Collecting pfzy<0.4.0,>=0.3.1 (from InquirerPy==0.3.4->huggingface_hub[cli])
  Downloading pfzy-0.3.4-py3-none-any.whl.metadata (4.9 kB)
Downloading InquirerPy-0.3.4-py3-none-any.whl (67 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.7/67.7 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading huggingface_hub-0.31.4-py3-none-any.whl (489 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m489.3/489.3 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading pfzy-0.3.4-py3-none-any.whl (8.5 kB)
Installing collected packages: pfzy, InquirerPy, huggingface_hub
  Attempting uninstall: huggingface_hub
    Found existing installation: huggingface-hub 0.29.0
    Uninstalling huggingface-hub-0

In [6]:
!pip install optimum

Collecting optimum
  Downloading optimum-1.24.0-py3-none-any.whl.metadata (21 kB)
Downloading optimum-1.24.0-py3-none-any.whl (433 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m433.6/433.6 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0mta [36m0:00:01[0m
[?25hInstalling collected packages: optimum
Successfully installed optimum-1.24.0


In [4]:
!pip install -U bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl (76.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.45.5


### SpecExec c lossy

In [8]:
"""
Base class for SpecInfer, SpecExec and whatever comes next
"""

from abc import ABC, abstractmethod

import numpy as np
import torch

from specdec.trees import Tree
from specdec import utils

if "logger" not in globals():
    logger = utils.get_logger()


class SpecBase(ABC):
    def __init__(self, draft_engine, target_engine, tokenizer):
        self.draft_engine = draft_engine
        self.target_engine = target_engine
        self.tokenizer = tokenizer
        self.device = self.draft_engine.device

    def generate(self, *args, **kwargs):
        """wrapper around generator"""
        for _ in self.generate_and_yield(*args, **kwargs):
            pass
        return self.prefix_tokens

    @torch.inference_mode()
    def generate_and_yield(
        self,
        prompt,
        max_new_tokens,
        lossy_const=2,
        seed=None,
        verbose_output=False,
        **kwargs,
    ):
        if kwargs:
            logger.debug(f"Found unused {kwargs=}")

        self.prefix_tokens = self.tokenizer.encode(prompt)
        self.original_num_tokens = len(self.prefix_tokens)

        #logger.info(f"{self.__class__.__name__} starting generation.")
        logger.debug(f"Prompt: '{prompt}'")

        self.history = []
        self.log = []
        self.summary = {
            **kwargs,
            "draft_model_name": self.draft_engine.config._name_or_path,
            "target_model_name": self.target_engine.config._name_or_path,
            "prompt_len": len(self.prefix_tokens),
            "prompt_text": prompt,
            "seed": seed,
            "max_new_tokens": max_new_tokens,
        }

        utils.set_seed(seed)
        torch.cuda.reset_peak_memory_stats()

        if hasattr(self, "tree"):
            del self.tree
        self.reset_engines(prefix_len=len(self.prefix_tokens), max_new_tokens=max_new_tokens, **kwargs)

        self.tree = Tree(prefix_tokens=self.prefix_tokens, draft_engine=self.draft_engine, tokenizer=self.tokenizer)
        self.levels = self.get_tree_levels(**kwargs)  # in case the child class works with fixed trees

        # warmup:
        logger.debug("=====  W A R M U P  ========")
        with utils.Timing(synchronize=True) as tw0:
            stats0 = self.grow_tree(prefix_tokens=self.prefix_tokens, **kwargs)
        with utils.Timing(synchronize=True) as tw1:
            stats1, warmup_tokens = self.validate_tree(lossy_const=lossy_const, **kwargs)
        torch.cuda.empty_cache()

        logger.debug(f"warmup time={tw0.elapsed + tw1.elapsed:.3f}; generated {len(warmup_tokens)} tokens.")
        self.prefix_tokens.extend(warmup_tokens)

        # main generation cycle
        iter = 1
        test_time = 0
        eos_flag = False

        while len(self.prefix_tokens) < max_new_tokens + self.original_num_tokens + len(warmup_tokens) and not eos_flag:
            logger.debug(f"=====  I T E R  {iter}  ========")

            with utils.Timing(synchronize=True) as t0:
                stats0 = self.grow_tree(prefix_tokens=self.prefix_tokens, **kwargs)
            with utils.Timing(synchronize=True) as t1:
                stats1, fresh_tokens = self.validate_tree(lossy_const=lossy_const, **kwargs)
            test_time += t0.elapsed + t1.elapsed
            torch.cuda.empty_cache()

            # logger.info(
            #     f"{iter:>3}.  "
            #     + f"Draft: {t0.elapsed:.3f}s, {stats0['tree_w']}w/{stats0['tree_h']}h/{stats0['tree_size']}size;  "
            #     + f"Target: {t1.elapsed:.3f}s, +{len(fresh_tokens)} tokens: {self.tokenizer.convert_ids_to_tokens(fresh_tokens)};  inp1:{stats1['input_len_1']}"
            # )

            if self.tokenizer.eos_token_id in fresh_tokens:
                fresh_tokens = fresh_tokens[: fresh_tokens.index(self.tokenizer.eos_token_id)]
                eos_flag = True
            self.prefix_tokens.extend(fresh_tokens)

            log1 = {
                "iter": iter,
                "t0": round(t0.elapsed, 2),
                "t1": round(t1.elapsed, 2),
                "new_tokens": len(fresh_tokens),
            }
            self.log.append({**log1, **stats0, **stats1})
            iter += 1
            yield fresh_tokens

        self.summary.update(
            {
                "iters": len(self.log),
                "new_tokens": len(self.prefix_tokens) - self.original_num_tokens - len(warmup_tokens),
                "tree_h": round(np.mean([x.get("tree_h") for x in self.log]), 1),
                "tree_w": int(np.mean([x.get("tree_w") for x in self.log])),
                "tree_size": int(np.mean([x.get("tree_size") for x in self.log])),
                "t0": round(sum([x.get("t0", 0) for x in self.log]) / len(self.log), 4),
                "t1": round(sum([x.get("t1", 0) for x in self.log]) / len(self.log), 4),
                "tft": round(tw0.elapsed + tw1.elapsed, 4),
                "input_0": int(sum([x.get("input_len_0", 0) for x in self.log]) / len(self.log)),
                "input_1": int(sum([x.get("input_len_1", 0) for x in self.log]) / len(self.log)),
                "min_CLP": round(np.mean([x.get("lowest_cum_log_prob", 0) for x in self.log]), 2),
                "draft_iters": round(np.mean([x.get("draft_iters", 0) for x in self.log]), 1),
                "mem_use": round(torch.cuda.max_memory_allocated() / 2**30, 2),
            }
        )
        self.summary["gen_rate"] = round(self.summary["new_tokens"] / self.summary["iters"], 1)
        self.summary["speed"] = round(self.summary["new_tokens"] / test_time, 2)
        logger.debug(f"\nResult tokens: {self.prefix_tokens}\n string:  {repr(self.tokenizer.decode(self.prefix_tokens))}")

        if verbose_output:
            print("Prompt:", "." * 80)
            print(utils.colored(prompt, "GREEN"))
            print("Generated", "." * 80)
            print(utils.colored(repr(self.tokenizer.decode(self.prefix_tokens[self.original_num_tokens :])), "CYAN"))
            print("=" * 80)

    @abstractmethod
    def grow_tree(self, tree, **kwargs):
        pass

    @abstractmethod
    def validate_tree(self, **kwargs):
        pass

    def get_tree_levels(self, **kwargs):
        # sets self.levels to None unless a child class overrides this method
        pass

    def reset_engines(self, **kwargs):
        self.draft_engine.clear_kv()
        self.target_engine.clear_kv()

In [9]:
"""
SpecExec, version 4b
"""

import numpy as np
import torch
import logging
import math

from specdec import utils
from specdec.trees import TOKENS, POSITIONS, PARENTS, STATUS  # noqa: F401

logger = utils.get_logger()


class SpecExecBase(SpecBase):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def reset_engines(self, prefix_len, max_budget, max_new_tokens, **kwargs):
        super().reset_engines()
        safety_margin = 64  # account for warmup tokens and excess generation
        draft_max_len = prefix_len + max_new_tokens + max_budget + safety_margin + 256
        target_max_len = prefix_len + max_new_tokens + max_budget + safety_margin
        if hasattr(self, "tree"):
            self.tree.set_max_len(draft_max_len)
        else:
            self.draft_engine.set_max_len(draft_max_len)

        self.target_engine.set_max_len(target_max_len)
        #logger.info(f"Max_len reset: {draft_max_len=}, {target_max_len=}")

    @torch.inference_mode()
    def grow_tree(self, max_budget, max_beam_len, max_n_beams=32, max_branch_width=16, min_log_prob=-10, decay_factor=0.95, **kwargs):
        """grows speculative tree
        Args:
            Engines, tokenizer,
            max_tokens (_type_): maximum tree size in tokens
            max_beam_length (_type_): number of growth iterations
            max_n_beams (int, optional): number of tokens considered at each iteration
            max_branch_width (int, optional): max number of children per branch.
            min_log_prob: limit for proba of added tokens. Defaults to -10.
        Returns:
            statistics and tree
        """

        logger.debug(f"=================  G R O W  {self.__class__.__name__}  ================================================== ")
        self.decay_factor_log = math.log(decay_factor)
        expected_ratios = []
        input_tokens_count = []  # for logging
        self.max_budget = max_budget
        pick_best = self.draft_engine.config.model_type in ["llama"]  # list models that support `cache_position` argument

        for next_position in range(self.tree.prefix_len, self.tree.prefix_len + max_beam_len):
            logger.debug(f"{next_position=} ----------------------- {self.tree.end=}")

            draft_logits, parent_indices, parent_scores, parent_positions = self.tree.process_candidates(lim=max_n_beams, pick_best=pick_best)

            logger.debug(f"after process_candidates {draft_logits.shape=}, {self.tree.end=}")

            best_child_token_ids, best_child_positions, best_parent_idxs, cum_beam_scores = self.get_next_beams(
                draft_logits,
                parent_pos=parent_positions,
                parent_idxs=parent_indices,
                beam_scores=parent_scores,
                num_beams=max_budget,
                decay_factor_log=self.decay_factor_log,
                max_branch_width=max_branch_width,
            )

            input_tokens_count.append(best_child_token_ids.numel())

            if best_child_token_ids.shape[-1] == 0:  # no new beams
                logger.debug("no children offered")
                # break if no child generated (possibly due to min_log_prob too high for the budget)
                break
            # break if new tokens are not in top budget by log prob
            if self.tree.end - self.tree.prefix_len >= max_budget:
                lowest_tree_log_prob = self.tree.log_probs[self.tree.prefix_len : self.tree.end].topk(k=max_budget, dim=-1, sorted=False).values.min()
                best_new_log_prob = cum_beam_scores.max()
                if best_new_log_prob <= lowest_tree_log_prob:
                    logger.debug(f"early stop: pos {next_position}: best_new={best_new_log_prob:.2f} <= lowest_tree_prob={lowest_tree_log_prob:.2f}")
                    break
            self.tree.add(best_child_token_ids, best_child_positions, best_parent_idxs, cum_beam_scores, new_status=self.tree.CANDIDATE)

        if self.tree.end - self.tree.prefix_len > max_budget:
            logger.debug(f"Have to trim: Draft token count: {self.tree.end - self.tree.prefix_len} > max_budget {max_budget}")
            self.tree.trim_budget(budget=max_budget)

        stats = {
            "tree_w": np.unique(self.tree.positions.tolist(), return_counts=True)[1].max(),
            "tree_h": self.tree.positions.max().item() - self.tree.prefix_len + 1,
            "tree_size": self.tree.size - self.tree.prefix_len,  # tree size net of prefix len
            "input_len_0": sum(input_tokens_count),
            "draft_iters": next_position - self.tree.prefix_len + 1 if "next_position" in locals() else 0,
            "lowest_cum_log_prob": round(self.tree.log_probs[: self.tree.end].min().item(), 4),
        }
        logger.debug(f"input_tokens_count: {sum(input_tokens_count)}, {input_tokens_count}")
        logger.debug(
            f"tree layer sizes: {torch.unique(self.tree.positions[self.tree.prefix_len:], return_counts=True)[1].tolist()}"
        )  # Tree nodes counts by level
        #logger.info(f"{stats}")

        return stats

    @torch.inference_mode()
    def validate_tree(self, temperature=1.0, top_p=1.0, lossy_const=2, **kwargs):
        """validation of the generated sequences with Target model"""
        logger.debug(f"=================  V A L I D A T E   {self.__class__.__name__}   ============================")
        target_token_map_bool = self.tree.status[: self.tree.end] >= self.tree.PROCESSED  # tokens generated in the current iteration
        target_token_map_bool[: self.tree.prefix_len] = False  # addresses problem of the last prefix token status
        target_token_idxs = torch.where(target_token_map_bool)[0]
        target_parent_idxs = self.tree.parents[: self.tree.end][target_token_map_bool]

        input_token_map_bool = target_token_map_bool.clone()  # tokens needed as target_engine forward inputs
        input_token_map_bool[target_parent_idxs] = True  # inputs for fwd
        if self.target_engine.kv_len_used == 0:
            input_token_map_bool[: self.tree.prefix_len] = True

        if self.tree.end > self.target_engine.max_len:
            #logger.info(f"target_engine max_len expands from {self.target_engine.set_max_len} to {self.tree.end}")
            self.target_engine.set_max_len(self.tree.end)

        input_ids = self.tree.tokens[input_token_map_bool].unsqueeze(0)
        cache_position = torch.where(input_token_map_bool)[0]
        amask_target = self.tree.amask[:, :, cache_position, : self.target_engine.max_len]  # clipping to target max_len
        position_ids = self.tree.positions[input_token_map_bool].unsqueeze(0)
        #logger.info(f"VAL {input_ids.shape=}, {amask_target.shape=}, {self.target_engine.kv_len_used=}, {self.tree.prefix_len=}, {self.tree.end=}")

        target_logits = self.target_engine.forward(
            input_ids=input_ids,
            attention_mask=self.tree.invert_mask(amask_target, dtype=self.target_engine.model.dtype),
            position_ids=position_ids,
            cache_position=cache_position,
        )
        target_logits = target_logits.squeeze(0)  # remove batch dim
        draft_token_choices = self.tree.tokens[target_token_map_bool]
        
        all_target_token_choices, all_target_token_logprobs = self.sampler_from_logits(logits=target_logits, 
                                                                                       draft_token_choices=draft_token_choices, 
                                                                                       lossy_const=lossy_const, 
                                                                                       temperature=temperature, 
                                                                                       top_p=top_p)

        # Matching target and draft choices to find the longest accepted sequence
        interim_t = torch.ones_like(self.tree.tokens)
        interim_t[input_token_map_bool] = all_target_token_choices

        target_token_choices = interim_t[target_parent_idxs]

        # get accept_mask
        accept_flags = draft_token_choices == target_token_choices  # flags of positions where draft & target match in <target_token_idxs space>
        accept_idxs = target_token_idxs[accept_flags]  # indices of positions where draft & target match

        accept_mask = torch.zeros(self.tree.end, device=self.device)  # mask for selecting rows from amask
        accept_mask[: self.tree.prefix_len] = 1  # assume whole prefix accepted
        accept_mask[accept_idxs] = 1  # add accepted idxs based on draft==target
        accepted_amask = amask_target[0, 0, :, : self.tree.end] * accept_mask

        mask_row_sums = amask_target[0, 0, :, : self.tree.end].sum(axis=1)

        # get the best sequence
        seq_lengths = accepted_amask.sum(axis=1)
        best_sequence_index = (mask_row_sums * (mask_row_sums == seq_lengths)).argmax()
        best_sequence_mask = amask_target[0, 0, best_sequence_index, : self.tree.end].to(torch.bool)

        fresh_token_idxs = torch.where(best_sequence_mask[self.tree.prefix_len :])[0] + self.tree.prefix_len
        fresh_token_ids = self.tree.tokens[fresh_token_idxs].tolist()

        last_accepted_token_position = fresh_token_idxs[-1] if fresh_token_idxs.numel() else self.tree.prefix_len - 1
        self.tree.reset_to_sequence(best_sequence_mask, target_engine=self.target_engine)

        # Generate one extra token based on target model logits
        extra_token_id = interim_t[last_accepted_token_position]  # all_target_token_choices[last_accepted_token_position - logits_offset]
        self.tree.add(
            token_ids=extra_token_id,
            positions=self.tree.positions[self.tree.size - 1] + 1,
            parent_idxs=torch.tensor([self.tree.size - 1], device=self.device),
            log_probs=torch.tensor(0.0, device=self.device),
            new_status=self.tree.CANDIDATE,
        )
        self.tree.prefix_len = self.tree.end
        self.tree.data[STATUS, : self.tree.prefix_len - 1] = self.tree.PROMPT

        fresh_token_ids.append(extra_token_id.item())

        if logger.level <= logging.DEBUG:
            logger.debug(f"{extra_token_id=}, '{self.tokenizer.convert_ids_to_tokens(extra_token_id.item())}'")
            logger.debug(f"sampled {len(fresh_token_ids)} tokens: {fresh_token_ids} {self.tokenizer.convert_ids_to_tokens(fresh_token_ids)}")
        stats = {"input_len_1": input_ids.shape[-1], "cache_len_1": self.target_engine.kv_len_used, "accepted_tokens": len(fresh_token_ids)}

        return stats, fresh_token_ids

    @staticmethod
    @torch.inference_mode()
    def sampler_from_logits(logits, draft_token_choices, lossy_const=2, temperature=1.0, top_p=0.9, min_tokens_to_keep=1):
        """
        Performs token sampling from logits using top-p (nucleus) sampling or deterministic selection.
        Args:
            logits (torch.Tensor): Logits from a language model.
            temperature (float): Adjusts distribution sharpness (higher = more random);  0 for greedy.
            top_p (float): Cumulative probability threshold for top-p sampling.
            min_tokens_to_keep (int): Minimum tokens to keep regardless of top_p.
        Returns: Tuple[torch.Tensor, torch.Tensor]: Indices and log probabilities of selected tokens.
        """

        if temperature > 0:
            if temperature != 1:
                scores = logits / temperature  # Apply temperature scaling
            else:
                scores = logits

            if top_p != 1.0:
                # Sort scores in descending order for top-p sampling
                sorted_logits, sorted_indices = torch.sort(scores, descending=True)
                cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)

                # Create a mask to remove logits not in the top-p
                sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
                sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0  # Keep at least min_tokens_to_keep tokens

                # Scatter the indices to the original order and mask the logits
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                scores.masked_fill_(indices_to_remove, -float("inf"))

            # Sampling from the filtered logits

            # for i, j in enumerate(range(len(scores))[-len(draft_token_choices):]):
            #     scores[j][draft_token_choices[i]] += (torch.median(scores) * lossy_const)
            #print(torch.median(scores))

            #print(torch.std(scores))
            
            K = len(draft_token_choices)
            last_K_rows_indices = torch.arange(len(scores))[-K:]
            # scores_update = abs(torch.median(scores)) * lossy_const
            # scores[last_K_rows_indices, draft_token_choices] += 20
            scores_update = torch.std(scores) * lossy_const
            scores[last_K_rows_indices, draft_token_choices] += scores_update

            
            probs = torch.softmax(scores, dim=-1)
            # print('probs', probs, probs.shape)
            _ = torch.multinomial(probs, 1)[:, 0]  # warmup
            selection = torch.multinomial(probs, 1)[:, 0]

            # Compute log probabilities
            logprobs = torch.log_softmax(scores, dim=-1)
            logprobs = torch.index_select(logprobs, 1, selection).diag()

        else:
            # Greedy selection
            selection = torch.argmax(logits, dim=-1)
            logprobs = torch.zeros_like(selection)

        return selection.to(logits.device), logprobs.to(logits.device)

    @torch.inference_mode()
    def get_next_beams(self, logits, parent_pos, parent_idxs, beam_scores, num_beams=None, min_log_prob=None, decay_factor_log=0, max_branch_width=4, **kwargs):
        """
        produces up to num_beams top beams by cumulative log_prob
        with log_prob >= min_log_prob limit
        decay_factor_log - accounts for uncertainty in acceptance of the draft tokens
        """

        logprobs = torch.log_softmax(logits, dim=-1)  # shape: [n_beams, voc_size]
        logprobs_k = logprobs.topk(k=max_branch_width, dim=-1, sorted=False)
        leaves_ids = logprobs_k.indices
        leaves_p = logprobs_k.values

        flat_incoming_probs = (beam_scores.unsqueeze(-1) + decay_factor_log + leaves_p).flatten()
        flat_incoming_ids = leaves_ids.flatten()
        sorted_incoming_probs = flat_incoming_probs.sort(descending=True)
        flat_sorted_log_probs = sorted_incoming_probs.values
        flat_sorted_indices = sorted_incoming_probs.indices

        joint_probs = torch.concat(
            [self.tree.log_probs[self.tree.prefix_len : self.tree.end], flat_sorted_log_probs]
        )  # existing + new probs, for finding threshold

        if joint_probs.shape[-1] > num_beams or joint_probs.shape[-1] + (self.tree.end - self.tree.prefix_len) > self.tree.max_len:
            min_joint_prob = joint_probs.topk(k=num_beams, sorted=False, dim=-1).values.min()

            flat_best_mask = torch.where(flat_sorted_log_probs >= min_joint_prob)[0]
            flat_best_probs = flat_sorted_log_probs[flat_best_mask]
            flat_best_indices = flat_sorted_indices[flat_best_mask]
            best_child_token_ids = flat_incoming_ids[flat_best_indices]

            if flat_best_indices.shape[-1] + self.tree.end > self.tree.max_len:
                logger.debug(f"get_next_beams: trimming draft from {self.tree.end - self.tree.prefix_len} to {self.max_budget=} tokens; {self.tree.end}")
                interim_idx = self.tree.trim_budget(min_log_prob=min_joint_prob)
                parent_idxs = interim_idx[parent_idxs]
        else:
            flat_best_probs = flat_sorted_log_probs
            flat_best_indices = flat_sorted_indices
            best_child_token_ids = flat_incoming_ids[flat_sorted_indices]

        best_hypo_ids = flat_best_indices // max_branch_width

        best_parent_idxs = parent_idxs[best_hypo_ids]
        best_child_pos = parent_pos[best_hypo_ids] + 1

        return best_child_token_ids, best_child_pos, best_parent_idxs, flat_best_probs

In [10]:
import argparse
import datetime
import json
import logging
import os
import socket
import subprocess
from itertools import product
from pathlib import Path

import pandas as pd
import torch
import transformers
from tqdm.auto import tqdm
from transformers import BitsAndBytesConfig, AutoModelForCausalLM

#from offloading.offload_model import load_gptq_offloaded_model, load_offloaded_model
from specdec import SpecExecBeams, SpecInfer, utils
import engine
from specdec.utils import colored

device = torch.device("cuda:0")
_DEFAULT_DEVICE_SIZE = 2
DISPLAY_WIDTH = 160
pd.set_option("display.width", DISPLAY_WIDTH)
pd.set_option("display.max_columns", 32)


def create_spec_generator(
    model_name_0,
    model_name_1,
    draft_engine_class,
    gen_type="SX",
    offload=False,
    device_size=_DEFAULT_DEVICE_SIZE,
    check_tokenizer=False,
    tree_max_len=4096
):
    """Creates a SpecGenerator object for different generation types.

    This function loads draft and target pre-trained language models specified by their names
    and creates a SpecBase subclass object based on the provided generation type.
    It also handles several configuration options like device placement and tokenizer verification.

    Args:
        model_name_0 (str): Name of the draft model.
        model_name_1 (str): Name of the target model.
        gen_type (str, optional): Generation type. Defaults to "SX" (SpecExec).
            Valid options include:
                - "SpecExecBase", : SpecExec generator
                - "SI", "spec_infer", "specinfer": SpecInfer generator
        offload (bool, optional): Whether to offload model 1 using offloading library. Defaults to False.
        device_size (int, optional): Device size for offloading. Defaults to `_DEFAULT_DEVICE_SIZE`.
        check_tokenizer (bool, optional): Whether to verify if both models have the same tokenizer. Defaults to False.

    Returns:
        SpecGenerator: An instance of a SpecBase subclass object based on the provided parameters.

    Raises:
        ValueError: If an invalid `gen_type` is provided.
    """

    if len(model_name_0.split("::")) == 2:
        model_name_0, rev_0 = model_name_0.split("::")
    else:
        rev_0 = "main"  # default in `from_pretrained()`

    if len(model_name_1.split("::")) == 2:
        model_name_1, rev_1 = model_name_1.split("::")
    else:
        rev_1 = "main"  # default in `from_pretrained()`

    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_0, legacy=False)

    if check_tokenizer:
        # verify that the two models have the same tokenizer
        tokenizer_1 = transformers.AutoTokenizer.from_pretrained(model_name_1, legacy=False)
        vv0 = tokenizer.get_vocab()
        vv1 = tokenizer_1.get_vocab()

        ignored_tokens = ["[PAD]"]  # disregard these tokens when comparing the cokonizers' vocabs
        assert set(vv0.keys()).difference(ignored_tokens) == set(vv1.keys()).difference(ignored_tokens)
        for k in set(vv0.keys()).difference(ignored_tokens):
            assert vv0[k] == vv1[k]
        del tokenizer_1, vv0, vv1

    #logger.info(f"Loading Model 0: `{model_name_0}`, {draft_engine_class=}")
    if draft_engine_class.lower() in ("es", "static", "enginestatic"):
        model_0 = transformers.AutoModelForCausalLM.from_pretrained(model_name_0, device_map=device, torch_dtype=torch.float16, revision=rev_0)
        draft_engine = engine.EngineStatic(model_0, max_len=tree_max_len)
    # elif draft_engine_class.lower() in ("esc", "staticcompiled", "enginestaticcompiled"):
    #     model_0 = transformers.AutoModelForCausalLM.from_pretrained(model_name_0, device_map=device, torch_dtype=torch.float16, revision=rev_0)
    #     draft_engine = engine.EngineStaticCompiled(model_0, max_len=args.tree_max_len)
    # elif draft_engine_class.lower() in ("ie", "inferenceengine"):
    #     draft_engine = engine.InferenceEngine(model_name_0, max_len=args.tree_max_len)
    elif draft_engine_class.lower() in ("padded", "inferenceenginepadded"):
        draft_engine = engine.InferenceEnginePadded(model_name_0, max_len=tree_max_len)
    elif draft_engine_class.lower() in ("er", "regular", "engineregular"):
        draft_engine = engine.EngineRegular(model_name_0, max_len=tree_max_len)
    else:
        raise ValueError(f"Unsupported engine class: {draft_engine_class} !")

    #logger.info(f"Loading Model 1: `{model_name_1}`")
    gptq_max_input_length = 16384  # constant for GPTQ models

    if offload:
        if "gptq" in model_name_1.lower():
            model_1 = load_gptq_offloaded_model(model_name_1, device_size=device_size, main_device=device, max_input_length=gptq_max_input_length)
        else:
            model_1 = load_offloaded_model(model_name_1, device_size=device_size, main_device=device)

    else:
        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
        model_1 = AutoModelForCausalLM.from_pretrained(model_name_1,
                                                       quantization_config=quantization_config,
                                                       torch_dtype=torch.float16, 
                                                       revision=rev_1,
                                                       device_map=device)
        # model_1 = transformers.AutoModelForCausalLM.from_pretrained(model_name_1, device_map=device, torch_dtype=torch.float16, revision=rev_1)

        if "gptq" in model_name_1.lower():
            model_1_config = transformers.AutoConfig.from_pretrained(model_name_1)
            if getattr(model_1_config.quantization_config, "act_order", False) and (model_1_config.config.max_length < 16384):
                try:
                    from auto_gptq import exllama_set_max_input_length

                    model_1 = exllama_set_max_input_length(model_1, gptq_max_input_length)
                    print("set `exllama_set_max_input_length` OK")
                except (AttributeError, ValueError, ImportError):
                    # AttributeError may happen if GPTQ-quantized model has no attribute 'device_to_buffers'
                    # could be fixed by using code from post_init()
                    # ImportError resembles https://github.com/open-mmlab/mmdetection3d/issues/1152
                    logger.warning("Failed to set `exllama_set_max_input_length`")

    # target_engine = EngineStatic(model_1, max_len=args.tree_max_len)
    target_engine = engine.EngineRegular(model_1, max_len=tree_max_len)

    if gen_type.lower() in ("sx_base", "base", "sx2", "spec_exec_base", "specexecbase"):
        spec_generator = SpecExecBase(draft_engine, target_engine, tokenizer)
    elif gen_type.lower() in ("spec_exec_beams", "specexecbeams", "sx_beams"):
        spec_generator = SpecExecBeams(draft_engine, target_engine, tokenizer)
    elif gen_type.lower() in ("sa", "a", "spec_adaptive", "specadaptive"):
        spec_generator = SpecAdaptive(draft_engine, target_engine, tokenizer)
    elif gen_type.lower() in ("sf", "f", "spec_fixed", "specfixed"):
        spec_generator = SpecFixed(draft_engine, target_engine, tokenizer)
    elif gen_type.lower() in ("si", "spec_infer", "specinfer"):
        spec_generator = SpecInfer(draft_engine, target_engine, tokenizer)
    elif gen_type.lower() in ("sis", "spec_infer_stems", "specinferstems"):
        spec_generator = SpecInferStems(draft_engine, target_engine, tokenizer)
    else:
        raise ValueError(f"unknown {gen_type=}")

    #logger.info(f"Created spec_generator of type {gen_type}; Models: {model_name_0}, {model_name_1}")
    return spec_generator

[15:37:42.915] INFO: NumExpr defaulting to 4 threads.


### Обычная генерация

In [7]:
from typing import Union, Sequence, Tuple
import json
import os
import time

from tqdm.auto import tqdm
from termcolor import colored
import numpy as np
import pandas as pd

import itertools
import torch
from torch import nn
import torch.nn.functional as F
import transformers
import re
from transformers import BitsAndBytesConfig, AutoModelForCausalLM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def extract_answer(s, suffix='<|eot_id|>'):
    s = s.lower().replace(suffix, '').replace('the final answer is', '=')
    idx = s.rfind("=")
    if idx != - 1:
        return s[idx + 1:].strip()
    

def extract_float(num_str):
    try:
        num_str = re.sub(r'[^0-9.-]', '', num_str).strip(".")
        return float(num_str)
    except (ValueError, TypeError):
        return

class args:
    draft_model = 'meta-llama/Llama-3.2-1B-Instruct'
    target_model = 'meta-llama/Llama-3.1-8B-Instruct'
    torch_dtype = 'auto'
    # data from https://github.com/openai/grade-school-math/tree/master/grade_school_math/data
    # use train.jsonl and test.jsonl
    gsm8k_test_path = '/kaggle/input/gsm888/test.jsonl'
    gsm8k_train_path = '/kaggle/input/gsm888/train.jsonl'
    random_seed = 42
    max_new_tokens = 1024
    n_samples = None


prompt_with_8_shots = "Given the following problem, reason and give a final answer to the problem.\nProblem: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The final answer is 6\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The final answer is 5\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The final answer is 39\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The final answer is 8\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The final answer is 9\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The final answer is 29\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The final answer is 33\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: "
prompt_with_0_shots = "Given the following problem, reason and give a final answer to the problem.\n"
formatting_prompt = "Your response should end with \"The final answer is [answer]\" where [answer] is the response to the problem."

np.random.seed(args.random_seed)

def load_questions(args):
    with open(args.gsm8k_test_path) as f:
        gsm_questions = [json.loads(line) for line in f]

    n_samples = args.n_samples or len(gsm_questions)
    gsm_questions = [
        {
            'question': i['question'],
            'answer': i['answer'][i['answer'].rfind('#### ') + 5:]
        }
        for i in gsm_questions[:n_samples]
    ]

    return gsm_questions

gsm_questions = load_questions(args)
#n_samples = len(gsm_questions)
n_samples = 64
print(f'Num samples: {n_samples}')

Num samples: 64


In [6]:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
target_model = AutoModelForCausalLM.from_pretrained(args.target_model,
                                                    quantization_config=quantization_config,
                                                    device_map=device)
draft_model = AutoModelForCausalLM.from_pretrained(args.draft_model,
                                                   quantization_config=quantization_config,
                                                   device_map=device)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

AttributeError: type object 'args' has no attribute 'draft_model'

In [8]:
draft_model = AutoModelForCausalLM.from_pretrained(args.draft_model,
                                                   quantization_config=quantization_config,
                                                   device_map=device)

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [9]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    'meta-llama/Llama-3.1-8B-Instruct', 
)
tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [80]:
def generate(input_ids, model, args):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()
    start_event.record()
    
    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=1024,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )

    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)

    return output_ids, elapsed_time_ms


correct = 0
n_samples = 256
times = []

with tqdm(total=n_samples) as pbar:
    for sample_idx in range(0, n_samples):
        question_sample = gsm_questions[sample_idx]
        answer = question_sample['answer']
        question = question_sample['question']
        formatted_zero_shot_prompt = prompt_with_8_shots + question + "\n" + formatting_prompt
        batch_input_ids = tokenizer.apply_chat_template(
            [{'role': 'user', 'content': formatted_zero_shot_prompt}],
            tokenize=True, return_tensors='pt', padding=True, continue_final_message=False  # <--- MIGHT BE A BUG
        ).to(device)
        generation, elapsed_time = generate(batch_input_ids, target_model, args)
        times.append(elapsed_time)
        generation_str = tokenizer.batch_decode(generation)[0]
        #print(generation_str)
        raw_pred = extract_answer(generation_str)

        pred_float = extract_float(raw_pred)
        answer_float = extract_float(answer)

        gen_tokens = generation.shape[-1] - batch_input_ids.shape[-1]
        verdict = int(answer_float == pred_float)

        correct += verdict
        accuracy = correct / (sample_idx + 1)

        print(
            f"{sample_idx:>4}.  verdict:{verdict:<1} "
            f"Averages:  acc:{accuracy:1.3f} "
            f"pred: {pred_float}, answer: {answer_float} "
            f"time: {np.round(elapsed_time / 1000, 2)}s"
        )
        pbar.update(1)

  0%|          | 0/256 [00:00<?, ?it/s]

   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 time: 19.51s
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 time: 12.56s
   2.  verdict:0 Averages:  acc:0.667 pred: 25000.0, answer: 70000.0 time: 29.23s
   3.  verdict:1 Averages:  acc:0.750 pred: 540.0, answer: 540.0 time: 13.07s
   4.  verdict:1 Averages:  acc:0.800 pred: 20.0, answer: 20.0 time: 26.01s
   5.  verdict:1 Averages:  acc:0.833 pred: 64.0, answer: 64.0 time: 27.14s
   6.  verdict:1 Averages:  acc:0.857 pred: 260.0, answer: 260.0 time: 13.38s
   7.  verdict:0 Averages:  acc:0.750 pred: 120.0, answer: 160.0 time: 22.76s
   8.  verdict:0 Averages:  acc:0.667 pred: 315.0, answer: 45.0 time: 37.94s
   9.  verdict:1 Averages:  acc:0.700 pred: 460.0, answer: 460.0 time: 17.74s
  10.  verdict:1 Averages:  acc:0.727 pred: 366.0, answer: 366.0 time: 21.53s
  11.  verdict:1 Averages:  acc:0.750 pred: 694.0, answer: 694.0 time: 13.24s
  12.  verdict:0 Averages:  acc:0.692 pred: 12.0, answer: 13.0 time: 84

In [81]:
print(f"Таргетная модель: time - {np.mean(times) / 1000} s, accuracy - {accuracy}")

Таргетная модель: time - 21.397363939285277 s, accuracy - 0.859375


In [82]:
correct = 0
n_samples = 256
times = []

with tqdm(total=n_samples) as pbar:
    for sample_idx in range(0, n_samples):
        question_sample = gsm_questions[sample_idx]
        answer = question_sample['answer']
        question = question_sample['question']
        formatted_zero_shot_prompt = prompt_with_8_shots + question + "\n" + formatting_prompt
        batch_input_ids = tokenizer.apply_chat_template(
            [{'role': 'user', 'content': formatted_zero_shot_prompt}],
            tokenize=True, return_tensors='pt', padding=True, continue_final_message=False  # <--- MIGHT BE A BUG
        ).to(device)
        generation, elapsed_time = generate(batch_input_ids, draft_model, args)
        times.append(elapsed_time)
        generation_str = tokenizer.batch_decode(generation)[0]
        #print(generation_str)
        raw_pred = extract_answer(generation_str)

        pred_float = extract_float(raw_pred)
        answer_float = extract_float(answer)

        gen_tokens = generation.shape[-1] - batch_input_ids.shape[-1]
        verdict = int(answer_float == pred_float)

        correct += verdict
        accuracy = correct / (sample_idx + 1)

        print(
            f"{sample_idx:>4}.  verdict:{verdict:<1} "
            f"Averages:  acc:{accuracy:1.3f} "
            f"pred: {pred_float}, answer: {answer_float} "
            f"time: {np.round(elapsed_time / 1000, 2)}s"
        )
        pbar.update(1)

  0%|          | 0/256 [00:00<?, ?it/s]

   0.  verdict:0 Averages:  acc:0.000 pred: 26.0, answer: 18.0 time: 9.1s
   1.  verdict:1 Averages:  acc:0.500 pred: 3.0, answer: 3.0 time: 4.82s
   2.  verdict:0 Averages:  acc:0.333 pred: 50000.0, answer: 70000.0 time: 4.02s
   3.  verdict:0 Averages:  acc:0.250 pred: 6060.0, answer: 540.0 time: 7.06s
   4.  verdict:0 Averages:  acc:0.200 pred: 120.0, answer: 20.0 time: 5.62s
   5.  verdict:0 Averages:  acc:0.167 pred: 48.0, answer: 64.0 time: 5.19s
   6.  verdict:0 Averages:  acc:0.143 pred: 20.081405, answer: 260.0 time: 14.43s
   7.  verdict:0 Averages:  acc:0.125 pred: None, answer: 160.0 time: 9.88s
   8.  verdict:0 Averages:  acc:0.111 pred: 3403404.0, answer: 45.0 time: 9.61s
   9.  verdict:0 Averages:  acc:0.100 pred: 15.0, answer: 460.0 time: 0.71s
  10.  verdict:1 Averages:  acc:0.182 pred: 366.0, answer: 366.0 time: 10.01s
  11.  verdict:0 Averages:  acc:0.167 pred: 168.0, answer: 694.0 time: 0.71s
  12.  verdict:0 Averages:  acc:0.154 pred: None, answer: 13.0 time: 9.15s

In [83]:
print(f"Драфтовая модель: {np.mean(times) / 1000} s, accuracy - {accuracy}")

Драфтовая модель: 7.950274225234986 s, accuracy - 0.25


__________________________________________________

### Тестирование спекулятивной генерации с lossy Hugging Face

In [5]:
from typing import Union, Sequence, Tuple
import json
import os

from tqdm.auto import tqdm
from termcolor import colored
import numpy as np
import pandas as pd

import itertools
import torch
from torch import nn
import torch.nn.functional as F
import transformers
import re
from transformers import BitsAndBytesConfig, AutoModelForCausalLM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def extract_answer(s, suffix='<|eot_id|>'):
    s = s.lower().replace(suffix, '').replace('the final answer is', '=')
    idx = s.rfind("=")
    if idx != - 1:
        return s[idx + 1:].strip()
    

def extract_float(num_str):
    try:
        num_str = re.sub(r'[^0-9.-]', '', num_str).strip(".")
        return float(num_str)
    except (ValueError, TypeError):
        return

class args:
    draft_model = 'meta-llama/Llama-3.2-1B-Instruct'
    target_model = 'meta-llama/Llama-3.1-8B-Instruct'
    torch_dtype = 'auto'
    # data from https://github.com/openai/grade-school-math/tree/master/grade_school_math/data
    # use train.jsonl and test.jsonl
    gsm8k_test_path = '/kaggle/input/gsm888/test.jsonl'
    gsm8k_train_path = '/kaggle/input/gsm888/train.jsonl'
    random_seed = 42
    max_new_tokens = 1024
    n_samples = None


prompt_with_8_shots = "Given the following problem, reason and give a final answer to the problem.\nProblem: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The final answer is 6\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The final answer is 5\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The final answer is 39\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The final answer is 8\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The final answer is 9\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The final answer is 29\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The final answer is 33\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: "
prompt_with_0_shots = "Given the following problem, reason and give a final answer to the problem.\n"
formatting_prompt = "Your response should end with \"The final answer is [answer]\" where [answer] is the response to the problem."

np.random.seed(args.random_seed)

def load_questions(args):
    with open(args.gsm8k_test_path) as f:
        gsm_questions = [json.loads(line) for line in f]

    n_samples = args.n_samples or len(gsm_questions)
    gsm_questions = [
        {
            'question': i['question'],
            'answer': i['answer'][i['answer'].rfind('#### ') + 5:]
        }
        for i in gsm_questions[:n_samples]
    ]

    return gsm_questions

gsm_questions = load_questions(args)
#n_samples = len(gsm_questions)
n_samples = 256
print(f'Num samples: {n_samples}')

Num samples: 256


In [6]:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
target_model = AutoModelForCausalLM.from_pretrained(args.target_model,
                                                    quantization_config=quantization_config,
                                                    device_map=device)
draft_model = AutoModelForCausalLM.from_pretrained(args.draft_model,
                                                   quantization_config=quantization_config,
                                                   device_map=device)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [94]:
draft_model.generation_config.num_assistant_tokens = 20
draft_model.generation_config.assistant_confidence_threshold = 0.01

In [7]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    args.target_model, 
)
tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [8]:
import copy

from transformers.generation.candidate_generator import (
    CandidateGenerator,
    _crop_past_key_values,
    _prepare_attention_mask,
    _prepare_token_type_ids,
)

from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.configuration_utils import GenerationConfig
from typing import Optional, Union

from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput

GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]


def _assisted_decoding_2(
    input_ids: torch.LongTensor,
    candidate_generator: CandidateGenerator,
    logits_processor: LogitsProcessorList,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    synced_gpus: bool,
    streamer: Optional["BaseStreamer"],
    **model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
    r"""
    Generates sequences of token ids for models with a language modeling head using **greedy decoding** or
    **sample** (depending on `do_sample`), assisted by candidate sequences. Assisted generation is an example of a
    candidate decoding strategy. Can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text
    models.

    Parameters:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The sequence used as a prompt for the generation.
        candidate_generator (`CandidateGenerator`):
            A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For
            more information, the documentation of [`CandidateGenerator`] should be read.
        logits_processor (`LogitsProcessorList`):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        stopping_criteria (`StoppingCriteriaList`):
            An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
            used to tell if the generation loop should stop.
        generation_config ([`~generation.GenerationConfig`]):
            The generation configuration to be used as parametrization of the decoding method.
        synced_gpus (`bool`):
            Whether to continue running the while loop until max_length (needed to avoid deadlocking with
            `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
        streamer (`BaseStreamer`, *optional*):
            Streamer object that will be used to stream the generated sequences. Generated tokens are passed
            through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
        model_kwargs:
            Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
            If model is an encoder-decoder model the kwargs should include `encoder_outputs`.

    Return:
        [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or
        `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
        [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
        `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
        `model.config.is_encoder_decoder=True`.
    """

    global target_model
    # init values
    do_sample = generation_config.do_sample
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    output_logits = generation_config.output_logits
    return_dict_in_generate = generation_config.return_dict_in_generate

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    if return_dict_in_generate and target_model.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # keep track of which sequences are already finished
    batch_size = input_ids.shape[0]
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
    model_kwargs = target_model._get_initial_cache_position(input_ids, model_kwargs)

    this_peer_finished = False
    is_first_iteration = True  # to preserve the same API in the output as other generation methods

    added_lengths = []
    candidate_length_arr = []
    
    while target_model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
        cur_len = input_ids.shape[-1]

        #  1. Fetch candidate sequences from a `CandidateGenerator` and move to the correct device
        candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids)
        candidate_input_ids = candidate_input_ids.to(target_model.device)
        if candidate_logits is not None:
            candidate_logits = candidate_logits.to(target_model.device)

        candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
        candidate_length_arr.append(candidate_length)
        #print('candidate_length:', candidate_length)
        is_done_candidate = stopping_criteria(candidate_input_ids, None)

        # 2. Use the original model to obtain the next token logits given the candidate sequence. We obtain
        # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct,
        # we use this forward pass to also pick the subsequent logits in the original model.

        # 2.1. Prepare the model inputs
        candidate_kwargs = copy.copy(model_kwargs)
        candidate_kwargs = _prepare_attention_mask(
            candidate_kwargs, candidate_input_ids.shape[1], target_model.config.is_encoder_decoder
        )
        candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
        if "cache_position" in candidate_kwargs:
            candidate_kwargs["cache_position"] = torch.cat(
                (
                    candidate_kwargs["cache_position"],
                    torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
                ),
                dim=0,
            )

        model_inputs = target_model.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
        
        if "num_logits_to_keep" in model_inputs:
            model_inputs["num_logits_to_keep"] = candidate_length + 1

        # 2.2. Run a forward pass on the candidate sequence
        # prepare variable output controls (note: some models won't accept all output controls)
        model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
        model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

        outputs = target_model(**model_inputs)

        # 2.3. Process the new logits
        # .float() is needed to retain precision for later logits manipulations
        new_logits = outputs.logits[:, -candidate_length - 1 :].float()  # excludes the input prompt if present
        new_logits = new_logits.to(input_ids.device)
        next_token_logits = new_logits.clone()
        
        if len(logits_processor) > 0:
            for i in range(candidate_length + 1):
                new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])

        # 3. Select the accepted tokens. There are two possible cases:
        # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
        # 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
        if do_sample and candidate_logits is not None:
            valid_tokens, n_matches = _speculative_sampling_2(
                candidate_input_ids,
                candidate_logits,
                candidate_length,
                new_logits,
                is_done_candidate,
            )

        # Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
        # original model logits with the candidate tokens. We can keep the candidate tokens until the first
        # mismatch, or until the max length is reached.
        else:
            if do_sample:
                probs = new_logits.softmax(dim=-1)
                selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
            else:
                selected_tokens = new_logits.argmax(dim=-1)

            candidate_new_tokens = candidate_input_ids[:, cur_len:]
            n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

            # Ensure we don't generate beyond max_len or an EOS token
            if is_done_candidate and n_matches == candidate_length:
                n_matches -= 1
            valid_tokens = selected_tokens[:, : n_matches + 1]

        added_lengths.append(n_matches.item() + 1)
        #print('added_length:', added_lengths[-1])

        # 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
        # by the model after the last candidate match is also valid, as it is generated from a correct sequence.
        # Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
        # is no match.

        # 4.1. Get the valid continuation, after the matching tokens
        input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
        if streamer is not None:
            streamer.put(valid_tokens.cpu())
        new_cur_len = input_ids.shape[-1]

        # 4.2. Discard past key values relative to unused assistant tokens
        new_cache_size = new_cur_len - 1
        outputs.past_key_values = _crop_past_key_values(target_model, outputs.past_key_values, new_cache_size)

        # 5. Update the candidate generation strategy if needed
        candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)

        # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
        model_kwargs = target_model._update_model_kwargs_for_generation(
            outputs,
            model_kwargs,
            is_encoder_decoder=target_model.config.is_encoder_decoder,
            num_new_tokens=n_matches + 1,
        )
        if synced_gpus and this_peer_finished:
            continue

        # Store scores, attentions and hidden_states when required
        # Assistant: modified to append one tuple element per token, as in the other generation methods.
        if return_dict_in_generate:
            newly_added_length = n_matches + 1
            if output_scores:
                scores += tuple(new_logits[:, i, :] for i in range(newly_added_length))
            if output_logits:
                raw_logits += tuple(next_token_logits[:, i, :] for i in range(newly_added_length))

            newly_added_length = new_cur_len if is_first_iteration else newly_added_length
            if output_attentions:
                if target_model.config.is_encoder_decoder:
                    cross_attentions = _split_model_outputs(
                        cross_attentions, outputs.cross_attentions, cur_len, newly_added_length
                    )
                    decoder_attentions = _split_model_outputs(
                        decoder_attentions,
                        outputs.decoder_attentions,
                        cur_len,
                        newly_added_length,
                        is_decoder_attention=True,
                    )
                # some (V)LLMs have hard requirement on SDPA and thus never return attn
                elif outputs.attentions[0] is not None:
                    decoder_attentions = _split_model_outputs(
                        decoder_attentions,
                        outputs.attentions,
                        cur_len,
                        newly_added_length,
                        is_decoder_attention=True,
                    )
            if output_hidden_states:
                if target_model.config.is_encoder_decoder:
                    decoder_hidden_states = _split_model_outputs(
                        decoder_hidden_states, outputs.decoder_hidden_states, cur_len, newly_added_length
                    )
                else:
                    decoder_hidden_states = _split_model_outputs(
                        decoder_hidden_states, outputs.hidden_states, cur_len, newly_added_length
                    )

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0
        is_first_iteration = False

    print('candidate_lengths:', candidate_length_arr)
    print('added_lengths:', added_lengths)
    
    if streamer is not None:
        streamer.end()

    if (
        hasattr(candidate_generator, "assistant_model")
        and candidate_generator.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic"
    ):
        candidate_generator.assistant_model.generation_config.num_assistant_tokens = (
            candidate_generator.num_assistant_tokens
        )
    if return_dict_in_generate:
        if target_model.config.is_encoder_decoder:
            return GenerateEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
        else:
            return GenerateDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    else:
        return input_ids, added_lengths


def _speculative_sampling_2(
    candidate_input_ids,
    candidate_logits,
    candidate_length,
    new_logits,
    is_done_candidate,
):
    """
    Applies sampling as in the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf, algorithm 1). Returns
    the selected tokens, as well as the number of candidate matches.

    NOTE: Unless otherwise stated, the variable names match those in the paper.
    """
    new_candidate_input_ids = candidate_input_ids[:, -candidate_length:]
    # Gets the probabilities from the logits. q_i and p_i denote the assistant and model probabilities of the tokens
    # selected by the assistant, respectively.
    q = candidate_logits.softmax(dim=-1)
    q_i = q[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    p = new_logits.softmax(dim=-1)
    p_i = p[:, torch.arange(candidate_length), new_candidate_input_ids].squeeze(0, 1)
    #probability_ratio = p_i / q_i

    # lossy speculation
    lossy_rate = float(os.environ.get("LOSSY_RATE"))
    if lossy_rate <= 0 or lossy_rate > 1:
        raise ValueError(f"Unsupported lossy rate: {lossy_rate}. Must be in (0; 1].")
    #print('probability_ratio:', p_i / (q_i * lossy_rate))
    probability_ratio = p_i / (q_i * lossy_rate)

    # When probability_ratio > 1 (i.e. q_i(x) < p_i(x), or "assistant probability of the candidate token is smaller
    # than the model probability for the same token"), keep the token. Otherwise reject with p = 1 - probability_ratio
    # (= keep with p = probability_ratio). Keep all the tokens until the first rejection
    r_i = torch.rand_like(probability_ratio)
    is_accepted = r_i <= probability_ratio
    n_matches = ((~is_accepted).cumsum(dim=-1) < 1).sum()  # this is `n` in algorithm 1

    # Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct behavior)
    if is_done_candidate and n_matches == candidate_length:
        # Output length is assumed to be `n_matches + 1`. Since we won't generate another token with the target model
        # due to acceptance on EOS we fix `n_matches`
        n_matches -= 1
        valid_tokens = new_candidate_input_ids[:, : n_matches + 1]
    else:
        # Next token selection: if there is a rejection, adjust the distribution from the main model before sampling.
        gamma = candidate_logits.shape[1]
        p_n_plus_1 = p[:, n_matches, :]
        if n_matches < gamma:
            q_n_plus_1 = q[:, n_matches, :]
            p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0)
            p_prime.div_(p_prime.sum())
        else:
            p_prime = p_n_plus_1
        t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

        # The selected tokens include the matches (if any) plus the next sampled tokens
        if n_matches > 0:
            valid_tokens = torch.cat((new_candidate_input_ids[:, :n_matches], t), dim=-1)
        else:
            valid_tokens = t

    return valid_tokens, n_matches

In [9]:
target_model._assisted_decoding = _assisted_decoding_2

In [10]:
def generate(input_ids, model, args):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    torch.cuda.synchronize()
    start_event.record()

    with torch.inference_mode():
        output_ids, accepted_lengths = model.generate(
            input_ids,
            max_new_tokens=1024,
            do_sample=True,
            assistant_model=draft_model,
            pad_token_id=tokenizer.pad_token_id,
        )
        #print('accepted_lengths:', accepted_lengths)

    end_event.record()
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)

    return output_ids, elapsed_time_ms, np.mean(accepted_lengths)

In [11]:
def lossy_experiment(n_samples=5):
    correct = 0
    times = []
    nums_accepted = []
    
    lossy_rate = float(os.environ.get("LOSSY_RATE"))
    print(f"LOSSY RATE: {lossy_rate}")
    
    with tqdm(total=n_samples) as pbar:
        for sample_idx in range(0, n_samples):
            question_sample = gsm_questions[sample_idx]
            answer = question_sample['answer']
            question = question_sample['question']
            formatted_zero_shot_prompt = prompt_with_8_shots + question + "\n" + formatting_prompt
            batch_input_ids = tokenizer.apply_chat_template(
                [{'role': 'user', 'content': formatted_zero_shot_prompt}],
                tokenize=True, return_tensors='pt', padding=True, continue_final_message=False  # <--- MIGHT BE A BUG
            ).to(device)
            generation, elapsed_time, num_tokens_accepted = generate(batch_input_ids, target_model, args)
            
            nums_accepted.append(num_tokens_accepted)
            times.append(elapsed_time)
            
            generation_str = tokenizer.batch_decode(generation)[0]
            # print(generation_str)
            raw_pred = extract_answer(generation_str)
    
            pred_float = extract_float(raw_pred)
            answer_float = extract_float(answer)
    
            gen_tokens = generation.shape[-1] - batch_input_ids.shape[-1]
            verdict = int(answer_float == pred_float)
    
            correct += verdict
            accuracy = correct / (sample_idx + 1)
    
            print(
                f"{sample_idx:>4}.  verdict:{verdict:<1} "
                f"Averages:  acc:{accuracy:1.3f} "
                f"pred: {pred_float}, answer: {answer_float} "
                f"mean_num_accepted: {np.round(num_tokens_accepted, 2)} "
                f"time: {np.round(elapsed_time / 1000, 2)}s"
            )
            pbar.update(1)

    return times, accuracy, nums_accepted

#### Treshold = 0.01

In [99]:
os.environ["LOSSY_RATE"] = "1"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 1.0


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 3]
added_lengths: [21, 2, 12, 1, 2, 9, 5, 5, 11, 6, 1, 17, 3, 1, 6, 9, 2, 21, 3]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 7.21 time: 21.09s
candidate_lengths: [14, 6, 3, 4, 2, 1, 5, 13, 15, 10, 20, 10, 1, 2, 1, 6, 9, 5, 20, 20]
added_lengths: [6, 7, 4, 3, 2, 1, 6, 13, 16, 9, 6, 8, 2, 2, 2, 5, 7, 5, 3, 19]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 6.3 time: 11.81s
candidate_lengths: [6, 1, 3, 2, 2, 4, 1, 2, 6, 4, 1, 1, 8, 3, 6, 1, 2, 2, 4, 1, 4, 2, 2, 3, 3, 2, 1, 1, 2, 1, 1, 1, 4, 6, 9, 1, 1, 2, 1, 13, 1, 4, 4, 1, 8, 7, 5, 1, 2, 1, 11, 1, 2, 20, 5]
added_lengths: [7, 2, 4, 2, 3, 5, 1, 3, 7, 5, 1, 2, 9, 3, 7, 2, 2, 3, 4, 2, 5, 3, 3, 4, 3, 3, 1, 1, 2, 2, 1, 2, 4, 7, 9, 2, 2, 2, 2, 14, 2, 4, 5, 2, 4, 8, 5, 2, 3, 2, 12, 1, 3, 21, 5]
   2.  verdict:0 Averages:  acc:0.667 pred: 90000.0, answer: 70000.0 mean_num_accepted: 4.0

In [100]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 1.0, time - 15.083965808868408 s, acc - 0.85546875, mean_accepted_tokens - 5.908096242274175


In [101]:
os.environ["LOSSY_RATE"] = "0.5"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.5


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [6, 16, 6, 1, 1, 1, 1, 2, 12, 1, 3, 4, 3, 2, 4, 1, 6, 11, 8, 2, 9, 2, 1, 8, 9, 6, 6, 9, 6]
added_lengths: [7, 17, 6, 2, 1, 2, 2, 3, 13, 1, 3, 2, 3, 3, 1, 1, 6, 12, 9, 1, 6, 2, 2, 9, 10, 7, 7, 6, 6]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 5.17 time: 12.3s
candidate_lengths: [14, 4, 1, 4, 2, 13, 8, 1, 3, 16, 8, 1, 20, 20]
added_lengths: [12, 4, 1, 3, 3, 10, 8, 1, 1, 16, 8, 1, 10, 19]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 6.93 time: 8.46s
candidate_lengths: [9, 4, 3, 4, 3, 2, 3, 5, 4, 3, 1, 9, 3, 1, 20, 10, 2, 12, 20, 13, 2, 18, 1, 3, 1, 17, 6, 20, 13]
added_lengths: [9, 5, 4, 5, 4, 3, 4, 6, 4, 1, 1, 10, 1, 2, 17, 10, 3, 12, 21, 14, 3, 19, 1, 1, 1, 7, 7, 21, 13]
   2.  verdict:1 Averages:  acc:1.000 pred: 70000.0, answer: 70000.0 mean_num_accepted: 7.21 time: 15.41s
candidate_lengths: [5, 1, 1, 12, 9, 3, 2, 3, 1, 6, 1, 3, 3, 3, 20, 15, 4, 14, 1, 16]
added_lengths: [5, 1, 2, 1, 9, 4, 3, 4

In [103]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.5, time - 14.290847503662109 s, acc - 0.83984375, mean_accepted_tokens - 7.188244466096613


In [104]:
os.environ["LOSSY_RATE"] = "0.3"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.3


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [8, 20, 14, 8, 20, 20, 4, 1, 4, 19, 20, 1, 15, 8, 11, 18, 3, 20, 12]
added_lengths: [7, 21, 14, 8, 6, 3, 2, 1, 4, 2, 13, 1, 15, 9, 9, 1, 4, 21, 12]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 8.05 time: 14.84s
candidate_lengths: [6, 6, 3, 10, 2, 6, 7, 14, 1, 17, 9, 3, 1, 20, 2]
added_lengths: [6, 7, 3, 4, 3, 7, 8, 9, 2, 18, 9, 4, 2, 21, 1]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 6.93 time: 8.67s
candidate_lengths: [8, 1, 4, 2, 1, 4, 9, 4, 3, 4, 7, 1, 4, 2, 7, 14, 8, 2, 1, 20, 20, 18, 17, 15, 4, 1, 2, 4, 1, 8, 1]
added_lengths: [8, 1, 5, 3, 2, 5, 10, 5, 4, 5, 8, 2, 5, 3, 8, 15, 9, 2, 2, 11, 4, 19, 17, 16, 1, 1, 2, 4, 2, 8, 2]
   2.  verdict:0 Averages:  acc:0.667 pred: -10000.0, answer: 70000.0 mean_num_accepted: 6.1 time: 14.79s
candidate_lengths: [5, 9, 6, 5, 11, 2, 10, 6, 11, 11, 20]
added_lengths: [5, 10, 7, 5, 12, 3, 11, 7, 12, 12, 20]
   3.  verdict:1 Averages:  acc:0.750 pred: 540.0, 

In [105]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.3, time - 13.79042766571045 s, acc - 0.84375, mean_accepted_tokens - 7.278357959899259


In [106]:
os.environ["LOSSY_RATE"] = "0.1"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.1


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [6, 6, 6, 2, 1, 2, 2, 3, 2, 2, 16, 1, 8, 2, 3, 1, 16, 12, 4, 2, 9, 1, 5, 1, 5, 2, 1, 20, 7, 17, 5, 2, 3, 20]
added_lengths: [7, 7, 7, 3, 2, 3, 3, 2, 1, 2, 17, 1, 5, 2, 3, 1, 4, 13, 4, 3, 10, 2, 6, 1, 6, 3, 2, 21, 8, 17, 2, 3, 4, 20]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 5.74 time: 15.08s
candidate_lengths: [6, 7, 2, 2, 1, 1, 2, 12, 1, 6, 2, 5, 1, 8, 1, 19, 4, 18, 19]
added_lengths: [7, 8, 2, 3, 1, 2, 3, 13, 1, 7, 2, 6, 1, 8, 2, 20, 4, 19, 18]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 6.68 time: 9.18s
candidate_lengths: [9, 3, 2, 5, 9, 13, 1, 10, 11, 11, 4, 6, 2, 3, 20, 4, 6, 20, 6, 20, 4, 20, 4]
added_lengths: [10, 4, 2, 6, 10, 4, 1, 11, 3, 12, 5, 7, 2, 4, 21, 5, 7, 18, 6, 21, 5, 21, 4]
   2.  verdict:0 Averages:  acc:0.667 pred: 120000.0, answer: 70000.0 mean_num_accepted: 8.22 time: 13.65s
candidate_lengths: [5, 1, 2, 8, 2, 4, 2, 1, 4, 1, 2, 5, 10, 2, 5, 7, 5, 2, 9, 20, 2]
added_lengt

In [107]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.1, time - 13.428416919708251 s, acc - 0.8046875, mean_accepted_tokens - 7.927992733562955


In [108]:
os.environ["LOSSY_RATE"] = "0.01"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.01


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [6, 18, 7, 6, 10, 20, 3, 3, 4, 2, 12, 12, 11, 10, 2, 7, 6, 7, 10, 2, 3, 11]
added_lengths: [7, 19, 8, 6, 10, 8, 4, 1, 5, 2, 13, 13, 11, 11, 3, 2, 7, 7, 11, 3, 4, 11]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 7.55 time: 12.61s
candidate_lengths: [7, 5, 6, 3, 3, 2, 20, 12, 1, 20, 1, 5, 1, 20, 1, 2, 20]
added_lengths: [8, 5, 2, 2, 3, 3, 1, 7, 2, 21, 2, 6, 2, 6, 2, 3, 21]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 5.65 time: 9.5s
candidate_lengths: [9, 3, 1, 2, 1, 4, 8, 4, 9, 1, 4, 6, 8, 6, 3, 2, 14, 1, 20, 11, 7, 20, 5]
added_lengths: [9, 3, 2, 3, 2, 5, 9, 5, 9, 1, 5, 6, 9, 7, 4, 2, 15, 2, 21, 12, 8, 21, 5]
   2.  verdict:0 Averages:  acc:0.667 pred: 75000.0, answer: 70000.0 mean_num_accepted: 7.17 time: 11.67s
candidate_lengths: [5, 8, 6, 7, 14, 2, 5, 4, 8, 20, 16]
added_lengths: [6, 9, 7, 8, 15, 2, 5, 4, 9, 12, 16]
   3.  verdict:0 Averages:  acc:0.500 pred: 180.0, answer: 540.0 mean_num_acce

In [109]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.01, time - 13.069739028930664 s, acc - 0.85546875, mean_accepted_tokens - 7.803476787045731


In [110]:
# os.environ["LOSSY_RATE"] = "0.001"

# times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

In [111]:
# lossy_rate = float(os.environ.get("LOSSY_RATE"))
# print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

#### Treshold = 0

In [12]:
draft_model.generation_config.assistant_confidence_threshold = 0
draft_model.generation_config.num_assistant_tokens = 20

In [35]:
os.environ["LOSSY_RATE"] = "1"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 1.0


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 16, 12]
added_lengths: [10, 3, 10, 2, 3, 2, 2, 2, 4, 4, 11, 2, 5, 1, 3, 2, 12, 1, 1, 6, 5, 8, 10, 9, 3, 21, 1, 9, 12]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 5.66 time: 31.98s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 9, 7]
added_lengths: [14, 9, 2, 1, 1, 16, 16, 15, 3, 6, 16, 1, 6]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 8.15 time: 13.91s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 11]
added_lengths: [21, 4, 4, 4, 6, 12, 20, 1, 1, 21, 7, 21, 13, 7, 4, 20, 6, 7, 21, 11]
   2.  verdict:1 Averages:  acc:1.000 pred: 70000.0, answer: 70000.0 mean_num_accepted: 10.55 time: 22.3s
candidate_lengths: [12, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 19, 17]
added_lengths: [5, 1, 3, 3,

In [36]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 1.0, time - 23.946503761291503 s, acc - 0.86328125, mean_accepted_tokens - 8.686347416503905


In [37]:
os.environ["LOSSY_RATE"] = "0.5"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.5


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 19, 20, 20, 14]
added_lengths: [8, 13, 2, 3, 2, 9, 3, 1, 2, 2, 13, 2, 3, 2, 6, 14, 21, 13, 2, 1, 3, 1, 5, 8, 1, 19, 5, 14]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 6.36 time: 31.45s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
added_lengths: [6, 6, 8, 4, 13, 1, 1, 2, 1, 21, 1, 9, 19]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 7.08 time: 15.01s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 9]
added_lengths: [7, 6, 3, 4, 3, 2, 5, 5, 4, 10, 11, 3, 7, 10, 8, 8, 5, 4, 7, 1, 6, 16, 11, 11, 21, 21, 21, 21, 1, 3, 11, 20, 8]
   2.  verdict:0 Averages:  acc:0.667 pred: 0.0, answer: 70000.0 mean_num_accepted: 8.61 time: 36.65s
candidate_lengths: [12, 20, 20, 20, 20, 20, 20, 20, 2

In [39]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.5, time - 19.225626674652098 s, acc - 0.859375, mean_accepted_tokens - 10.661923355014363


In [40]:
os.environ["LOSSY_RATE"] = "0.3"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.3


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 19]
added_lengths: [21, 4, 17, 16, 5, 4, 3, 6, 1, 1, 19, 4, 21, 1, 4, 1, 21, 1, 3, 18]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 8.55 time: 22.39s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 7]
added_lengths: [12, 9, 13, 1, 8, 1, 4, 5, 9, 20, 21, 6]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 9.08 time: 13.11s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 9]
added_lengths: [15, 13, 10, 16, 21, 14, 7, 6, 21, 21, 21, 11, 21, 14, 21, 6, 21, 2, 7, 13, 5, 8, 1, 5, 1, 21, 21, 21, 9]
   2.  verdict:0 Averages:  acc:0.667 pred: 170000.0, answer: 70000.0 mean_num_accepted: 12.86 time: 32.16s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 13]
added_lengths: [6, 16, 8, 7, 2, 21, 2, 21, 13]
   3.  verdict:1 Averages:  a

In [41]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.3, time - 18.333567180633544 s, acc - 0.82421875, mean_accepted_tokens - 11.610591745964244


In [42]:
os.environ["LOSSY_RATE"] = "0.1"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.1


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 12]
added_lengths: [21, 4, 4, 4, 2, 1, 5, 5, 17, 5, 6, 10, 1, 1, 21, 21, 11]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 8.18 time: 18.81s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 9]
added_lengths: [7, 13, 14, 2, 6, 1, 1, 21, 21, 8]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 9.4 time: 11.04s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 16]
added_lengths: [9, 14, 4, 12, 14, 4, 15, 14, 7, 21, 21, 12, 21, 16]
   2.  verdict:0 Averages:  acc:0.667 pred: 43333.0, answer: 70000.0 mean_num_accepted: 13.14 time: 15.8s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 13, 15]
added_lengths: [21, 2, 18, 8, 3, 21, 4, 21, 6, 1, 15]
   3.  verdict:1 Averages:  acc:0.750 pred: 540.0, answer: 540.0 mean_num_accepted: 10.91 time: 12.16s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20

In [43]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.1, time - 16.999432041168212 s, acc - 0.82421875, mean_accepted_tokens - 11.812966250613588


In [44]:
os.environ["LOSSY_RATE"] = "0.01"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.01


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
added_lengths: [21, 2, 9, 6, 12, 2, 16, 5, 8, 21, 20, 10, 20]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 11.69 time: 15.03s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 8]
added_lengths: [12, 8, 3, 11, 3, 3, 15, 12, 13, 7]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 8.7 time: 11.03s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 3, 10]
added_lengths: [18, 9, 11, 10, 21, 21, 2, 6, 21, 12, 9, 15, 2, 21, 3, 10]
   2.  verdict:1 Averages:  acc:1.000 pred: 70000.0, answer: 70000.0 mean_num_accepted: 11.94 time: 16.97s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 9]
added_lengths: [7, 21, 17, 5, 4, 6, 7, 16, 21, 9]
   3.  verdict:1 Averages:  acc:1.000 pred: 540.0, answer: 540.0 mean_num_accepted: 11.3 time: 11.08s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20,

In [45]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.01, time - 17.34913021850586 s, acc - 0.828125, mean_accepted_tokens - 12.064152305545356


In [14]:
os.environ["LOSSY_RATE"] = "0.001"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.001


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20]
added_lengths: [8, 1, 12, 7, 7, 7, 8, 18, 5, 7, 13, 6, 1, 21, 21, 21]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 10.19 time: 18.06s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 8, 20, 19, 16]
added_lengths: [19, 6, 10, 3, 7, 4, 11, 11, 1, 15, 1, 15]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 8.58 time: 12.85s
candidate_lengths: [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 19]
added_lengths: [9, 19, 2, 2, 3, 1, 11, 14, 10, 21, 21, 21, 13, 21, 10, 2, 21, 19]
   2.  verdict:1 Averages:  acc:1.000 pred: 70000.0, answer: 70000.0 mean_num_accepted: 12.22 time: 20.06s
candidate_lengths: [12, 20, 20, 20, 20, 20, 20, 20, 6]
added_lengths: [5, 21, 19, 9, 21, 5, 18, 21, 5]
   3.  verdict:1 Averages:  acc:1.000 pred: 540.0, answer: 540.0 mean_num_accepted: 13.78 time: 9.29s
candidate_lengths: [20, 20, 20, 20, 

In [15]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.001, time - 16.311200370788573 s, acc - 0.8359375, mean_accepted_tokens - 12.167193646415562


In [12]:
draft_model.generation_config.num_assistant_tokens = 12
draft_model.generation_config.assistant_confidence_threshold = 0

In [15]:
os.environ["LOSSY_RATE"] = "1"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 1.0


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 7]
added_lengths: [7, 5, 9, 2, 6, 13, 3, 2, 5, 1, 5, 4, 6, 2, 2, 13, 6, 13, 2, 13, 1, 3, 13, 6]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 5.92 time: 18.3s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 8]
added_lengths: [6, 12, 3, 2, 3, 3, 6, 1, 1, 2, 13, 10, 7, 12, 9, 7]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 6.06 time: 12.31s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 2]
added_lengths: [9, 6, 11, 13, 1, 1, 9, 4, 8, 4, 5, 1, 2, 4, 13, 13, 12, 13, 8, 2, 8, 3, 1, 9, 2, 5, 1, 2, 1, 1, 1, 4, 7, 7, 13, 7, 3, 7, 7, 13, 13, 13, 1, 13, 13, 1]
   2.  verdict:1 Averages:  acc:1.000 pred: 70000.0, answer: 70000.0 mean_n

In [16]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 1.0, time - 19.15632596206665 s, acc - 0.85546875, mean_accepted_tokens - 7.334224421234497


In [19]:
os.environ["LOSSY_RATE"] = "0.5"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.5


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 2]
added_lengths: [7, 2, 13, 3, 5, 1, 13, 2, 2, 5, 4, 3, 10, 4, 11, 3, 13, 6, 6, 13, 1]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 6.05 time: 15.72s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11]
added_lengths: [12, 9, 2, 1, 12, 3, 1, 10, 13, 7, 13, 13, 10]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 8.15 time: 10.18s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 3]
added_lengths: [13, 1, 6, 2, 13, 1, 1, 13, 1, 7, 6, 4, 3, 3, 13, 13, 6, 9, 2, 5, 9, 13, 13, 5, 13, 13, 13, 2]
   2.  verdict:0 Averages:  acc:0.667 pred: 120000.0, answer: 70000.0 mean_num_accepted: 7.25 time: 21.18s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
added_lengths: [5, 1, 13, 9, 13, 5, 5, 2

In [20]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.5, time - 17.301128684997558 s, acc - 0.8359375, mean_accepted_tokens - 8.47801146298399


In [21]:
os.environ["LOSSY_RATE"] = "0.3"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.3


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
added_lengths: [13, 5, 2, 2, 3, 3, 13, 10, 2, 5, 6, 8, 6, 13, 3, 13, 13, 9, 13, 13, 13]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 8.0 time: 16.24s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 3]
added_lengths: [13, 6, 1, 10, 9, 13, 4, 8, 13, 13, 13, 13, 2]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 9.08 time: 9.76s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 7]
added_lengths: [8, 1, 13, 13, 3, 1, 13, 13, 13, 10, 1, 13, 13, 13, 13, 4, 4, 13, 13, 6, 13, 13, 13, 7]
   2.  verdict:1 Averages:  acc:1.000 pred: 70000.0, answer: 70000.0 mean_num_accepted: 9.46 time: 18.29s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 3]
added_lengths: [5, 13, 8, 13, 13, 13, 13, 12, 3, 13, 3]
   3.  verdict:1 Averages:  acc:1.00

In [22]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.3, time - 15.704358388900756 s, acc - 0.8515625, mean_accepted_tokens - 8.745080569514583


In [23]:
os.environ["LOSSY_RATE"] = "0.1"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.1


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
added_lengths: [13, 6, 3, 1, 2, 3, 10, 7, 9, 2, 4, 13, 1, 1, 13, 13, 11]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 6.59 time: 13.17s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 2]
added_lengths: [13, 6, 6, 4, 13, 2, 13, 4, 13, 4, 13, 3, 13, 13, 13, 1]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 8.38 time: 11.87s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
added_lengths: [13, 2, 7, 6, 11, 12, 4, 4, 13, 3, 13, 13, 13, 5, 13, 2, 5, 1, 13, 13]
   2.  verdict:0 Averages:  acc:0.667 pred: 150000.0, answer: 70000.0 mean_num_accepted: 8.3 time: 15.49s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
added_lengths: [6, 1, 13, 3, 13, 4, 10, 8, 13, 13, 13, 13, 11]
   3.  verdict:1 Averages:  acc:0.750 pred: 540.0, answer: 

In [24]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.1, time - 15.36103758430481 s, acc - 0.83203125, mean_accepted_tokens - 9.073583519346865


In [25]:
os.environ["LOSSY_RATE"] = "0.05"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.05


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 9]
added_lengths: [13, 5, 1, 5, 3, 3, 5, 6, 13, 5, 13, 9, 5, 5, 1, 3, 6, 2, 13, 13, 13, 9]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 6.86 time: 16.88s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 9]
added_lengths: [13, 6, 1, 7, 2, 9, 13, 3, 7, 13, 10, 8]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 7.67 time: 9.33s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 3]
added_lengths: [8, 4, 5, 1, 4, 7, 4, 6, 13, 13, 10, 9, 4, 2, 1, 13, 13, 3, 13, 13, 7, 13, 13, 13, 2, 13, 13, 13, 13, 13, 3]
   2.  verdict:0 Averages:  acc:0.667 pred: 3333.33, answer: 70000.0 mean_num_accepted: 8.45 time: 23.19s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 8]
added_lengths: [13, 13, 11, 3, 13, 13, 13, 5, 13, 8]


In [26]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.05, time - 15.059922733306884 s, acc - 0.83203125, mean_accepted_tokens - 9.087630553518569


In [21]:
os.environ["LOSSY_RATE"] = "0.01"

times, accuracy, nums_accepted = lossy_experiment(n_samples=256)

LOSSY RATE: 0.01


  0%|          | 0/256 [00:00<?, ?it/s]

candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 6]
added_lengths: [7, 13, 13, 1, 2, 1, 2, 13, 2, 5, 6, 10, 9, 4, 13, 5, 7, 4, 6]
   0.  verdict:1 Averages:  acc:1.000 pred: 18.0, answer: 18.0 mean_num_accepted: 6.47 time: 14.28s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 4]
added_lengths: [5, 7, 9, 13, 3, 2, 7, 13, 1, 5, 12, 13, 13, 3]
   1.  verdict:1 Averages:  acc:1.000 pred: 3.0, answer: 3.0 mean_num_accepted: 7.57 time: 10.37s
candidate_lengths: [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 8]
added_lengths: [8, 1, 13, 8, 6, 10, 1, 1, 13, 13, 13, 13, 13, 4, 13, 13, 1, 13, 13, 13, 13, 8]
   2.  verdict:1 Averages:  acc:1.000 pred: 70000.0, answer: 70000.0 mean_num_accepted: 9.27 time: 16.59s
candidate_lengths: [11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 6]
added_lengths: [5, 1, 1, 13, 6, 2, 12, 13, 13, 8, 13, 13, 5]
   3.  verdict:1 Averages:  acc:1.000 pred: 540.0, 

In [22]:
lossy_rate = float(os.environ.get("LOSSY_RATE"))
print(f"Lossy спекуляция: LOSSY RATE - {lossy_rate}, time - {np.mean(times) / 1000} s, acc - {accuracy}, mean_accepted_tokens - {np.mean(nums_accepted)}")

Lossy спекуляция: LOSSY RATE - 0.01, time - 14.716214794158935 s, acc - 0.81640625, mean_accepted_tokens - 9.075393501622166


#### Результаты

In [19]:
Драфтовая модель: 7.950274225234986 s, accuracy - 0.25
Таргетная модель: time - 21.397363939285277 s, accuracy - 0.859375
Lossy спекуляция: LOSSY RATE - 1.0, time - 14.738972623825074 s, acc - 0.85546875, mean_accepted_tokens - 5.783127490207388
Lossy спекуляция: LOSSY RATE - 0.5, time - 13.542602083206177 s, acc - 0.8359375, mean_accepted_tokens - 7.181438002028306
Lossy спекуляция: LOSSY RATE - 0.3, time - 13.780028093338013 s, acc - 0.84765625, mean_accepted_tokens - 7.534450154459602
Lossy спекуляция: LOSSY RATE - 0.1, time - 13.497806966781615 s, acc - 0.8203125, mean_accepted_tokens - 8.002264932672569
Lossy спекуляция: LOSSY RATE - 0.01, time - 14.116077709197999 s, acc - 0.83203125, mean_accepted_tokens - 7.867767907980765
Lossy спекуляция: LOSSY RATE - 0.001, time - 14.258594095230103 s, acc - 0.828125, mean_accepted_tokens - 7.8543560757065105
Lossy спекуляция: LOSSY RATE - 0.0001, time - 13.31358434677124 s, acc - 0.8359375, mean_accepted_tokens - 7.8614880717472015
Lossy спекуляция: LOSSY RATE - 1e-05, time - 13.57956993484497 s, acc - 0.8046875, mean_accepted_tokens - 7.991759691454453
Lossy спекуляция: LOSSY RATE - 1e-06, time - 13.545313425064087 s, acc - 0.84375, mean_accepted_tokens - 7.9721436960552

SyntaxError: invalid syntax (<ipython-input-19-f08f377966ee>, line 1)

In [10]:
stats = {
    'Среднее время, сек': [7.95, 21.40, 14.74, 13.54, 13.78, 13.50, 14.12, 14.26, 13.31, 13.58, 13.55],
    'Среднее число принятых токенов': ['-', '-', 5.78, 7.18, 7.53, 8.00, 7.87, 7.85, 7.86, 7.99, 7.97],
    'Точность': [0.25, 0.86, 0.86, 0.84, 0.85, 0.82, 0.83, 0.83, 0.84, 0.80, 0.84]
}

row_names = ['Драфтовая модель', 'Таргетная модель', 'Lossy rate = 1 (обычная спекулятивная генерация)', 'Lossy rate = 0.5', 'Lossy rate = 0.3', 'Lossy rate = 0.1', 
            r'Lossy rate = $10^{-2}$', r'Lossy rate = $10^{-3}$', r'Lossy rate = $10^{-4}$', r'Lossy rate = $10^{-5}$', r'Lossy rate = $10^{-6}$']

df = pd.DataFrame(stats, index=row_names)

In [11]:
df

Unnamed: 0,"Среднее время, сек",Среднее число принятых токенов,Точность
Драфтовая модель,7.95,-,0.25
Таргетная модель,21.4,-,0.86
Lossy rate = 1 (обычная спекулятивная генерация),14.74,5.78,0.86
Lossy rate = 0.5,13.54,7.18,0.84
Lossy rate = 0.3,13.78,7.53,0.85
Lossy rate = 0.1,13.5,8.0,0.82
Lossy rate = $10^{-2}$,14.12,7.87,0.83
Lossy rate = $10^{-3}$,14.26,7.85,0.83
Lossy rate = $10^{-4}$,13.31,7.86,0.84
Lossy rate = $10^{-5}$,13.58,7.99,0.8


In [2]:
import numpy as np
import pandas as pd

In [3]:
# Tokens = 20, Threshold = 0.01
stats1 = {
    'Среднее время, сек': [7.95, 21.40, np.round(15.083965808868408, 2), 
                           np.round(14.290847503662109, 2), np.round(13.79042766571045, 2), 
                           np.round(13.428416919708251, 2), np.round(13.069739028930664, 2)],
    'Среднее число принятых токенов': ['-', '-', np.round(5.908096242274175, 2), np.round(7.188244466096613, 2), 
                                       np.round(7.278357959899259, 2), np.round(7.927992733562955, 2),
                                       np.round(7.803476787045731, 2)],
    'Точность': [0.25, 0.86, np.round(0.85546875, 2), np.round(0.83984375, 2), np.round(0.84375, 2), np.round(0.8046875, 2), np.round(0.85546875, 2)]
}

row_names1 = ['Драфтовая модель', 'Таргетная модель', 'l = 1 (обычная спекулятивная генерация)', 'l = 0.5', 'l = 0.3', 'l = 0.1', 'l = 0.01']

df1 = pd.DataFrame(stats1, index=row_names1)

In [4]:
# Tokens = 20, Threshold = 0
stats2 = {
    'Среднее время, сек': [7.95, 21.40, np.round(23.946503761291503, 2), 
                           np.round(19.225626674652098, 2), np.round(18.333567180633544, 2), 
                           np.round(16.999432041168212, 2), np.round(17.34913021850586, 2), 
                           np.round(16.311200370788573, 2)],
    'Среднее число принятых токенов': ['-', '-', np.round(8.686347416503905, 2), np.round(10.661923355014363, 2), 
                                       np.round(11.610591745964244, 2), np.round(11.812966250613588, 2),
                                       np.round(12.064152305545356, 2), np.round(12.167193646415562, 2)],
    'Точность': [0.25, 0.86, np.round(0.86328125, 2), np.round(0.859375, 2), np.round(0.82421875, 2), 
                 np.round(0.82421875, 2), np.round(0.828125, 2), np.round(0.8359375, 2)]
}

row_names2 = ['Драфтовая модель', 'Таргетная модель', 'l = 1 (обычная спекулятивная генерация)', 'l = 0.5', 'l = 0.3', 'l = 0.1', 'l = 0.01', 'l = 0.001']

df2 = pd.DataFrame(stats2, index=row_names2)

In [5]:
# Tokens = 12, Threshold = 0
stats3 = {
    'Среднее время, сек': [7.95, 21.40, np.round(19.15632596206665, 2), 
                           np.round(17.301128684997558, 2), np.round(15.704358388900756, 2), 
                           np.round(15.36103758430481, 2), np.round(15.059922733306884, 2), 
                           np.round(14.716214794158935, 2)],
    'Среднее число принятых токенов': ['-', '-', np.round(7.334224421234497, 2), np.round(8.47801146298399, 2), 
                                       np.round(8.745080569514583, 2), np.round(9.073583519346865, 2),
                                       np.round(9.087630553518569, 2), np.round(9.075393501622166, 2)],
    'Точность': [0.25, 0.86, np.round(0.85546875, 2), np.round(0.8359375, 2), np.round(0.8515625, 2), 
                 np.round(0.83203125, 2), np.round(0.83203125, 2), np.round(0.81640625, 2)]
}

row_names3 = ['Драфтовая модель', 'Таргетная модель', 'l = 1 (обычная спекулятивная генерация)', 'l = 0.5', 'l = 0.3', 'l = 0.1', 'l = 0.05', 'l = 0.001']

df3 = pd.DataFrame(stats3, index=row_names3)

**Num tokens = 20, threshold = 0.01**

In [6]:
df1

Unnamed: 0,"Среднее время, сек",Среднее число принятых токенов,Точность
Драфтовая модель,7.95,-,0.25
Таргетная модель,21.4,-,0.86
l = 1 (обычная спекулятивная генерация),15.08,5.91,0.86
l = 0.5,14.29,7.19,0.84
l = 0.3,13.79,7.28,0.84
l = 0.1,13.43,7.93,0.8
l = 0.01,13.07,7.8,0.86


**Num tokens = 20, threshold = 0**

In [7]:
df2

Unnamed: 0,"Среднее время, сек",Среднее число принятых токенов,Точность
Драфтовая модель,7.95,-,0.25
Таргетная модель,21.4,-,0.86
l = 1 (обычная спекулятивная генерация),23.95,8.69,0.86
l = 0.5,19.23,10.66,0.86
l = 0.3,18.33,11.61,0.82
l = 0.1,17.0,11.81,0.82
l = 0.01,17.35,12.06,0.83
l = 0.001,16.31,12.17,0.84


**Num tokens = 12, threshold = 0**

In [8]:
df3

Unnamed: 0,"Среднее время, сек",Среднее число принятых токенов,Точность
Драфтовая модель,7.95,-,0.25
Таргетная модель,21.4,-,0.86
l = 1 (обычная спекулятивная генерация),19.16,7.33,0.86
l = 0.5,17.3,8.48,0.84
l = 0.3,15.7,8.75,0.85
l = 0.1,15.36,9.07,0.83
l = 0.05,15.06,9.09,0.83
l = 0.001,14.72,9.08,0.82


__________________________________________________

### Тестирование со SpecExec

In [11]:
from typing import Union, Sequence, Tuple
import json
import os

from tqdm.auto import tqdm
from termcolor import colored
import numpy as np
import pandas as pd

import itertools
import torch
from torch import nn
import torch.nn.functional as F
import transformers
import re
from transformers import BitsAndBytesConfig, AutoModelForCausalLM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def extract_answer(s, suffix='<|eot_id|>'):
    s = s.lower().replace(suffix, '').replace('the final answer is', '=')
    idx = s.rfind("=")
    if idx != - 1:
        return s[idx + 1:].strip()
    

def extract_float(num_str):
    try:
        num_str = re.sub(r'[^0-9.-]', '', num_str).strip(".")
        return float(num_str)
    except (ValueError, TypeError):
        return

class args:
    draft_model = 'meta-llama/Llama-3.2-1B-Instruct'
    target_model = 'meta-llama/Llama-3.1-8B-Instruct'
    #draft_model = 'JackFram/llama-68m'
    #target_model = 'lmsys/vicuna-7b-v1.3'
    torch_dtype = 'auto'
    # data from https://github.com/openai/grade-school-math/tree/master/grade_school_math/data
    # use train.jsonl and test.jsonl
    gsm8k_test_path = '/kaggle/input/gsm888/test.jsonl'
    gsm8k_train_path = '/kaggle/input/gsm888/train.jsonl'
    random_seed = 42
    max_new_tokens = 1024
    n_samples = None


prompt_with_8_shots = "Given the following problem, reason and give a final answer to the problem.\nProblem: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The final answer is 6\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The final answer is 5\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The final answer is 39\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The final answer is 8\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The final answer is 9\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The final answer is 29\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The final answer is 33\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nYour response should end with \"The final answer is [answer]\" where [answer] is the response to the problem.\n Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The final answer is 8\n\nGiven the following problem, reason and give a final answer to the problem.\nProblem: "
prompt_with_0_shots = "Given the following problem, reason and give a final answer to the problem.\n"
formatting_prompt = "Your response should end with \"The final answer is [answer]\" where [answer] is the response to the problem."

np.random.seed(args.random_seed)

def load_questions(args):
    with open(args.gsm8k_test_path) as f:
        gsm_questions = [json.loads(line) for line in f]

    n_samples = args.n_samples or len(gsm_questions)
    gsm_questions = [
        {
            'question': i['question'],
            'answer': i['answer'][i['answer'].rfind('#### ') + 5:]
        }
        for i in gsm_questions[:n_samples]
    ]

    return gsm_questions

gsm_questions = load_questions(args)
n_samples = len(gsm_questions)
#n_samples = 4
print(f'Num samples: {n_samples}')

Num samples: 1319


In [12]:
spec_generator = create_spec_generator(
    model_name_0=args.draft_model,
    model_name_1=args.target_model,
    draft_engine_class='padded',
    gen_type='sx_base',
    offload=False,
    tree_max_len=4096
)

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

AttributeError: 'LlamaAttention_FI' object has no attribute 'rotary_emb'

In [13]:
tokenizer_ = transformers.AutoTokenizer.from_pretrained(
    'meta-llama/Llama-2-7b-chat-hf', 
)
spec_generator.tokenizer.chat_template = tokenizer_.chat_template
spec_generator.tokenizer.pad_token_id = spec_generator.tokenizer.eos_token_id


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [20]:
spec_generator.tokenizer.chat_template

"{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"

In [17]:
def generate(input_ids, spec_generator, args):
    with torch.inference_mode():
        _ = spec_generator.generate(
            input_ids,
            lossy_const=0,
            max_n_beams=128,
            max_beam_len=32,
            max_new_tokens=args.max_new_tokens,
            max_budget=16,
            max_branch_width=32,
            temperature=1.0,
            #draft_temperature=None,
            top_p=1.0,
            #min_log_prob=min_log_prob,
            seed=0,
            tree_max_len=4096,
            #**kwargs,
        )
        output_ids = spec_generator.prefix_tokens
        print(spec_generator.tokenizer.decode(output_ids))
        return output_ids


correct = 0

with tqdm(total=n_samples) as pbar:
    for sample_idx in range(0, n_samples):
        question_sample = gsm_questions[sample_idx]
        answer = question_sample['answer']
        question = question_sample['question']
        formatted_zero_shot_prompt = prompt_with_8_shots + question + "\n" + formatting_prompt
        batch_input_ids = spec_generator.tokenizer.apply_chat_template(
            [{'role': 'user', 'content': formatted_zero_shot_prompt}],
            tokenize=False, return_tensors='pt', padding=True, continue_final_message=False  # <--- MIGHT BE A BUG
        )
        #print(batch_input_ids)
        generation = generate(batch_input_ids, spec_generator, args)
        generation_str = spec_generator.tokenizer.decode(generation)
        raw_pred = extract_answer(generation_str)
        #print('raw_pred', raw_pred)
        pred_float = extract_float(raw_pred)
        #print('pred_float', pred_float)
        #print('answer', answer)
        answer_float = extract_float(answer)
        #print('answer_float', answer_float)
        #break

        # gen_tokens = generation.shape[-1] - batch_input_ids.shape[-1]
        gen_tokens = len(spec_generator.prefix_tokens[spec_generator.original_num_tokens :])
        verdict = int(answer_float == pred_float)

        correct += verdict
        accuracy = correct / (sample_idx + 1)

        print(
            f"{sample_idx:>4}.  verdict:{verdict:<1} "
            f"Averages:  acc:{accuracy:1.3f} "
            f"pred: {pred_float}, answer: {answer_float}"
        )
        pbar.update(1)

  0%|          | 0/4 [00:00<?, ?it/s]

<s><s>[INST] Given the following problem, reason and give a final answer to the problem.
Problem: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
Your response should end with "The final answer is [answer]" where [answer] is the response to the problem.
 There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The final answer is 6

Given the following problem, reason and give a final answer to the problem.
Problem: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
Your response should end with "The final answer is [answer]" where [answer] is the response to the problem.
 There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The final answer is 5

Given the following problem, reason and give a final answer to the problem.
Problem: Leah 

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [10]:
import gc
draft_model.to('cpu')
del draft_model
torch.cuda.empty_cache()
gc.collect()

0