In [47]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_default_device('cuda')
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto")
model = torch.compile(
    AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-1_5", trust_remote_code=True, torch_dtype="auto")
    .eval()
)

In [27]:
from human_eval.data import read_problems, write_jsonl
problems = read_problems()

In [28]:
problems['HumanEval/0']

{'task_id': 'HumanEval/0',
 'prompt': 'from typing import List\n\n\ndef has_close_elements(numbers: List[float], threshold: float) -> bool:\n    """ Check if in given list of numbers, are any two numbers closer to each other than\n    given threshold.\n    >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n    False\n    >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n    True\n    """\n',
 'entry_point': 'has_close_elements',
 'canonical_solution': '    for idx, elem in enumerate(numbers):\n        for idx2, elem2 in enumerate(numbers):\n            if idx != idx2:\n                distance = abs(elem - elem2)\n                if distance < threshold:\n                    return True\n\n    return False\n',
 'test': "\n\nMETADATA = {\n    'author': 'jt',\n    'dataset': 'test'\n}\n\n\ndef check(candidate):\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n    assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n    assert candidate([1.0, 2.0, 5.9,

In [64]:
def parse_response_phi15(response: str):
    """Extract the solution from phi1 model's response, as it often
    generates some random function after the required solution was generated.

    This could be improved further with more time.
    """
    # discard the original prompt as it is included in the response
    #response = response[len(prompt):]

    # get the result until the second def
    def1_pos = response.index('def ')
    try:
        def2_pos = response.index('def ', def1_pos+4)
    except ValueError as ex:
        def2_pos = len(response)
    return response[:def2_pos]

In [58]:
@torch.inference_mode()
def mix_generate(model, tokenizer, prompt, max_new_tokens:int=512, num_sequences:int=10):
  """Generate output that is a mix of greedy and sampling.

  The greedy approach seems to be the most effective, while beam-search
  is not supported by the model. So, generate 1 sequence using greedy
  and the rest using sampling.
  """
  inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
  seqs = []
  # greedy generation
  outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
  text = tokenizer.batch_decode(outputs)[0]
  seqs.append(parse_response_phi15(text))
  # sampling generation
  if num_sequences >= 1:
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True,
                             do_sample=True, top_k=3, num_return_sequences=num_sequences-1)
    seqs.extend([parse_response_phi15(text)
                for text in tokenizer.batch_decode(outputs, skip_special_tokens=True)])

  return seqs



In [67]:
from tqdm import tqdm
import time

num_samples_per_task = 10
max_new_tokens = 300

start = time.time()
pbar = tqdm(total=len(problems) * num_samples_per_task)
samples = []
for task_id in problems:
  p = problems[task_id]['prompt']
  solutions = mix_generate(model, tokenizer, p,
                           max_new_tokens=max_new_tokens)
  samples.extend([dict(task_id=task_id, completion=solution) for solution in solutions])
  pbar.update(num_samples_per_task)

elapsed = time.time() - start

print('Total generation time: ', elapsed)

write_jsonl("/content/drive/MyDrive/phi15_humaneval.jsonl", samples)


  4%|▍         | 70/1640 [09:54<3:42:09,  8.49s/it]

  1%|          | 10/1640 [00:41<1:53:38,  4.18s/it][A
  1%|          | 20/1640 [01:14<1:37:37,  3.62s/it][A
  2%|▏         | 30/1640 [01:45<1:31:11,  3.40s/it][A
  2%|▏         | 40/1640 [02:17<1:28:19,  3.31s/it][A
  3%|▎         | 50/1640 [02:55<1:32:30,  3.49s/it][A
  4%|▎         | 60/1640 [03:27<1:29:05,  3.38s/it][A
  4%|▍         | 70/1640 [03:59<1:27:06,  3.33s/it][A
  5%|▍         | 80/1640 [04:31<1:25:20,  3.28s/it][A
  5%|▌         | 90/1640 [05:03<1:24:13,  3.26s/it][A
  6%|▌         | 100/1640 [05:34<1:22:58,  3.23s/it][A
  7%|▋         | 110/1640 [06:06<1:22:11,  3.22s/it][A
  7%|▋         | 120/1640 [06:38<1:21:02,  3.20s/it][A
  8%|▊         | 130/1640 [07:09<1:20:03,  3.18s/it][A
  9%|▊         | 140/1640 [07:40<1:19:05,  3.16s/it][A
  9%|▉         | 150/1640 [08:12<1:18:21,  3.16s/it][A
 10%|▉         | 160/1640 [08:43<1:17:49,  3.16s/it][A
 10%|█         | 170/1640 [09:16<1:17:47,  3.18s/it][A
 11

Total generation time:  5380.125176906586


In [None]:
# total generation time: 5380 / 1640 samples