In [1]:
import os
from tqdm import tqdm

os.environ['TRANSFORMERS_CACHE'] = '/datasets/Large_Language_Models'
from datasets import load_dataset, load_from_disk

In [104]:
dataset = load_dataset('h4iku/coconut_java2006_preprocessed', split='train')

In [105]:
dataset

Dataset({
    features: ['rem', 'add', 'context'],
    num_rows: 1125599
})

In [17]:
df=dataset.to_pandas()

In [None]:
for index, row in df.iterrows():
    if index<100:
        print("Index:", index)  # 注意这里是大写的 'Index'
        print("Row Data:", row['rem'])

In [133]:
from transformers import AutoTokenizer



In [134]:
MODEL_NAME = 'codefuse-ai/CodeFuse-CodeLlama-34B'
TOKEN = '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'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=TOKEN)

In [135]:
def get_token_len(input):
    input_tokens = tokenizer.encode(input, return_tensors="pt")
    return len(input_tokens[0])

## java code format

In [2]:
import subprocess
import tempfile

def format_java_code(java_code):
    # 创建一个临时文件来保存Java代码
    with tempfile.NamedTemporaryFile(mode='w+', suffix='.java', delete=False) as temp_file:
        temp_file_name = temp_file.name
        # print(java_code)
        temp_file.write('public class A {  '+java_code+'}')

    # 使用java命令行工具格式化代码
    try:
        subprocess.run(['java', '-jar', '/home/dwu25/google-java-format-1.19.1-all-deps.jar', '--replace', temp_file_name], check=True)
        
        # 读取格式化后的代码
        with open(temp_file_name, 'r') as file:
            formatted_code = file.read()
        return formatted_code[19:-3]
    except subprocess.CalledProcessError as e:
        # print(java_code)
        return None
    finally:
        # 删除临时文件
        if temp_file_name:
            os.remove(temp_file_name)

# java_code = """
# public class A { void B(){return B;}}"""

# formatted_code = format_java_code(java_code)
# if formatted_code:
#     print("Formatted Java Code:\n")
#     print(formatted_code)
# formatted_code[19:-3]

## coconut preprocess

In [3]:
import re

## 去掉注释
def remove_comments(text):
    # 正则表达式匹配以 // 开头，后面跟任意字符（除了换行符），并以两个或更多空格结尾的子字符串
    pattern = r'//.*?(  +|\t)'
    # 使用空字符串替换匹配的文本
    return re.sub(pattern, '', text)

In [123]:
remove_comments('protected void ensureRowsAreVisible(int beginRow, int endRow)	{    // FIXME: not implemented	}')

'protected void ensureRowsAreVisible(int beginRow, int endRow)\t{    }'

In [136]:
def format_context(sample):
    context=sample['context']
    context=remove_comments(context)
    
    if sample['rem'].replace('{', '').replace('}', '')=='':
        sample['formatted_context']=None
        return sample
    
    if len(context)>1000:
        sample['formatted_context']=None
        return sample
    
    sample['formatted_context']=format_java_code(context)
    return sample

In [None]:
format_dataset=dataset.map(format_context, num_proc=12)

## formatted coconut

In [4]:
formatted_dataset=load_from_disk('data/formatted_coconut')

In [5]:
formatted_dataset=formatted_dataset.filter(lambda x:x['formatted_context'] != None )

In [6]:
formatted_dataset[3]

