In [2]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers # requires transformers==4.35.2

In [3]:
print(transformers.__version__)

4.43.3


In [4]:
# model_name = "Qwen/CodeQwen1.5-7B-Chat"
# model_name = "TheBloke/deepseek-coder-6.7B-base-GPTQ"
# model_revision = "gptq-4bit-32g-actorder_True"
model_name = "deepseek-ai/deepseek-coder-6.7b-base"

# draft_model_name = "TheBloke/deepseek-coder-1.3b-base-GPTQ"
# model_revision =  "gptq-4bit-32g-actorder_True"
draft_model_name = "deepseek-ai/deepseek-coder-1.3b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto")#, revision=model_revision)
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, trust_remote_code=True, device_map="auto")

Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Unrecognized keys in `rope_scaling` for 'rope_type'='linear': {'type'}


In [5]:
NEWLINE_THRESHOLD = 10

In [6]:
print(tokenizer.encode("..."))
print(tokenizer.encode("""
"""))
print(tokenizer.encode("##"))

newline_token = tokenizer.encode("""
""")[-1]

[32013, 1202]
[32013, 185]
[32013, 1672]


In [7]:
@torch.no_grad()
def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
    input_length = input_ids.size(1)

    # Ensure max_ngram_size and num_pred_tokens are valid
    if max_ngram_size <= 0 or num_pred_tokens <= 0 or max_ngram_size > input_length:
        raise ValueError("Invalid max_ngram_size or num_pred_tokens")

    for ngram_size in range(max_ngram_size, 0, -1):
        # Extract the last n tokens as our search ngram
        ngram = input_ids[0, -ngram_size:].tolist()

        # Create sliding windows of size ngram_size
        windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)

        # Convert ngram to a tensor for comparison
        ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)

        # Find where the windows match the ngram
        matches = (windows == ngram_tensor).all(dim=2)

        # Get the indices of matches
        match_indices = matches.nonzero(as_tuple=True)[1]

        # Iterate through match indices to find a valid continuation
        for idx in match_indices:
            start_idx = idx + ngram_size
            end_idx = start_idx + num_pred_tokens
            # Ensure we don't go beyond the length of input_ids and avoid self-match
            # if end_idx <= input_length and start_idx < input_length - ngram_size:
            #     return input_ids[0, start_idx:end_idx]
            if start_idx < input_length - ngram_size:
                return input_ids[0, start_idx:min(end_idx, input_length)]

    # If no match is found, return an empty tensor
    return torch.tensor([100], dtype=torch.long, device=input_ids.device)

In [8]:
from dataclasses import dataclass

@dataclass
class Hunk:
    filepath: str
    text: str

@dataclass
class SearchReplaceChange:
    filepath: str
    search_block: str
    replace_block: str

def find_hunks(diff_string):
    hunks = []
    current_filename = ""
    current_lines = ""
    for line in diff_string.splitlines():
        if line.startswith("---"):
            continue
        elif line.lstrip().startswith("+++"):
            if len(current_filename) > 0:
                hunks.append(Hunk(current_filename, current_lines))
            current_filename = line[3:]
            current_lines = ""
        elif line.lstrip().startswith("@@"):
            if len(current_filename) > 0:
                hunks.append(Hunk(current_filename, current_lines))
            current_lines = ""
        else:
            current_lines += line
            current_lines += "\n"
    hunks.append(Hunk(current_filename, current_lines))
    return hunks

def parse_diff(diff_string):
    hunks = find_hunks(diff_string)
    search_replace_blocks = []

    for hunk in hunks:
        filepath = hunk.filepath
        text = hunk.text

        search_block = ""
        replace_block = ""

        for line in text.splitlines():
            if line.startswith("-"):
                search_block += " " + line[1:] + "\n"
            elif line.startswith("+"):
                replace_block += " " + line[1:] + "\n"
            else:
                search_block += line + "\n"
                replace_block += line + "\n"

        search_replace_blocks.append(
            SearchReplaceChange(filepath, search_block, replace_block)
        )
        filepath = ""
        search_block = ""
        replace_block = ""

    search_replace_blocks.append(
        SearchReplaceChange(filepath, search_block, replace_block)
    )

    return search_replace_blocks

In [9]:
from thefuzz import fuzz

