Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to match results on code generation #47

Closed
sindhura97 opened this issue Apr 16, 2022 · 2 comments
Closed

Unable to match results on code generation #47

sindhura97 opened this issue Apr 16, 2022 · 2 comments

Comments

@sindhura97
Copy link

I am unable to get a decent (close to test set performance reported in paper) performance on the validation set for code generation using your fine-tuned checkpoint. I am getting a bleu score of 29.49 and EM of 12.65. Here is my code. Am I doing something wrong here?

from datasets import load_dataset

class Example(object):
    def __init__(self, idx, source, target ):
        self.idx = idx
        self.source = source
        self.target = target

def read_examples(split):
    dataset = load_dataset('code_x_glue_tc_text_to_code')[split]
    examples = []
    for eg in dataset:
        examples.append(Example(idx = eg['id'], source=eg['nl'], target=eg['code']))
    return examples

examples = read_examples('validation')

from transformers import RobertaTokenizer, T5ForConditionalGeneration
import torch
from tqdm import tqdm
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0"

tokenizer = RobertaTokenizer.from_pretrained('Salesforce/codet5-base')
model = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-base')

model.load_state_dict(torch.load('finetuned_models_concode_codet5_base.bin'))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)
model.to(device)

preds = []
for eg in tqdm(examples):
    input_ids = tokenizer(eg.source, return_tensors="pt").input_ids.to(device)
#     print (len(input_ids[0]))
    generated_ids = model.generate(input_ids, max_length=200, num_beams=5)
    preds.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
    
import sys
import numpy as np
from bleu import _bleu

accs = []
with open("test.output",'w') as f, open("test.gold",'w') as f1:
    for ref,gold in zip(preds,examples[:len(preds)]):
        f.write(ref+'\n')
        f1.write(gold.target+'\n')    
        accs.append(ref.strip().split()==gold.target.split())

print (np.mean(accs), _bleu('test.gold', 'test.output'))
@yuewang-cuhk
Copy link
Contributor

Hi, before we get time to examine your code to figure out where the problem comes from, we suggest you to first employ the the run_gen.py script to reproduce the results. You can pass the do_test argument at here and pass the finetuned checkpoint to load at here.

@sindhura97
Copy link
Author

thanks, that worked. I got test BLEU 39.6 which is close to but not exactly the same reported in paper. Probably the paper used a different checkpoint?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants