Model Inference

In [1]:
#!pip install ipywidgets
#!pip install sentencepiece
import ipywidgets as widgets
from ipywidgets import HBox, Label
from ipywidgets import Layout
from IPython.display import display, clear_output
from transformers import T5ForConditionalGeneration,AutoModelForSeq2SeqLM,AutoTokenizer,RobertaTokenizer
import numpy as np
import torch
import re

In [2]:
# def toks_with_prompt(text, prompt, tokenizer):
#     input_ids = tokenizer(text,max_length=512, padding="max_length").input_ids
#     control_toks = tokenizer(f"<nl>{prompt}</nl>", add_special_tokens=False)
#     input_toks = control_toks["input_ids"] + input_ids
#     return torch.tensor([input_toks], dtype=torch.long)

max_size=512
def toks_with_prompt(text, prompt,tokenizer):
    # takes a tokenized data set and adds the control sequence to the code
    input_ids = tokenizer(text).input_ids
    control_toks = tokenizer(f"<nl>{prompt}</nl>", add_special_tokens=False)
    idx_last = int(input_ids.index(2))
    control_toks_len = len(control_toks["input_ids"])
    free_space = max_size - idx_last - 1

    true_seq = input_ids[: idx_last + 1]
    if free_space >= control_toks_len:
        input_ids[:control_toks_len] = control_toks["input_ids"]
        input_ids[control_toks_len : control_toks_len + len(true_seq)] = true_seq
    else:
        input_ids[:control_toks_len] = control_toks["input_ids"]
        fit_len = max_size - control_toks_len - 1
        input_ids[control_toks_len:] = true_seq[:fit_len] + [2]

    return torch.tensor([input_ids], dtype=torch.long)

In [3]:
def do_inference(input_code,featureName,promptValue):
    
    if featureName.lower() == "casing":
        checkpoint = "/data/users/team2_capstone/code-style-probing/seq2seq_results/outlier_casing_codet5small/checkpoint-27000"
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        input_ids =  tokenizer([input_code],max_length=512, padding="max_length", return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids)
        return tokenizer.batch_decode(generated_ids,skip_special_tokens=True)
    elif featureName.lower() == "comments":
        checkpoint = "/data/users/team2_capstone/code-style-probing/seq2seq_results/outlier_codet5small/checkpoint-40500"
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        input_ids =  tokenizer([input_code],max_length=512, padding="max_length", return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids)
        return tokenizer.batch_decode(generated_ids,skip_special_tokens=True)
    elif featureName.lower() == "class":
        checkpoint = "/data/users/team2_capstone/code-style-probing/seq2seq_results/outlier_class_codet5small/checkpoint-49000"
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        input_ids =  tokenizer([input_code],max_length=1024, padding="max_length", return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids)
        return tokenizer.batch_decode(generated_ids,skip_special_tokens=True)
    elif featureName.lower() == "docstring":
        checkpoint = "/data/users/team2_capstone/code-style-probing/seq2seq_results/outlier_updated_docstring_codet5small/checkpoint-85500"
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        input_ids =  tokenizer([input_code],max_length=512, padding="max_length", return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids)
        return tokenizer.batch_decode(generated_ids,skip_special_tokens=True)
    elif featureName.lower() == "list comprehensions":
        checkpoint = "/data/users/team2_capstone/code-style-probing/seq2seq_results/outlier_comp_codet5small/checkpoint-9500"
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        input_ids =  tokenizer([input_code],max_length=512, padding="max_length", return_tensors="pt").input_ids
        generated_ids = model.generate(input_ids)
        return tokenizer.batch_decode(generated_ids,skip_special_tokens=True)
    else:
        
        checkpoint = "/data/users/team2_capstone/code-style-probing/seq2seq_results/combined_nl_prompt_base_features_contd_codet5small/checkpoint-144856"        
        model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
        #tokenizer = AutoTokenizer.from_pretrained('Salesforce/codet5-small')
        tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-small')
        input_tokens = toks_with_prompt(input_code, promptValue, tokenizer)
        generated_ids = model.generate(input_tokens)
        return tokenizer.batch_decode(generated_ids,skip_special_tokens=True)
    

In [4]:
button = widgets.Button(description="Run Inference!")
output = widgets.Output(layout={'border': '1px solid black'})
feature_selection = widgets.Dropdown(
    options=['Casing', 'Comments', 'Class','Docstring','List Comprehensions','Combined'],
    description='Features : ',
    disabled=False,
)
l = Layout(flex='0 1 auto', height='250px', min_height='40px', width='550px')
inputText = widgets.Textarea(
    layout = l,
    placeholder='Type something',
    description='Input Code:',
    disabled=False,
    font_size = "18px"
)

promptText = widgets.Textarea(
    #layout = l,
    placeholder='Type something',
    description='Prompt:',
    disabled=False
)

In [5]:
#!pip install simple_colors
import difflib
from simple_colors import *
red = lambda text: f"\033[38;2;255;0;0m{text}\033[38;2;255;255;255m"
green = lambda text: f"\033[38;2;0;150;0m{text}\033[38;2;255;255;255m"
blue = lambda text: f"\033[38;2;0;0;255m{text}\033[38;2;255;255;255m"
white = lambda text: f"\033[38;2;0;0;0m{text}\033[38;2;0;0;0m"

def get_edits_string(old, new):
    result = ""
    codes = difflib.SequenceMatcher(a=old, b=new).get_opcodes()
    for code in codes:
        if code[0] == "equal": 
            result += white(old[code[1]:code[2]])
        elif code[0] == "delete":
            result += red(old[code[1]:code[2]])
        elif code[0] == "insert":
            result += green(new[code[3]:code[4]])
        elif code[0] == "replace":
            result += (red(old[code[1]:code[2]]) + green(new[code[3]:code[4]]))
    return result

In [6]:
from multiprocessing.sharedctypes import Value

display(feature_selection)
display(promptText)
display(inputText)
display(button,output)

@output.capture(clear_output=True)
def on_button_clicked(b):
    with output:
        output_code = do_inference(str(inputText.value),feature_selection.value,promptText.value)
        final_code=get_edits_string(inputText.value, output_code[0])
        print(final_code)
        #print(output_code[0])
clickValue = button.on_click(on_button_clicked)

Dropdown(description='Features : ', options=('Casing', 'Comments', 'Class', 'Docstring', 'List Comprehensions'…

Textarea(value='', description='Prompt:', placeholder='Type something')

Textarea(value='', description='Input Code:', layout=Layout(flex='0 1 auto', height='250px', min_height='40px'…

Button(description='Run Inference!', style=ButtonStyle())

Output(layout=Layout(border='1px solid black'))