@dataclass
class Match:
    block: str
    score: float

def line_relevant(line):
    return not(len(line.strip()) == 0 or line.startswith("#") or line.startswith("//"))

def find_best_match(query_code: str, original_code: str):
    query_code = query_code.strip()

    original_lines = original_code.splitlines()
    query_lines = query_code.splitlines()

    if len(query_lines) == 0:
        return Match("SUPERDOCSTHISSTRINGWILLNEVEREVERBEFOUND", 100)

    best_match = Match("", -1)

    for start_line in range(len(original_lines)):
        min_end = min(len(original_lines), max(start_line, start_line + len(query_lines) - 5)) # +/- 5 lines for tolerance
        max_end = min(len(original_lines), start_line + len(query_lines) + 5)
        for end_line in range(min_end, max_end):
            full_original_snippet = "\n".join(original_lines[start_line:end_line + 1])

            snippet_from_original = "\n".join([line for line in original_lines[start_line:end_line + 1] if line_relevant(line)]) # the loop already doesn't include max_end
            snippet_from_query = "\n".join([line for line in query_lines if line_relevant(line)])

            stripped_original = " ".join([line.strip() for line in snippet_from_original.splitlines()])
            stripped_query =  " ".join([line.strip() for line in snippet_from_query.splitlines()])

            score = fuzz.ratio(stripped_original, stripped_query)

            # Weighting the first and last lines by 3x
            score += 3*fuzz.ratio(original_lines[start_line], query_lines[0])
            score += 3*fuzz.ratio(original_lines[end_line], query_lines[-1])

            if score > best_match.score:
                best_match = Match(full_original_snippet, score)
    return best_match

In [10]:
import difflib

@torch.no_grad()
def find_candidate_pred_tokens_diff(input_ids, code_ids, orig_input_len=0, ngram_size=3, num_pred_tokens=10):
    # start_time = time.perf_counter()
    input_length = input_ids.size(1)
    code_length = len(code_ids)

    # Ensure max_ngram_size and num_pred_tokens are valid
    if ngram_size <= 0 or ngram_size > input_length:
        raise ValueError("Invalid max_ngram_size or num_pred_tokens")

    sm = difflib.SequenceMatcher(None, code_ids, input_ids[0, orig_input_len:].tolist())
    
    deleted = added = changed = same = last_deleted = 0
    for tag, i1, i2, j1, j2 in sm.get_opcodes():
        if tag == 'replace':
            changed += i2 - i1
        elif tag == 'delete':
            deleted += i2 - i1
            last_deleted = i2 - i1
        elif tag == 'insert':
            added += j2 - j1
        elif tag == 'equal':
            same += i2 - i1
    
    approx_tokens_original = changed + deleted + same - last_deleted

    lookback_start = max(input_length - ngram_size, orig_input_len)
    search_ngram = input_ids[0, lookback_start:].tolist()

    for ngram_start in range(max(0, approx_tokens_original - ngram_size), len(code_ids)):
        # if there is a match, return the entire rest of the tokens.
        if ngram_start + len(search_ngram) >= len(code_ids):
            break
        if search_ngram == code_ids[ngram_start:ngram_start + len(search_ngram)]:
            return torch.tensor(code_ids[ngram_start + len(search_ngram):max(ngram_start + len(search_ngram) + num_pred_tokens, len(code_ids))], dtype=torch.long, device=input_ids.device)

    # If no match is found, return what the answer would be otherwise
    # print("Diff searching took: ", time.perf_counter() - start_time)
    return find_candidate_pred_tokens(input_ids, ngram_size, num_pred_tokens)
    # return torch.tensor([], dtype=torch.long, device=input_ids.device)


In [11]:
COLORS = ["\x1b[31m", "\x1b[32m", "\x1b[34m", "\x1b[35m"]  # Red, Green, Blue, Magenta
UNDERLINE = "\x1b[4m"
RESET = "\x1b[0m"

In [36]:
from transformers.generation.candidate_generator import CandidateGenerator, _crop_past_key_values
from transformers.generation.stopping_criteria import StoppingCriteria
from transformers.generation.configuration_utils import GenerationConfig
from typing import Tuple, Optional

