In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers # requires transformers==4.35.2

  _torch_pytree._register_pytree_node(


In [2]:
model_name = "deepseek-ai/deepseek-coder-6.7b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="cuda:0", torch_dtype=torch.float16, use_flash_attention_2=True) # , use_flash_attention=True)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from transformers.models.auto import (
    MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
    MODEL_FOR_CAUSAL_LM_MAPPING,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
    MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
    MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from transformers.utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
from transformers.generation.beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from transformers.generation.beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import (
    EncoderNoRepeatNGramLogitsProcessor,
    EncoderRepetitionPenaltyLogitsProcessor,
    EpsilonLogitsWarper,
    EtaLogitsWarper,
    ExponentialDecayLengthPenalty,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    ForceTokensLogitsProcessor,
    HammingDiversityLogitsProcessor,
    InfNanRemoveLogitsProcessor,
    LogitNormalization,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    MinNewTokensLengthLogitsProcessor,
    NoBadWordsLogitsProcessor,
    NoRepeatNGramLogitsProcessor,
    PrefixConstrainedLogitsProcessor,
    RepetitionPenaltyLogitsProcessor,
    SequenceBiasLogitsProcessor,
    SuppressTokensAtBeginLogitsProcessor,
    SuppressTokensLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
    TopPLogitsWarper,
    TypicalLogitsWarper,
    UnbatchedClassifierFreeGuidanceLogitsProcessor,
)
from transformers.generation.stopping_criteria import (
    MaxLengthCriteria,
    MaxTimeCriteria,
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)

from transformers.generation.utils import _crop_past_key_values
import difflib

@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
    """
    Base class for outputs of decoder-only generation models using greedy search.


    Args:
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
        scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
            Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
    """

    sequences: torch.LongTensor = None
    scores: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

In [4]:
@torch.no_grad()
def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
    input_length = input_ids.size(1)

    # Ensure max_ngram_size and num_pred_tokens are valid
    if max_ngram_size <= 0 or num_pred_tokens <= 0 or max_ngram_size > input_length:
        raise ValueError("Invalid max_ngram_size or num_pred_tokens")

    for ngram_size in range(max_ngram_size, 0, -1):
        # Extract the last n tokens as our search ngram
        ngram = input_ids[0, -ngram_size:].tolist()

        # Create sliding windows of size ngram_size
        windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

        # Convert ngram to a tensor for comparison
        ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)

        # Find where the windows match the ngram
        matches = (windows == ngram_tensor).all(dim=2)

        # Get the indices of matches
        match_indices = matches.nonzero(as_tuple=True)[1]

        # Iterate through match indices to find a valid continuation
        for idx in match_indices:
            start_idx = idx + ngram_size
            end_idx = start_idx + num_pred_tokens
            # Ensure we don't go beyond the length of input_ids and avoid self-match
            # if end_idx <= input_length and start_idx < input_length - ngram_size:
            #     return input_ids[0, start_idx:end_idx]
            if start_idx < input_length - ngram_size:
                return input_ids[0, start_idx:min(end_idx, input_length)]

    # If no match is found, return an empty tensor
    return torch.tensor([], dtype=torch.long, device=input_ids.device)

