In [1]:
import os
os.environ["HF_ENDPOINT"] = "https://huggingface.co"
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from typing import Pattern
import torch
from pygments import highlight
from pygments.formatters import Terminal256Formatter
from pygments.lexers import PythonLexer

tokenizer = AutoTokenizer.from_pretrained("YurtsAI/yurts-python-code-gen-30-sparse")
model = AutoModelForCausalLM.from_pretrained("YurtsAI/yurts-python-code-gen-30-sparse")

In [2]:
# Affix the pad token ID
PAD_TOKEN_ID = 50256

In [3]:
# Load tokeniser and model
tokenizer = AutoTokenizer.from_pretrained("YurtsAI/yurts-python-code-gen-30-sparse")
tokenizer.padding_side = "left"
tokenizer.pad_token = PAD_TOKEN_ID
model = AutoModelForCausalLM.from_pretrained("YurtsAI/yurts-python-code-gen-30-sparse").eval()

In [4]:
def run_single_inference(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str, max_output_token_len: int = 128,
                         max_prompt_token_length: int = 2048, top_p: float = 0.95, temp: float = 0.2, batch_size: int = 1,
                         device: str = "cuda") -> str:
    """Runs a single inference and retrieves the decoded results."""
    device = torch.device(device)
    input_ids = tokenizer(prompt, truncation=True, padding=True, max_length=max_prompt_token_length, 
                          return_tensors="pt").input_ids
    input_ids_len = input_ids.shape[1]

    with torch.no_grad():
        input_ids = input_ids.to()
        tokens = model.generate(input_ids, do_sample=True, num_return_sequences=batch_size, temperature=temp,
                                max_length=input_ids_len + max_output_token_len, top_p=top_p, pad_token_id=PAD_TOKEN_ID,
                                use_cache=True)
        text = tokenizer.batch_decode(tokens[:, input_ids_len:, ...])
    return text

In [5]:
TERMINALS = re.compile(r"<\|endoftext\|>", re.MULTILINE)
FUNCTION_DEFS = re.compile(r"(^.*def .*)", re.MULTILINE)
CLASS_DEFS = re.compile(r"(^.*class .*)", re.MULTILINE)
COMMENTS = re.compile(r"(^\s*#.*)", re.MULTILINE)


def find_repeats_and_trim(pattern: Pattern, string: str) -> str:
    """For the given pattern, find all group matches and trim anything after
    the first occurence of a repeat.
    """
    repeats = [
        (match.group(), match.end()) for match in pattern.finditer(string)
    ]
    print(repeats)
    if len(repeats) > 1:
        for f_i, f_j in zip(repeats, repeats[1:]):
            if (
                f_i[0] == f_j[0]
            ):  # if text is the same - we have a consecutive, repeating patterns
                string = string[: f_i[1]]
                break
    return string

def format_output(prompt: str, output: str) -> str:
    """Given the prompt and the output, join them and trim any repeating
    patterns as well as any terminals.
    """
    code_gen = prompt + output
    for pattern in [CLASS_DEFS, FUNCTION_DEFS, COMMENTS]:
        code_gen = find_repeats_and_trim(pattern, code_gen)

    first_terminal_match = TERMINALS.search(code_gen)
    if first_terminal_match is not None:
        return code_gen[first_terminal_match.start()]
    return code_gen

In [None]:
prompt = "def hello_world():"
max_output_token_len = 50

In [17]:
raw_output = run_single_inference(tokenizer, model, prompt, max_output_token_len=max_output_token_len).pop()
# Print raw output
print(highlight(prompt + raw_output, PythonLexer(), Terminal256Formatter()), end="")