class DiffPromptLookupCandidateGenerator(CandidateGenerator):
    def __init__(self, input_ids, code_ids, ngram_size=3, num_pred_tokens=10):
        self.code_ids = code_ids
        self.orig_input_len = input_ids.shape[-1]
        self.ngram_size = ngram_size
        self.num_pred_tokens = num_pred_tokens
    
    def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
        # print("Getting candidates")
        return torch.cat(
            (
                input_ids,
                find_candidate_pred_tokens_diff(input_ids, self.code_ids, self.orig_input_len, self.ngram_size, self.num_pred_tokens).unsqueeze(0)
            ),
            dim=-1
        ), None
    
    def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): # Maybe use the number of matches/scores to have a threshold
        pass 

class NumRunsStoppingCriteria(StoppingCriteria):
    def __init__(self, max_num_runs=4):
        self.max_num_runs = 4
        self.num_runs = 0

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        self.num_runs += 1
        return self.num_runs >= self.max_num_runs

class NewlineStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer, prompt_tokens: int, newline_count=5):
        self.newline_token = tokenizer.encode("""
""")[-1]
        self.newline_count = newline_count
        self.prompt_tokens = prompt_tokens

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        considered_tokens = input_ids[:, self.prompt_tokens:].list()
        newline_list = [self.newline_token for i in range(self.newline_count)]
        return newline_line in considered_tokens[0]

def _get_default_candidate_generator_generator(generator: CandidateGenerator):
    def _get_candidate_generator(self, **kwargs):
        return generator
    return _get_candidate_generator

class TwoLayerLookupCandidateGenerator(CandidateGenerator):
    def __init__(self, tokenizer, prompt_tokens, draft_model, input_ids, code_ids, num_runs=4, **diff_prompt_args):
        self.tokenizer = tokenizer
        self.prompt_tokens = prompt_tokens
        self.draft_model = draft_model
        self.input_ids = input_ids
        self.code_ids = code_ids
        self.candidate_generator = DiffPromptLookupCandidateGenerator(
            self.input_ids, 
            self.code_ids,
            **diff_prompt_args
        )
        self.draft_model.generation_config.pad_token_id = tokenizer.pad_token_id
        
        self.past_keys_values = None
        self.num_runs = num_runs

        self.draft_model._get_candidate_generator = (_get_default_candidate_generator_generator(self.candidate_generator)).__get__(self.draft_model, type(self.draft_model))

    def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
        if self.past_keys_values:
            self.past_keys_values = _crop_past_key_values(self.draft_model, self.past_keys_values, input_ids.shape[-1] - 1)

        starting_input_length = input_ids.shape[-1]
        # print("Getting draft candidates")
        
        generation = self.draft_model.generate(
            inputs=input_ids,
            attention_mask=torch.ones(input_ids.shape[-1], device=input_ids.device).unsqueeze(0),
            prompt_lookup_num_tokens=1,
            max_new_tokens=1000,
            stopping_criteria=[NumRunsStoppingCriteria(self.num_runs), NewlineStoppingCriteria(self.tokenizer, self.prompt_tokens)],
            past_key_values=self.past_keys_values,
            use_cache=True,
            # output_logits=True,
            output_scores=True,
            return_dict_in_generate=True
        )

        return generation.sequences, None

    def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): # Maybe use the number of matches/scores to have a threshold
        pass 

In [None]:
from fastapi import FastAPI, Path
from pydantic import BaseModel
from typing import Annotated, Union
import uvicorn
import asyncio
from fastapi.middleware.cors import CORSMiddleware
from transformers import TextStreamer

app = FastAPI()