In [5]:
def dynamic_cache_crop_with_start(self, start: int, max_length: int):
    """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
    negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""

    # In case it is negative
    if max_length < 0:
        max_length = self.get_seq_length() - abs(max_length)

    if self.get_seq_length() <= max_length:
        return

    self._seen_tokens = max_length
    for idx in range(len(self.key_cache)):
        self.key_cache[idx] = self.key_cache[idx][..., start:max_length, :]
        self.value_cache[idx] = self.value_cache[idx][..., start:max_length, :]


In [6]:
def _crop_past_key_values_with_start(model, past_key_values, start, max_length):
    """Crops the past key values up to a certain maximum length."""
    new_past = []
    # if isinstance(past_key_values, DynamicCache):
    #     past_key_values.crop_with_start = dynamic_cache_crop_with_start.__get__(past_key_values, type(past_key_values))
    #     past_key_values.crop_with_start(start, max_length)
    # elif past_key_values is not None:
    for idx in range(len(past_key_values)):
        new_past.append(
            (
                past_key_values[idx][0][:, :, start:max_length, :],
                past_key_values[idx][1][:, :, start:max_length, :],
            )
        )
    past_key_values = tuple(new_past)
    # return past_key_values

In [7]:
COLORS = ["\x1b[31m", "\x1b[32m", "\x1b[34m", "\x1b[35m"]  # Red, Green, Blue, Magenta
UNDERLINE = "\x1b[4m"
RESET = "\x1b[0m"
NEWLINE_THRESHOLD = 5

In [26]:
import difflib

@torch.no_grad()
def greedy_search_pld(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        draft_matching_window_size = 3,
        draft_num_candidate_tokens = 10,
        print_output=True,
        kv_tolerance=15,
        **model_kwargs,
    ):

        global tokenizer

        # init values
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

        # # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None

        max_len = stopping_criteria[0].max_length

        i = 0
        current_color_index = 0
        start_length = input_ids.shape[-1]

        while True:
            i += 1
            cur_len = input_ids.shape[-1]

            candidate_pred_tokens = find_candidate_pred_tokens(input_ids, draft_matching_window_size, draft_num_candidate_tokens)
            
            if len(candidate_pred_tokens) == 0:
                candidate_pred_tokens = torch.tensor([100], device=input_ids.device).unsqueeze(0)
            else:
                candidate_pred_tokens = candidate_pred_tokens.unsqueeze(0)

            # pre-generation, you should be cropping the inputs
            sm = difflib.SequenceMatcher(input_ids[0, :start_length].tolist(), input_ids[0, start_length:].tolist())
            original_code_index = 0
            last_deleted = 0
            for tag, i1, i2, j1, h2 in sm.get_opcodes():
                if tag == 'delete':
                    original_code_index += i2 - i1
                    last_deleted += i2 - i1
                elif tag == 'insert':
                    ...
                elif tag == 'equal':
                    original_code_index += i2 - i1
                elif tag == 'replace':
                    original_code_index += i2 - i1
            original_code_index -= last_deleted
            original_code_index = max(0, original_code_index - kv_tolerance)

            if "past_key_values" in model_kwargs:
                cropped_input_ids = input_ids[original_code_index:]
                
                candidate_input_ids = torch.cat((cropped_input_ids, candidate_pred_tokens), dim=1)
                
                candidate_length = candidate_input_ids.shape[1] - cropped_input_ids.shape[1]
                
                candidate_kwargs = copy.copy(model_kwargs)
                candidate_kwargs["past_key_values"] = _crop_past_key_values_with_start(self, model_kwargs["past_key_values"], original_code_index, input_ids.shape[1])
            else:
                candidate_input_ids = torch.cat((input_ids, candidate_pred_tokens), dim=1)
                candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
                candidate_kwargs = copy.copy(model_kwargs)
            
            candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1])
            candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])

            model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
            
            # prepare model inputs
            # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )


            new_logits = outputs.logits[:, -candidate_length - 1 :]  # excludes the input prompt if present
            selected_tokens = new_logits.argmax(dim=-1)
            candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
            n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

            
            # if last_assistant_token_is_eos and n_matches == candidate_length: # todo: do this earlier somehow
            #     n_matches -= 1
            
            n_matches = min(n_matches, max_len - cur_len - 1)

            # print(n_matches)
            # i+= n_matches.item()

            if print_output:
                current_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
            
            valid_tokens = selected_tokens[:, : n_matches + 1]
            input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
            new_cur_len = input_ids.shape[-1]

            if input_ids.shape[-1] > NEWLINE_THRESHOLD: # Check that there are max 5 consecutive newlines.
                flag = True
                for i in range(NEWLINE_THRESHOLD):
                    if not(input_ids[0, -i] == 185): # Is a newline
                        flag = False
                if flag:
                    break

            if print_output:
                updated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
                # Find and print the newly added text
                if updated_text != current_text:
                    new_text = updated_text[len(current_text):]
                    if len(valid_tokens[0]) > 1:
                        color = COLORS[current_color_index]
                        print(f"{color}{new_text}{RESET}", end='')
                        # Update color for next generation
                        current_color_index = (current_color_index + 1) % len(COLORS)
                    else:
                        print(f"{new_text}", end='')

            
            new_cache_size = new_cur_len - 1            
            outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
        
            model_kwargs["past_key_values"] = outputs.past_key_values

            # stop if we exceed the maximum length

            if (valid_tokens == eos_token_id_tensor.item()).any():
                break
            
            if stopping_criteria(input_ids, scores):
                break


        if return_dict_in_generate:
            return GreedySearchDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                # attentions=decoder_attentions,
                # hidden_states=decoder_hidden_states,
            )
        else:
            return input_ids

In [27]:
model.greedy_search_pld = greedy_search_pld.__get__(model, type(model))

In [28]:
code_text = """import numpy as np
import matplotlib.pyplot as plt

# Calculate the average
average_throughput = np.mean(tokens_per_sec_arr)
print(f"Average Throughput: {average_throughput} tokens/sec")

# Plotting the histogram
plt.hist(tokens_per_sec_arr, bins=20, color='blue', edgecolor='black', alpha=0.7)
plt.title('Histogram of Throughput Values')
plt.xlabel('Tokens per Second')
plt.ylabel('Frequency')
plt.axvline(average_throughput, color='red', linestyle='dashed', linewidth=1)
plt.text(average_throughput*0.9, max(plt.ylim())*0.9, f'Average: {average_throughput:.2f}', color = 'red')
plt.show()
"""

question = "Can you please change x axis to start from 0"
prompt = "[INST] Code:```python\n{code_text}``` \n\n Question: {question} \n\n Modified code:[/INST]".format(code_text=code_text, question=question)

inputs = tokenizer(prompt, return_tensors="pt")
# Move all tensor values in the inputs to GPU
for key in inputs:
    inputs[key] = inputs[key].to(model.device)

In [29]:
model.greedy_search_pld = greedy_search_pld.__get__(model, type(model))

In [31]:
import time

start_time = time.perf_counter()
test_out = model.greedy_search_pld(inputs.input_ids,
            attention_mask = inputs.attention_mask,
            stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=len(inputs.input_ids[0]) + 500)]),
            prompt_matching_window_size = 3,
            prompt_num_candidate_tokens = 50,
            use_cache=True, 
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            print_output=True
        )
end_time = time.perf_counter()
print("\nTime taken: ", end_time - start_time)



```python[31m
import numpy as np
import matplotlib[0m[32m.pyplot as plt

# Calculate[0m[34m the average
average_throughput = np.[0m[35mmean(tokens_per_sec_arr)[0m[31m
print(f"Average Throughput: {[0m[32maverage_throughput} tokens/sec")
[0m[34m
# Plotting the histogram
plt[0m[35m.hist(tokens_per_sec_arr[0m,[31m bins=20, color='blue', edge[0m[32mcolor='black', alpha=0.7)[0m[34m
plt.title('Histogram of Throughput[0m[35m Values')
plt.xlabel('Tok[0m[31mens per Second')
plt.ylabel('[0m[32mFrequency')
plt.axvline([0m[34maverage_throughput, color='red',[0m[35m linestyle='dashed', linewidth=1[0m[31m)
plt.text(average_through[0m[32mput*[0m[34m0.9, max(plt.ylim[0m[35m())*0.9, f'Average:[0m[31m {average_throughput:.2f}',[0m[32m color = 'red')
plt.x[0mlim(0[34m, max[0m[35m(tok[0m[31mens_per_sec_arr))[0m
pl[32mt.show[0m[34m()
```
[0m





[35m




[0m[31m









[0m[32m










[0m[34m










[0m

KeyboardInterrupt: 