In [5]:
import os
from tqdm import tqdm

os.environ['TRANSFORMERS_CACHE'] = '/datasets/Large_Language_Models'
from datasets import load_dataset, load_from_disk

In [6]:
from transformers import AutoTokenizer

In [7]:
MODEL_NAME = 'codellama/CodeLlama-7b-hf'
TOKEN = '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'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)

In [62]:
def get_token_len(input):
    try:
        input_tokens = tokenizer(input, return_tensors="pt")["input_ids"]
        return len(input_tokens[0])
    except :
        print(input)

In [104]:
dataset = load_dataset('h4iku/coconut_java2006_preprocessed', split='train')

In [105]:
dataset

Dataset({
    features: ['rem', 'add', 'context'],
    num_rows: 1125599
})

In [17]:
df=dataset.to_pandas()

In [None]:
for index, row in df.iterrows():
    if index<100:
        print("Index:", index)  # 注意这里是大写的 'Index'
        print("Row Data:", row['rem'])

## java code format

In [2]:
import subprocess
import tempfile

def format_java_code(java_code):
    # 创建一个临时文件来保存Java代码
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.java', delete=False) as temp_file:
        temp_file_name = temp_file.name
        # print(java_code)
        temp_file.write('public class A {  '+java_code+'}')

    # 使用java命令行工具格式化代码
    try:
        subprocess.run(['java', '-jar', '/home/dwu25/google-java-format-1.19.1-all-deps.jar', '--replace', temp_file_name], check=True)
        
        # 读取格式化后的代码
        with open(temp_file_name, 'r') as file:
            formatted_code = file.read()
        return formatted_code[19:-3]
    except subprocess.CalledProcessError as e:
        # print(java_code)
        return None
    finally:
        # 删除临时文件
        if temp_file_name:
            os.remove(temp_file_name)

# java_code = """
# public class A { void B(){return B;}}"""

# formatted_code = format_java_code(java_code)
# if formatted_code:
#     print("Formatted Java Code:\n")
#     print(formatted_code)
# formatted_code[19:-3]

## coconut preprocess

In [60]:
import re

## 去掉注释
def remove_comments(text):
    # 正则表达式匹配以 // 开头，后面跟任意字符（除了换行符），并以两个或更多空格结尾的子字符串
    pattern = r'//.*?(  +|\t)'
    # 使用空字符串替换匹配的文本
    return re.sub(pattern, '', text)

In [123]:
remove_comments('protected void ensureRowsAreVisible(int beginRow, int endRow)	{    // FIXME: not implemented	}')

'protected void ensureRowsAreVisible(int beginRow, int endRow)\t{    }'

In [136]:
def format_context(sample):
    context=sample['context']
    context=remove_comments(context)
    
    if sample['rem'].replace('{', '').replace('}', '')=='':
        sample['formatted_context']=None
        return sample
    
    if len(context)>1000:
        sample['formatted_context']=None
        return sample
    
    sample['formatted_context']=format_java_code(context)
    return sample

In [None]:
format_dataset=dataset.map(format_context, num_proc=12)

## formatted coconut

In [4]:
formatted_dataset=load_from_disk('data/formatted_coconut')

In [5]:
formatted_dataset=formatted_dataset.filter(lambda x:x['formatted_context'] != None )

In [6]:
formatted_dataset[3]

{'rem': 'if (beginIndex < 0 || endIndex > count || beginIndex > endIndex) throw new StringIndexOutOfBoundsException(); if (beginIndex == 0 && endIndex == count) return this; int len = endIndex - beginIndex;  return new String(value, beginIndex + offset, len, (len << 2) >= value.length);',
 'add': 'return substring(begin, count);',
 'context': '  public String substring(int beginIndex, int endIndex)  {    if (beginIndex < 0 || endIndex > count || beginIndex > endIndex)      throw new StringIndexOutOfBoundsException();    if (beginIndex == 0 && endIndex == count)      return this;    int len = endIndex - beginIndex;    // Package constructor avoids an array copy.    return new String(value, beginIndex + offset, len,                      (len << 2) >= value.length);  }',
 'formatted_context': 'public String substring(int beginIndex, int endIndex) {\n    if (beginIndex < 0 || endIndex > count || beginIndex > endIndex)\n      throw new StringIndexOutOfBoundsException();\n    if (beginIndex 

In [8]:
def remove_spaces_newlines_and_get_indices(java_code):
    cleaned_code = ""
    indices = []
    for index, char in enumerate(java_code):
        if char not in [' ', '\n', '\r', '\t']:
            cleaned_code += char
            indices.append(index)
    return indices, cleaned_code

def find_substring_indices(main_string, substring):
    start_index = main_string.find(substring)
    
    # 如果找不到子字符串，则返回-1
    if start_index == -1:
        return -1, -1

    end_index = start_index + len(substring) - 1
    return start_index, end_index

def replace_pacth(java_code, java_patch):
    rm_code_ind_lst, rm_java_code=remove_spaces_newlines_and_get_indices(java_code)
    _, rm_java_patch=remove_spaces_newlines_and_get_indices(java_patch)
    start_ind, end_ind=find_substring_indices(rm_java_code, rm_java_patch)
    code_start_ind, code_end_ind=rm_code_ind_lst[start_ind], rm_code_ind_lst[end_ind]
    patch=java_code[code_start_ind:code_end_ind+1]
    return patch

def get_infill_code(code, rem_patch, add_patch):
    rem_patch= replace_pacth(code, rem_patch)
    return rem_patch, code.replace(rem_patch, '<INFILL>'), code.replace(rem_patch, add_patch)

code="""public synchronized StringBuffer append(char ch) {\n    ensureCapacity_unsynchronized(count + 1);\n    value[count++] = ch;\n    return this;\n  }"""
rem_patch="""ensureCapacity_unsynchronized(count + 1); value[count++] = ch; return this;"""
add_patch="""return append(obj == null ? "null" : obj.toString());"""

print( get_infill_code(code, rem_patch, add_patch))

('ensureCapacity_unsynchronized(count + 1);\n    value[count++] = ch;\n    return this;', 'public synchronized StringBuffer append(char ch) {\n    <INFILL>\n  }', 'public synchronized StringBuffer append(char ch) {\n    return append(obj == null ? "null" : obj.toString());\n  }')


In [19]:
def reformat(sample):
    rem_patch, add_patch, java_code=sample['rem'], sample['add'], sample['formatted_context']
    sample['rem_patch'], sample['infill_context'], add_context = get_infill_code(java_code, rem_patch, add_patch)
    formatted_add_context = format_java_code(add_context)
    if formatted_add_context:
        add_patch=replace_pacth(formatted_add_context, add_patch)
        sample['add_patch']=add_patch
    else:
        sample['add_patch']=None
    return sample

In [None]:
# reformat_dataset=formatted_dataset.map(reformat, num_proc=12)

In [4]:
reformat_dataset

Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rem_patch', 'infill_context', 'add_patch'],
    num_rows: 529792
})

## coconut_dschat_training_dataset

In [15]:
reformat_dataset=load_from_disk('data/reformatted_coconut')

In [16]:
dschat_rlhf_dataset = reformat_dataset.filter(lambda x:x['add_patch']!=None)

Loading cached processed dataset at /tmp/tmpyyta3gjv/data/reformatted_coconut/cache-ce3ed1ca6df456ab.arrow


In [17]:
def process_dschat(sample):
    pre, suffix=sample['infill_context'].split('<INFILL>')[0], sample['infill_context'].split('<INFILL>')[1]
    # format as "<PRE> {pre} <SUF>{suf} <MID>"
    sample['prompt']='<PRE> '+pre+' <SUF>'+suffix+' <MID>'
    return sample

In [18]:
dschat_rlhf_dataset = dschat_rlhf_dataset.map(process_dschat)

Map:   0%|          | 0/499660 [00:00<?, ? examples/s]

In [19]:
dschat_rlhf_dataset=dschat_rlhf_dataset.rename_columns({'rem_patch':'rejected', 'add_patch':'chosen'}).filter(lambda x:get_token_len(x['prompt'])<=500 and get_token_len(x['prompt'])>20 and get_token_len(x['chosen'])<=100 and get_token_len(x['rejected'])<=100).remove_columns(['rem', 'add', 'context', 'formatted_context','infill_context',])

Filter:   0%|          | 0/499660 [00:00<?, ? examples/s]

In [20]:
dschat_rlhf_dataset = dschat_rlhf_dataset.train_test_split(test_size=0.05)

In [21]:
dschat_rlhf_dataset.save_to_disk(f"data/apr_rlhf_coconut")

Saving the dataset (0/1 shards):   0%|          | 0/339925 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/17891 [00:00<?, ? examples/s]

In [24]:
dschat_rlhf_dataset['train'][0]

{'rejected': 'result = this.getCompetenceCourse().getAutonomousWorkHours(curricularPeriod.getOrder());',
 'chosen': 'result = this.getCompetenceCourse().getAutonomousWorkHours();',
 'prompt': '<PRE> public Double getAutonomousWorkHours(CurricularPeriod curricularPeriod) {\n    double result = 0.0;\n    if (this.getCompetenceCourse() != null) {\n       <SUF>\n    }\n    return result;\n  } <MID>'}

## coconut_rlhf_training_dataset

In [65]:
reformat_dataset=load_from_disk('data/reformatted_coconut')

In [66]:
rm_dataset = reformat_dataset.filter(lambda x:x['add_patch']!=None)

def process_dschat(sample):
    # pre, suffix=sample['infill_context'].split('<INFILL>')[0], sample['infill_context'].split('<INFILL>')[1]
    # # format as "<PRE> {pre} <SUF>{suf} <MID>"
    sample['prompt']=sample['infill_context'].replace('<INFILL>','<FILL_ME>')
    return sample

rm_dataset =rm_dataset.map(process_dschat)

Loading cached processed dataset at /tmp/tmpyyta3gjv/data/reformatted_coconut/cache-e92837652d431f7f.arrow
Loading cached processed dataset at /tmp/tmpyyta3gjv/data/reformatted_coconut/cache-6020518e0d9d2a85.arrow


In [67]:
rm_dataset=rm_dataset.rename_columns({'rem_patch':'rejected', 'add_patch':'chosen'})

In [57]:
def filter_sample(x):
    return get_token_len(x['prompt'])<=500 and get_token_len(x['prompt'])>20 and get_token_len(x['chosen'])<=100 and get_token_len(x['rejected'])<=100

In [68]:
rm_dataset=rm_dataset.filter(lambda x:x['prompt'].count('FILL_ME')==1)
rm_dataset

Filter:   0%|          | 0/499660 [00:00<?, ? examples/s]

Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rejected', 'infill_context', 'chosen', 'prompt'],
    num_rows: 494013
})

In [69]:
rm_dataset=rm_dataset.filter(lambda x:get_token_len(x['prompt'])<=500 and get_token_len(x['prompt'])>20 and get_token_len(x['chosen'])<=100 and get_token_len(x['rejected'])<=100)
rm_dataset

Filter:   0%|          | 0/494013 [00:00<?, ? examples/s]

Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rejected', 'infill_context', 'chosen', 'prompt'],
    num_rows: 352237
})

In [70]:
rm_dataset.remove_columns(['rem', 'add', 'context', 'formatted_context','infill_context',])
rm_dataset

Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rejected', 'infill_context', 'chosen', 'prompt'],
    num_rows: 352237
})

In [72]:
rm_dataset = rm_dataset.filter(lambda x:'protected' not in x['rejected'] and 'private' not in x['rejected'] and 'public' not in x['rejected'] )
rm_dataset

Filter:   0%|          | 0/352237 [00:00<?, ? examples/s]

Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rejected', 'infill_context', 'chosen', 'prompt'],
    num_rows: 304568
})

In [73]:
rm_dataset = rm_dataset.train_test_split(test_size=0.05)

In [74]:
rm_dataset.save_to_disk(f"data/apr_rlhf_coconut")

Saving the dataset (0/2 shards):   0%|          | 0/289339 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/15229 [00:00<?, ? examples/s]

In [63]:
import json
with open('data/apr_rm_train.json', 'w') as file:
    j_lst=[]
    for sample in rm_dataset['train']:
        j_sample={"instruction": sample['prompt'],'input':'','output':[sample['chosen'], sample['rejected']]}
        j_lst.append(j_sample)
    json.dump(j_lst, file, indent=2)

In [64]:
import json
with open('data/apr_rm_test.json', 'w') as file:
    j_lst=[]
    for sample in rm_dataset['test']:
        j_sample={"instruction": sample['prompt'],'input':'','output':[sample['chosen'], sample['rejected']]}
        j_lst.append(j_sample)
    json.dump(j_lst, file, indent=2)

## coconut comparison finetune dataset for MFTcoder

In [None]:
import re

def replace_multiple_spaces_with_single(s):
    return re.sub(r'\s+', ' ', s)

def preprocess(sample):
    rem, add, context = replace_multiple_spaces_with_single(sample['rem']).strip(), replace_multiple_spaces_with_single(sample['add']).strip(), replace_multiple_spaces_with_single(sample['context']).strip()
    infill =context.replace(rem, '<INFILL>')
    return rem.strip()=='' or infill.count('<INFILL>')!=1
    
dataset.filter(preprocess)

In [None]:
import random

def process(ind, sample): 
    correct_patch_choice=''
    
    random_number = random.randint(0, 1)
    if random_number==0:
        A_patch, B_patch, correct_patch_choice=sample['add'], sample['rem'],'A'
    else:
        A_patch, B_patch, correct_patch_choice=sample['rem'], sample['add'],'B'
        
    
    system_content="You are an intelligent programming assistant for JAVA."
    human_content="Choose a correct patch from the following two patches to infill the Java code.\n\nJava code:\n"+sample['context'].replace(sample['rem'], '<INFILL>')+ "\n\nPatches:\nA. "+A_patch.strip()+"\nB. "+B_patch.strip()+'<|role_start|>bot<|role_end|>'
    bot_content="The correct patch is "+correct_patch_choice
    
    jsonl_sample={"id":ind,
                  "data_name":"comparison_finetune",
                  "chat_rounds":[
                      {
                          "role": "system",
                          "content": system_content,
                          "chat_round_id": 0
                      },
                      {
                          "role": "human",
                          "content": human_content,
                          "chat_round_id": 1
                      },
                      {
                          "role": "bot",
                          "content": bot_content,
                          "chat_round_id": 2
                      }
                  ]
                 }
    if ind<10:
        print(human_content)
    return ind, jsonl_sample, correct_patch_choice

In [None]:
from tqdm import tqdm
import json

sample_num=0

with open('data/coconut_comparison.jsonl', 'w') as file:
    with open('data/coconut_comparison_label.txt', 'w') as label_file:
        
        for ind, sample in tqdm(df.iterrows()):
            ind, jsonl_sample, correct_patch_choice = process(ind, sample)
            
            json_string = json.dumps(jsonl_sample)
            if get_token_len(json_string)<=500:
                file.write(json_string)
                file.write('\n')  

                label_file.write(f"{ind} {correct_patch_choice}\n")
                
                sample_num+=1
print(sample_num)
    

## add buggy lines

In [124]:
reformat_dataset=load_from_disk('data/reformatted_coconut')
reformat_dataset

Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rem_patch', 'infill_context', 'add_patch'],
    num_rows: 529792
})

In [125]:
rm_dataset = reformat_dataset.filter(lambda x:x['add_patch']!=None)
rm_dataset

Loading cached processed dataset at /tmp/tmprx0bqt1y/data/reformatted_coconut/cache-e92837652d431f7f.arrow


Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rem_patch', 'infill_context', 'add_patch'],
    num_rows: 499660
})

In [126]:
rm_dataset = rm_dataset.filter(lambda x:'protected' not in x['rem_patch'] and 'private' not in x['rem_patch'] and 'public' not in x['rem_patch'] )
rm_dataset

Loading cached processed dataset at /tmp/tmprx0bqt1y/data/reformatted_coconut/cache-e451d2dce8f7339a.arrow


Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rem_patch', 'infill_context', 'add_patch'],
    num_rows: 399528
})

In [127]:
rm_dataset=rm_dataset.map(process_dschat)
rm_dataset

Map:   0%|          | 0/399528 [00:00<?, ? examples/s]

Dataset({
    features: ['rem', 'add', 'context', 'formatted_context', 'rem_patch', 'infill_context', 'add_patch', 'prompt'],
    num_rows: 399528
})

In [128]:
rm_dataset=rm_dataset.rename_columns({'rem_patch':'rejected', 'add_patch':'chosen'}).filter(lambda x:get_token_len(x['prompt'])<=500 and get_token_len(x['prompt'])>20 and get_token_len(x['chosen'])<=100 and get_token_len(x['rejected'])<=100).remove_columns(['rem', 'add', 'context', 'formatted_context','infill_context',])
rm_dataset

Filter:   0%|          | 0/399528 [00:00<?, ? examples/s]

Dataset({
    features: ['rejected', 'chosen', 'prompt'],
    num_rows: 350233
})

In [112]:
def count_leading_spaces(s):
    stripped_string = s.lstrip()
    leading_spaces = len(s) - len(stripped_string)
    return leading_spaces

In [113]:
buggy_line="if (beginIndex < 0 || endIndex > count || beginIndex > endIndex)\n      throw new StringIndexOutOfBoundsException();\n    if (beginIndex == 0 && endIndex == count) return this;\n    int len = endIndex - beginIndex;\n    return new String(value, beginIndex + offset, len, (len << 2) >= value.length);"

def add_buggy_lines(buggy_line, space_num):
    lst=buggy_line.split('\n')
    commented_lst=['// buggy code\n'+space_num*' '+'// '+lst[0]]+[space_num*' '+'// '+s[space_num:] for s in lst[1:]]
    return '\n'.join(commented_lst)+'\n'

In [123]:
def process_dschat(sample):
    pre, suffix, buggy_line=sample['infill_context'].split('<INFILL>')[0], sample['infill_context'].split('<INFILL>')[1], sample['rem_patch']
    # format as "<PRE> {pre} <SUF>{suf} <MID>"
    space_num = len(pre)-len(pre.rstrip(' '))
    buggy_line=add_buggy_lines(buggy_line, space_num)
    sample['prompt']=('<PRE> '+pre+buggy_line+' <SUF>'+suffix+' <MID>').replace('\n <SUF>\n', '\n <SUF>')
    return sample

In [139]:
rm_dataset=rm_dataset.filter(lambda x:not x['rejected'].startswith('{'))

Filter:   0%|          | 0/350233 [00:00<?, ? examples/s]

In [141]:
rm_dataset

Dataset({
    features: ['rejected', 'chosen', 'prompt'],
    num_rows: 348842
})

In [143]:
rm_dataset=rm_dataset.filter(lambda x: not '// buggy code' in remove_comments(x['prompt']))
rm_dataset

Filter:   0%|          | 0/348842 [00:00<?, ? examples/s]

Dataset({
    features: ['rejected', 'chosen', 'prompt'],
    num_rows: 321061
})

In [145]:
import re

## 去掉注释
def remove_commented_buggy_line(prompt):
    def remove_comments(code):
        lines = code.strip().split('\n')
        return '\n'.join(line for line in lines if not line.strip().startswith('//')) + '\n'

    pre, suffix = prompt.split(' <SUF>')[0], prompt.split(' <SUF>')[1]
    prompt = remove_comments(pre) + ' <SUF>' + suffix
    return prompt


In [None]:
for i in range(100):
    sample=rm_dataset[i]
    print('----------------------')
    print(sample['prompt'])
    sample['prompt']=remove_commented_buggy_line(sample['prompt'])
    print(sample['prompt'])

In [157]:
rm_dataset = rm_dataset.train_test_split(test_size=0.05)

In [158]:
rm_dataset.save_to_disk(f"data/apr_rlhf_rm_coconut")

Saving the dataset (0/1 shards):   0%|          | 0/305007 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/16054 [00:00<?, ? examples/s]

## torch test

In [148]:
import torch
tensor = torch.tensor([[3,5,1,2],[3,1,5,3],[7,5,8,3]],dtype=torch.float)
print(tensor)

tensor([[3., 5., 1., 2.],
        [3., 1., 5., 3.],
        [7., 5., 8., 3.]])
tensor([[3.2000, 5.2000, 1.2000, 2.2000],
        [3.2000, 1.2000, 5.2000, 3.2000],
        [7.2000, 5.2000, 8.2000, 3.2000]])


In [152]:
torch.cat([tensor[:,:1], tensor[:,1:]], dim=1)

tensor([[3.2000, 5.2000, 1.2000, 2.2000],
        [3.2000, 1.2000, 5.2000, 3.2000],
        [7.2000, 5.2000, 8.2000, 3.2000]])

In [153]:
for i in range(tensor.shape[0]):
    print(tensor[i])

tensor([3.2000, 5.2000, 1.2000, 2.2000])
tensor([3.2000, 1.2000, 5.2000, 3.2000])
tensor([7.2000, 5.2000, 8.2000, 3.2000])