class EditRequest(BaseModel):
    file_content: str
    query: str

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/edit_request")
async def edit_request(edit_request: EditRequest):
    # prompt = "Code:```python\n{code_text}``` \n\n Question: {question} \n\n Modified code:\n".format(code_text=code_text, question=question)
    shot = """## Code Before:
def add(a, b):
    return a + b
## Instruction:
Add a "sub" function that subtracts two numbers. Also write docstrings for both functions and change a,b to x,y.
## Code After:
def add(x, y):
    \"\"\"Adds two numbers.\"\"\"
    return x + y

def sub(x, y):
    \"\"\"Subtracts two numbers.\"\"\"
    return x - y









"""
    
    prompt = shot + "\n" + f"""## Code Before:
{edit_request.file_content}
## Instruction:
{edit_request.query}
## Code After:
"""
    
    code_inputs = tokenizer(edit_request.file_content, return_tensors="pt").input_ids[0].tolist()
    inputs = tokenizer(prompt, return_tensors="pt")

    # inputs = tokenizer.apply_chat_template(
    #     [
    #         {
    #             "role": "user",
    #             "content": prompt
    #         },
    #         {
    #             "role": "assistant",
    #             "content": "```\n"
    #         }
    #     ],
    #     tokenize=True,
    #     add_generation_prompt=True,
    #     return_tensors="pt"
    # ).to(model.device)
    
    # Move all tensor values in the inputs to GPU
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)

    # num_max_gen_tokens = inputs.shape[-1] + 300
    num_max_gen_tokens = inputs.input_ids.shape[-1] + 300

    # diff_candidate_generator = DiffPromptLookupCandidateGenerator(
    #     inputs.input_ids, 
    #     code_inputs
    # )
    
    # draft_model._get_candidate_generator = (_get_default_candidate_generator_generator(diff_candidate_generator)).__get__(draft_model, type(draft_model))
    

    two_layer_candidate_generator = TwoLayerLookupCandidateGenerator(
        tokenizer,
        inputs.input_ids.shape[-1],
        draft_model,
        inputs.input_ids,
        code_inputs,
        ngram_size=5,
        num_pred_tokens=50
    )
    model._get_candidate_generator = (_get_default_candidate_generator_generator(two_layer_candidate_generator)).__get__(model, type(model))

    test_out = model.generate(
        inputs=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        prompt_lookup_num_tokens=1,
        max_new_tokens=num_max_gen_tokens,
        stopping_criteria=[NewlineStoppingCriteria(tokenizer, inputs.input_ids.shape[-1])],
        use_cache=True,
        streamer=TextStreamer(tokenizer)
    )
    
    text = tokenizer.batch_decode(test_out[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)[0]
    # text = tokenizer.batch_decode(test_out, skip_special_tokens=True)[0]
    
    unified_diff = "\n".join(difflib.unified_diff(edit_request.file_content.splitlines(), text.splitlines(), n=3))
    search_replace_changes = parse_diff(unified_diff)
    # print("Search replace changes: ", search_replace_changes)
    
    fixed_file = edit_request.file_content
    for sr in search_replace_changes:
        if len(sr.search_block.strip()) == 0:
            continue
        print("SEARCH\n", sr.search_block)
        sr.search_block = find_best_match(sr.search_block, edit_request.file_content).block
        print("FOUND BLOCK\n", sr.search_block)
        print("REPLACE\n", sr.replace_block)
        fixed_file = fixed_file.replace(sr.search_block, f"""<<<<<<< SEARCH
{sr.search_block}
=======
{sr.replace_block}
>>>>>>> REPLACE""")

    return {"text": fixed_file}

if __name__ == "__main__":
    config = uvicorn.Config(app)
    server = uvicorn.Server(config)
    await server.serve()

INFO:     Started server process [1262226]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
Setting `pad_token_id` to `eos_token_id`:32014 for open-end generation.


<｜begin▁of▁sentence｜>## Code Before:
def add(a, b):
    return a + b
## Instruction:
Add a "sub" function that subtracts two numbers. Also write docstrings for both functions and change a,b to x,y.
## Code After:
def add(x, y):
    """Adds two numbers."""
    return x + y

def sub(x, y):
    """Subtracts two numbers."""
    return x - y










## Code Before:
import numpy as np
import matplotlib.pyplot as plt

# Calculate the average
average_throughput = np.mean(tokens_per_sec_arr)
print(f"Average Throughput: {average_throughput} tokens/sec")

# Plotting the histogram
plt.hist(tokens_per_sec_arr, bins=20, color='blue', edgecolor='black', alpha=0.7)
plt.title('Histogram of Throughput Values')
plt.xlabel('Tokens per Second')
plt.ylabel('Frequency')
plt.axvline(average_throughput, color='red', linestyle='dashed', linewidth=1)
plt.text(average_throughput*0.9, max(plt.ylim())*0.9, f'Average: {average_throughput:.2f}', color = 'red')
plt.show()

## Instruction:
Can you please change x axi