In [None]:
# Simple in-context learning

In [None]:
!pip install -U sentence-transformers
!pip install transformers

In [None]:
import json
import heapq
import warnings

from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
num_code = 100
filename = '/content/drive/MyDrive/UZH/AI4PP/function_call_prefix.json'
warnings.filterwarnings("ignore", message="The attention mask and the pad token id were not set.")

In [None]:
tokenizer_codegen = AutoTokenizer.from_pretrained('Salesforce/codegen-350M-mono', pad_token='<pad>')
model_codegen = AutoModelForCausalLM.from_pretrained('Salesforce/codegen-350M-mono', pad_token_id=50256)
# tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
# model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-350M-mono")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
def json_to_list(filename, num_code=-1):
  # input: 
  #   filename: pytorrent dataset filename
  #   num_code: number of smaples wanted, -1 means getting all
  # return: code_list, truncated_code_list(code without the last token)
  prefix_list = []
  gt_list = [] # ground truth
  emb_list = []
  model_ST = SentenceTransformer('all-mpnet-base-v2')

  with open(filename, 'r') as f:
    for idx,line in enumerate(f):
      if idx == num_code:
        break
      json_obj = json.loads(line)

      prefix = json_obj['input'].replace('<mask0>','')
      prefix_list.append(prefix)

      gt = json_obj['gt'].split(' ')[0].lower()
      gt_list.append(gt)

      embeddings = model_ST.encode(prefix, convert_to_tensor=True)
      emb_list.append(embeddings)

    return prefix_list, gt_list, emb_list

In [None]:
def split_data(prefix_list, gt_list, emb_list, ratio):
  # Calculate the size of the training set
  train_size = int(len(prefix_list) * ratio)

  prefix_train = prefix_list[:train_size]
  prefix_test = prefix_list[train_size:]
  gt_train = gt_list[:train_size]
  gt_test = gt_list[train_size:]
  emb_train = emb_list[:train_size]
  emb_test = emb_list[train_size:]

  return prefix_train, prefix_test, gt_train, gt_test, emb_train, emb_test

In [None]:
prefix_train, prefix_test, gt_train, gt_test, emb_train, emb_test = split_data(*json_to_list(filename), 0.8)

In [None]:
def compute_semantic_similarity(object_embeddings, sample_embeddings):
  cosine_scores = util.cos_sim(object_embeddings, sample_embeddings)
  return cosine_scores.item()

In [None]:
def generate_prompt(input, prefix_list, gt_list, emb_list, num_prompts=3):
  # generate prompt based on input
  # by finding similar input/output pair and add them befor the input
  similarities = []
  model_ST = SentenceTransformer('all-mpnet-base-v2')
  input_emb = model_ST.encode(input, convert_to_tensor=True)
  for i in range(len(emb_list)):
    similarity = compute_semantic_similarity(input_emb, emb_list[i])
    similarities.append({'index': i, 'similarity': similarity})

  n_largest_scores = heapq.nlargest(num_prompts, similarities, key=lambda s: s['similarity'])

  prompt = ''
  for item in n_largest_scores:
    prompt += '##############################\n'
    prompt += prefix_list[item['index']]
    prompt += gt_list[item['index']]
  
  prompt += '\n##############################\n'
  prompt += input
  prompt += '\n'

  return prompt

In [None]:
def evaluate(model, tokenizer, prefix_train, prefix_test, gt_train, gt_test, emb_train, emb_test, num_prompts=1, in_context=False):
  total = 0
  correct = 0
  for idx in range(len(prefix_test)):
    if in_context:
      text = generate_prompt(prefix_test[idx], prefix_train, gt_train, emb_train, num_prompts=num_prompts)
    else:
      text = prefix_test[idx]

    input_ids = tokenizer(text, return_tensors="pt").input_ids

    generated_ids = model.generate(input_ids, max_new_tokens=1)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    predicted_token = generated_text.split()[-1].lower()

    total = total + 1
    print(predicted_token)
    print(gt_test[idx])
    if predicted_token == gt_test[idx] or gt_test[idx].startswith(predicted_token):
      correct = correct + 1
    
    EM = correct / total
    print("EM: ", EM)

  return EM

In [None]:
evaluate(model_codegen, tokenizer_codegen, prefix_train, prefix_test, gt_train, gt_test, emb_train, emb_test)

def
easter
EM:  0.0
=
get_fixture_path
EM:  0.0
=
get_fixture_path
EM:  0.0
=
scopelinker
EM:  0.0
:
super
EM:  0.0
assert
str
EM:  0.0
raise
attributeerror
EM:  0.0
*
satvapor
EM:  0.0
:
glbegin
EM:  0.0
=
dummyfungen
EM:  0.0
[
type
EM:  0.0
:
print
EM:  0.0
=
problem
EM:  0.0
if
isinstance
EM:  0.0
assert
autocrop_array_shapes
EM:  0.0
=
websocketclient
EM:  0.0
is
isinstance
EM:  0.058823529411764705
=
lock
EM:  0.05555555555555555
=
celery
EM:  0.05263157894736842
=
create_app
EM:  0.05
"""<str_lit>"""
print
EM:  0.047619047619047616
index_pages
configure_logging
EM:  0.045454545454545456
class
toobusymiddleware
EM:  0.043478260869565216
dir
join
EM:  0.041666666666666664
def
load_image_list
EM:  0.04
=
zeromqmedium
EM:  0.038461538461538464
=
config
EM:  0.037037037037037035
return
ord
EM:  0.03571428571428571
raise
notimplementederror
EM:  0.034482758620689655
),
open
EM:  0.03333333333333333
=
dependencydecoder
EM:  0.03225806451612903
=
labeldictionary
EM:  0.03125
in
range
EM

0.030303030303030304

In [None]:
run(model_codegen, tokenizer_codegen, prefix_train, prefix_test, gt_train, gt_test, emb_train, emb_test, in_context=True)

e
easter
EM:  1.0
path
get_fixture_path
EM:  0.5
fix
get_fixture_path
EM:  0.3333333333333333
def
scopelinker
EM:  0.25
super
super
EM:  0.4
assert
str
EM:  0.3333333333333333
def
attributeerror
EM:  0.2857142857142857
(
satvapor
EM:  0.25
:
glbegin
EM:  0.2222222222222222
sh
dummyfungen
EM:  0.2
b
type
EM:  0.18181818181818182
from
print
EM:  0.16666666666666666
from
problem
EM:  0.15384615384615385
def
isinstance
EM:  0.14285714285714285
aut
autocrop_array_shapes
EM:  0.2
web
websocketclient
EM:  0.25
self
isinstance
EM:  0.23529411764705882
lock
lock
EM:  0.2777777777777778
cel
celery
EM:  0.3157894736842105
################
create_app
EM:  0.3
def
print
EM:  0.2857142857142857
class
configure_logging
EM:  0.2727272727272727
def
toobusymiddleware
EM:  0.2608695652173913
################
join
EM:  0.25
train
load_image_list
EM:  0.24
#
zeromqmedium
EM:  0.23076923076923078
################
config
EM:  0.2222222222222222
def
ord
EM:  0.21428571428571427
def
notimplementederror
EM:  0.

0.18181818181818182