[38;5;28;01mdef[39;00m [38;5;21mhello_world[39m():
    [38;5;28mprint[39m([38;5;124m"[39m[38;5;124mHello World[39m[38;5;124m"[39m)

hello_world()

[38;5;66;03m# A function that takes a string and returns a string with the first letter of the string[39;00m
[38;5;28;01mdef[39;00m [38;5;21mfirst_letter[39m(string):
    [38;5;28;01mreturn[39;00m string[[38;5;241m0[39m]


In [20]:
final_output = format_output(prompt, raw_output)

[]
[('def hello_world():', 18), ('def first_letter(string):', 175)]
[('\n# A function that takes a string and returns a string with the first letter of the string', 149)]


In [21]:
print(highlight(final_output, PythonLexer(), Terminal256Formatter()), end="")

[38;5;28;01mdef[39;00m [38;5;21mhello_world[39m():
    [38;5;28mprint[39m([38;5;124m"[39m[38;5;124mHello World[39m[38;5;124m"[39m)

hello_world()

[38;5;66;03m# A function that takes a string and returns a string with the first letter of the string[39;00m
[38;5;28;01mdef[39;00m [38;5;21mfirst_letter[39m(string):
    [38;5;28;01mreturn[39;00m string[[38;5;241m0[39m]


In [22]:
def hello_world():
    print("Hello World")

hello_world()

Hello World


Testing different prompts

In [33]:
# Generating with the default parameters
def tester(prompt, max_output_token_len=128):
    raw_output = run_single_inference(tokenizer, model, prompt, max_output_token_len=max_output_token_len).pop()
    final_output = format_output(prompt, raw_output)
    print(highlight(final_output, PythonLexer(), Terminal256Formatter()), end="")

In [29]:
prompt_1 = "def find_mean(a_list):"
tester(prompt_1)

[]
[('def find_mean(a_list):', 22), ('def find_median(a_list):', 295)]
[]
[38;5;28;01mdef[39;00m [38;5;21mfind_mean[39m(a_list):
    [38;5;124;03m"""[39;00m
[38;5;124;03m    Find the mean of a list of numbers.[39;00m

[38;5;124;03m    Parameters[39;00m
[38;5;124;03m    ----------[39;00m
[38;5;124;03m    a_list : list[39;00m
[38;5;124;03m        A list of numbers.[39;00m

[38;5;124;03m    Returns[39;00m
[38;5;124;03m    -------[39;00m
[38;5;124;03m    float[39;00m
[38;5;124;03m        The mean of the list of numbers.[39;00m

[38;5;124;03m    """[39;00m
    [38;5;28;01mreturn[39;00m [38;5;28msum[39m(a_list) [38;5;241m/[39m [38;5;28mlen[39m(a_list)


[38;5;28;01mdef[39;00m [38;5;21mfind_median[39m(a_list):
    [38;5;124m"""[39m
[38;5;124m    Find the median of a list of numbers.[39m

[38;5;124m    Parameters[39m
[38;5;124m    ----------[39m
[38;5;124m    a_list : list[39m
[38;5;124m        A list of numbers.[39m

[38;5;124m    Returns[39

In [35]:
prompt_2 = """
# This function calculates the most frequent entry in the given list
def most_frequent(a_list):
"""
tester(prompt_2, 200)

[]
[('def most_frequent(a_list):', 96), ('def most_frequent_2(a_list):', 692)]
[('\n# This function calculates the most frequent entry in the given list', 69), ('    # Initialize the most frequent entry', 137), ('    # Iterate through the list', 196), ('        # If the entry is the most frequent, update the most frequent entry', 305), ('        # If the entry is not the most frequent, update the most frequent entry', 475), ('    # Return the most frequent entry', 562), ('\n# This function calculates the most frequent entry in the given list', 663), ('    # Initialize the most frequent entry', 733), ('    # Iterate through the list', 792)]
[38;5;66;03m# This function calculates the most frequent entry in the given list[39;00m
[38;5;28;01mdef[39;00m [38;5;21mmost_frequent[39m(a_list):
    [38;5;66;03m# Initialize the most frequent entry[39;00m
    most_frequent_entry [38;5;241m=[39m [38;5;241m0[39m
    [38;5;66;03m# Iterate through the list[39;00m
    [38;5;28;01mfor[39;

In [36]:
def most_frequent(a_list):
    # Initialize the most frequent entry
    most_frequent_entry = 0
    # Iterate through the list
    for i in range(len(a_list)):
        # If the entry is the most frequent, update the most frequent entry
        if a_list[i] == a_list[most_frequent_entry]:
            most_frequent_entry += 1
        # If the entry is not the most frequent, update the most frequent entry
        else:
            most_frequent_entry = i
    # Return the most frequent entry
    return most_frequent_entry

In [38]:
test_list = [1, 1, 2, 2, 3, 3, 3, 4, 4]
# It just returns the size of the given list
most_frequent(test_list)

9

In [39]:
# The same compared to the previous prompt, but without the comment
prompt_3 = """
def most_frequent(a_list):
"""
tester(prompt_3, 200)

[]
[('def most_frequent(a_list):', 27), ('def get_frequent_numbers(a_list):', 198), ('def get_frequent_numbers_2(a_list):', 364), ('def get_frequent_numbers_3(a_list):', 553)]
[]
[38;5;28;01mdef[39;00m [38;5;21mmost_frequent[39m(a_list):
    [38;5;124;03m"""[39;00m
[38;5;124;03m    :param a_list: list of numbers[39;00m
[38;5;124;03m    :return: the most frequent number[39;00m
[38;5;124;03m    """[39;00m
    [38;5;28;01mreturn[39;00m [38;5;28mmax[39m([38;5;28mset[39m(a_list), key[38;5;241m=[39ma_list[38;5;241m.[39mcount)


[38;5;28;01mdef[39;00m [38;5;21mget_frequent_numbers[39m(a_list):
    [38;5;124;03m"""[39;00m
[38;5;124;03m    :param a_list: list of numbers[39;00m
[38;5;124;03m    :return: the list of frequent numbers[39;00m
[38;5;124;03m    """[39;00m
    [38;5;28;01mreturn[39;00m [most_frequent(a_list)]


[38;5;28;01mdef[39;00m [38;5;21mget_frequent_numbers_2[39m(a_list):
    [38;5;124;03m"""[39;00m
[38;5;124;03m    :param a_list: list o

In [40]:
def most_frequent(a_list):
    """
    :param a_list: list of numbers
    :return: the most frequent number
    """
    return max(set(a_list), key=a_list.count)


def get_frequent_numbers(a_list):
    """
    :param a_list: list of numbers
    :return: the list of frequent numbers
    """
    return [most_frequent(a_list)]

most_frequent(test_list)

3