In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from transformers.generation.utils import _crop_past_key_values
import difflib
import torch.nn.functional as F
import torch

In [55]:
model_name = "deepseek-ai/deepseek-coder-6.7b-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
target_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype=torch.float32)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [56]:
token_ids = tokenizer("# Write fibbonacci sequence in python", return_tensors="pt")
prediction = target_model.generate(input_ids=token_ids.input_ids, max_new_tokens=40)
print(prediction)
print(tokenizer.batch_decode(prediction))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:32014 for open-end generation.


tensor([[32013,     2, 17437, 12606,  6656,   305,  2711,  6905,   279,  9942,
           185,   185,  1551, 12606,     7,    77,  1772,   185,   315,   245,
            11,   270,  1412,    15,    11,   207,    16,   185,   315,  1470,
           245,  1013,   291,    25,   185,   436,  3628,     7,    64,    11,
          1223,    28,     6,   651,     8,   185,   436,   245,    11,   270]])
["<｜begin▁of▁sentence｜># Write fibbonacci sequence in python\n\ndef fib(n):\n    a, b = 0, 1\n    while a < n:\n        print(a, end=' ')\n        a, b"]


In [75]:
@torch.no_grad()
def regular_generate(target_model, tokenizer, input_str, max_seq_len=50):
    input_ids = tokenizer(input_str, return_tensors="pt").input_ids
    current_position = input_ids.shape[1]
    token_ids = torch.zeros((1, max_seq_len), dtype=torch.int32)
    token_ids[:, :current_position] = input_ids
    
    prediction = target_model(input_ids=token_ids[:, :current_position], use_cache=True, return_dict=True)
    cache = prediction.past_key_values
    pred_logits = prediction.logits
    pred_logits = F.softmax(pred_logits, dim=-1)
    token_ids[:, current_position] = torch.argmax(pred_logits[:, current_position - 1], dim=-1)
    current_position += 1

    while len(token_ids) < max_seq_len:
        prepared_input = target_model.prepare_inputs_for_generation()
        prediction = target_model(input_ids=token_ids[:, :current_position], use_cache=True, past_key_values=cache, return_dict=True, cache_position=torch.arange(current_position - 1, current_position, device=target_model.device))
        cache = prediction.past_key_values
        pred_logits = prediction.logits
        pred_logits = F.softmax(pred_logits, dim=-1)
        token_ids[:, current_position] = torch.argmax(pred_logits[:, current_position - 1], dim=-1)
        print(tokenizer.batch_decode(token_ids[:, :current_position]))
        current_position += 1
        if token_ids[0, -1] == tokenizer.eos_token or token_ids[0, -1] == tokenizer.pad_token: # Check for newline
            break

    # print(tokenizer.batch_decode(token_ids))

In [76]:
regular_generate(target_model, tokenizer, "# Write a python function that generates the fibonacci sequence")

['<｜begin▁of▁sentence｜># Write a python function that generates the fibonacci sequence up']
['<｜begin▁of▁sentence｜># Write a python function that generates the fibonacci sequence up to']
['<｜begin▁of▁sentence｜># Write a python function that generates the fibonacci sequence up to a']
['<｜begin▁of▁sentence｜># Write a python function that generates the fibonacci sequence up to a python']
['<｜begin▁of▁sentence｜># Write a python function that generates the fibonacci sequence up to a python function']
['<｜begin▁of▁sentence｜># Write a python function that generates the fibonacci sequence up to a python function that']
['<｜begin▁of▁sentence｜># Write a python function that generates the fibonacci sequence up to a python function that generates']


KeyboardInterrupt: 