{'rem': 'if (beginIndex < 0 || endIndex > count || beginIndex > endIndex) throw new StringIndexOutOfBoundsException(); if (beginIndex == 0 && endIndex == count) return this; int len = endIndex - beginIndex;  return new String(value, beginIndex + offset, len, (len << 2) >= value.length);',
 'add': 'return substring(begin, count);',
 'context': '  public String substring(int beginIndex, int endIndex)  {    if (beginIndex < 0 || endIndex > count || beginIndex > endIndex)      throw new StringIndexOutOfBoundsException();    if (beginIndex == 0 && endIndex == count)      return this;    int len = endIndex - beginIndex;    // Package constructor avoids an array copy.    return new String(value, beginIndex + offset, len,                      (len << 2) >= value.length);  }',
 'formatted_context': 'public String substring(int beginIndex, int endIndex) {\n    if (beginIndex < 0 || endIndex > count || beginIndex > endIndex)\n      throw new StringIndexOutOfBoundsException();\n    if (beginIndex 

In [8]:
def remove_spaces_newlines_and_get_indices(java_code):
    cleaned_code = ""
    indices = []
    for index, char in enumerate(java_code):
        if char not in [' ', '\n', '\r', '\t']:
            cleaned_code += char
            indices.append(index)
    return indices, cleaned_code

def find_substring_indices(main_string, substring):
    start_index = main_string.find(substring)
    
    # 如果找不到子字符串，则返回-1
    if start_index == -1:
        return -1, -1

    end_index = start_index + len(substring) - 1
    return start_index, end_index

def replace_pacth(java_code, java_patch):
    rm_code_ind_lst, rm_java_code=remove_spaces_newlines_and_get_indices(java_code)
    _, rm_java_patch=remove_spaces_newlines_and_get_indices(java_patch)
    start_ind, end_ind=find_substring_indices(rm_java_code, rm_java_patch)
    code_start_ind, code_end_ind=rm_code_ind_lst[start_ind], rm_code_ind_lst[end_ind]
    patch=java_code[code_start_ind:code_end_ind+1]
    return patch

def get_infill_code(code, rem_patch, add_patch):
    rem_patch= replace_pacth(code, rem_patch)
    return rem_patch, code.replace(rem_patch, '<INFILL>'), code.replace(rem_patch, add_patch)

code="""public synchronized StringBuffer append(char ch) {\n    ensureCapacity_unsynchronized(count + 1);\n    value[count++] = ch;\n    return this;\n  }"""
rem_patch="""ensureCapacity_unsynchronized(count + 1); value[count++] = ch; return this;"""
add_patch="""return append(obj == null ? "null" : obj.toString());"""

print( get_infill_code(code, rem_patch, add_patch))

('ensureCapacity_unsynchronized(count + 1);\n    value[count++] = ch;\n    return this;', 'public synchronized StringBuffer append(char ch) {\n    <INFILL>\n  }', 'public synchronized StringBuffer append(char ch) {\n    return append(obj == null ? "null" : obj.toString());\n  }')


In [19]:
def reformat(sample):
    rem_patch, add_patch, java_code=sample['rem'], sample['add'], sample['formatted_context']
    sample['rem_patch'], sample['infill_context'], add_context = get_infill_code(java_code, rem_patch, add_patch)
    formatted_add_context = format_java_code(add_context)
    if formatted_add_context:
        add_patch=replace_pacth(formatted_add_context, add_patch)
        sample['add_patch']=add_patch
    else:
        sample['add_patch']=None
    return sample

In [21]:
reformat_dataset=formatted_dataset.map(reformat, num_proc=12)

Map (num_proc=12):   0%|          | 0/529792 [00:00<?, ? examples/s]

/tmp/tmph994h3aa.java:2:6: error: reached end of file while parsing
/tmp/tmpap03a0co.java:2:6: error: reached end of file while parsing
/tmp/tmpjhck0mre.java:2:6: error: reached end of file while parsing
/tmp/tmp9ym77nwt.java:4:35: error: reached end of file while parsing
/tmp/tmpfhwcimno.java:4:17: error: reached end of file while parsing
/tmp/tmpmhs883z2.java:4:17: error: reached end of file while parsing
/tmp/tmpyohfe81r.java:1:89: error: 'try' without 'catch', 'finally' or resource declarations
/tmp/tmpyohfe81r.java:2:6: error: reached end of file while parsing
/tmp/tmpeoeuam0z.java:9:6: error: reached end of file while parsing
/tmp/tmp0_k03k0n.java:9:13: error: variable declaration not allowed here
/tmp/tmp337ht8k_.java:3:13: error: variable declaration not allowed here
/tmp/tmpum1b7nvy.java:15:7: error: reached end of file while parsing
/tmp/tmpa6dn5ez9.java:11:44: error: while expected
/tmp/tmpa6dn5ez9.java:12:6: error: illegal start of expression
Process ForkPoolWorker-3:
Proce

TimeoutError: 

## coconut comparison finetune dataset for MFTcoder

In [12]:
import re

def replace_multiple_spaces_with_single(s):
    return re.sub(r'\s+', ' ', s)

def preprocess(sample):
    rem, add, context = replace_multiple_spaces_with_single(sample['rem']).strip(), replace_multiple_spaces_with_single(sample['add']).strip(), replace_multiple_spaces_with_single(sample['context']).strip()
    infill =context.replace(rem, '<INFILL>')
    return rem.strip()=='' or infill.count('<INFILL>')!=1
    
dataset.filter(preprocess)

Filter:   0%|          | 0/1125599 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['rem', 'add', 'context'],
        num_rows: 219320
    })
})

In [19]:
import random

def process(ind, sample): 
    correct_patch_choice=''
    
    random_number = random.randint(0, 1)
    if random_number==0:
        A_patch, B_patch, correct_patch_choice=sample['add'], sample['rem'],'A'
    else:
        A_patch, B_patch, correct_patch_choice=sample['rem'], sample['add'],'B'
        
    
    system_content="You are an intelligent programming assistant for JAVA."
    human_content="Choose a correct patch from the following two patches to infill the Java code.\n\nJava code:\n"+sample['context'].replace(sample['rem'], '<INFILL>')+ "\n\nPatches:\nA. "+A_patch.strip()+"\nB. "+B_patch.strip()+'<|role_start|>bot<|role_end|>'
    bot_content="The correct patch is "+correct_patch_choice
    
    jsonl_sample={"id":ind,
                  "data_name":"comparison_finetune",
                  "chat_rounds":[
                      {
                          "role": "system",
                          "content": system_content,
                          "chat_round_id": 0
                      },
                      {
                          "role": "human",
                          "content": human_content,
                          "chat_round_id": 1
                      },
                      {
                          "role": "bot",
                          "content": bot_content,
                          "chat_round_id": 2
                      }
                  ]
                 }
    if ind<10:
        print(human_content)
    return ind, jsonl_sample, correct_patch_choice

In [21]:
from tqdm import tqdm
import json

sample_num=0

with open('data/coconut_comparison.jsonl', 'w') as file:
    with open('data/coconut_comparison_label.txt', 'w') as label_file:
        
        for ind, sample in tqdm(df.iterrows()):
            ind, jsonl_sample, correct_patch_choice = process(ind, sample)
            
            json_string = json.dumps(jsonl_sample)
            if get_token_len(json_string)<=500:
                file.write(json_string)
                file.write('\n')  

                label_file.write(f"{ind} {correct_patch_choice}\n")
                
                sample_num+=1
print(sample_num)
    

TypeError: 'generator' object is not subscriptable