Skip to content
This repository has been archived by the owner on Jul 31, 2024. It is now read-only.

eval_bleu with pretrained gpt model #3

Closed
Gugutse opened this issue Oct 4, 2021 · 5 comments
Closed

eval_bleu with pretrained gpt model #3

Gugutse opened this issue Oct 4, 2021 · 5 comments
Labels
question Further information is requested

Comments

@Gugutse
Copy link

Gugutse commented Oct 4, 2021

Hi @wasiahmad,
I'm trying to evaluate a gpt-2 model with your code. Thus, I run run.py with microsoft/CodeGPT-small-py in pretrain_dir parameter and do_infer. In eval_blue script outputs equal to model(inputs)[1] – these are hidden states of pretrained gpt – and it's a tuple of 12 elements (n_layers) consisting of 2 elements each, and these two have [1, 12, 48, 64]. When it goes to this line past_hidden = [x[:, i:i + 1].expand(-1, beam_size, -1, -1, -1) for x in outputs] an error occurs: TypeError: tuple indices must be integers or slices, not tuple – and it also implies that the shape of each element in outputs should have 5 dimensions.
Which corrections should be done in this case?

@wasiahmad
Copy link
Owner

Too many information and thus unable to understand the problem. I don't see any eval_blue script too. Please look at the code for other models.

@Gugutse
Copy link
Author

Gugutse commented Oct 5, 2021

eval_bleu is here: https://github.com/wasiahmad/AVATAR/blob/main/codegpt/run.py

for step, (batch, token_labels) in enumerate(test_dataloader):
        inputs = batch.to(args.device)
        with torch.no_grad():
            beam_size = args.beam_size
            m = torch.nn.LogSoftmax(dim=-1)
            outputs = model(inputs)[1]
            p = []
            zero = torch.cuda.LongTensor(1).fill_(0)
            for i in range(inputs.shape[0]):
                past_hidden = [x[:, i:i + 1].expand(-1, beam_size, -1, -1, -1) for x in outputs]

I would like to evaluate the model, but the error TypeError: tuple indices must be integers or slices, not tuple occurs on the line past_hidden = [x[:, i:i + 1].expand(-1, beam_size, -1, -1, -1) for x in outputs]

@wasiahmad
Copy link
Owner

I am not sure. The CodeGPT codebase is supposed to work fine. I am not sure why you are getting this error. I will run again in my environment to see if CodeGPT training and evaluation work correctly.

@Gugutse
Copy link
Author

Gugutse commented Oct 5, 2021

The problem is that GPT-2 model's output format has changed in the newer versions of Huggingface's transformers library – thus, it needs to apply torch.stack() to the output.

@wasiahmad
Copy link
Owner

I see, you can modify the code to be compatible with the newer versions of the transformers API. I am closing this issue for now.

@wasiahmad wasiahmad added the question Further information is requested label Dec 3, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants