Skip to content

Commit

Permalink
fix: fix huggingface logprob when not generating tokens and echoing p…
Browse files Browse the repository at this point in the history
…rompt
  • Loading branch information
ruixin31 committed Oct 17, 2023
1 parent 36c4d8a commit 3bddc3a
Showing 1 changed file with 44 additions and 25 deletions.
69 changes: 44 additions & 25 deletions src/helm/proxy/clients/huggingface_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,33 +84,44 @@ def serve_request(self, raw_request: Dict[str, Any]):
}

# Use HuggingFace's `generate` method.
output = self.model.generate(**encoded_input, **relevant_raw_request)
sequences = output.sequences
scores = output.scores
if relevant_raw_request["max_new_tokens"] == 0:
output = self.model(encoded_input["input_ids"])
sequences = encoded_input["input_ids"]
scores = output.logits
else:
output = self.model.generate(**encoded_input, **relevant_raw_request)
sequences = output.sequences
scores = output.scores

# Compute logprobs for each completed sequence.
all_logprobs_of_chosen_tokens = []
all_top_logprobs_dicts = []
for completion_id in range(raw_request["num_return_sequences"]):
logprobs_of_chosen_tokens = []
top_logprobs_dicts = []
for i in range(len(sequences[completion_id]) - len(encoded_input.input_ids[0])):
logprobs = torch.nn.functional.log_softmax(scores[i][completion_id], dim=0)

# Get top tokens in terms of log probability.
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
top_logprobs_dicts.append(
{
self.tokenizer.convert_ids_to_tokens(k.item()): v.item()
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)
}
)

# Get log probability of chosen token.
j = i + len(encoded_input.input_ids[0])
logprobs_of_chosen_tokens.append(logprobs[sequences[completion_id][j]].item())
all_logprobs_of_chosen_tokens.append(logprobs_of_chosen_tokens)
all_top_logprobs_dicts.append(top_logprobs_dicts)
if relevant_raw_request["max_new_tokens"] == 0 and raw_request["echo_prompt"]:
for completion_id in range(raw_request["num_return_sequences"]):
logprobs_of_chosen_tokens = []
top_logprobs_dicts = []
for i in range(len(sequences[completion_id]) - 1):
logprobs = torch.nn.functional.log_softmax(scores[completion_id][i], dim=0)
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
top_logprobs_dicts.append({self.tokenizer.convert_ids_to_tokens(k.item()): v.item()
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)})
logprobs_of_chosen_tokens.append(logprobs[sequences[completion_id][i + 1]].item())
all_logprobs_of_chosen_tokens.append(logprobs_of_chosen_tokens)
all_top_logprobs_dicts.append(top_logprobs_dicts)
else:
for completion_id in range(raw_request["num_return_sequences"]):
logprobs_of_chosen_tokens = []
top_logprobs_dicts = []
for i in range(len(sequences[completion_id]) - len(encoded_input.input_ids[0])):
logprobs = torch.nn.functional.log_softmax(scores[i][completion_id], dim=0)
# Get top tokens in terms of log probability.
topk_logprobs = torch.topk(logprobs, k=top_k_per_token)
top_logprobs_dicts.append({self.tokenizer.convert_ids_to_tokens(k.item()): v.item()
for (k, v) in zip(topk_logprobs.indices, topk_logprobs.values)})
j = i + len(encoded_input.input_ids[0])
logprobs_of_chosen_tokens.append(logprobs[sequences[completion_id][j]].item())
all_logprobs_of_chosen_tokens.append(logprobs_of_chosen_tokens)
all_top_logprobs_dicts.append(top_logprobs_dicts)

# Remove prompt from the start of each sequence if echo_prompt is False.
if not raw_request["echo_prompt"]:
Expand Down Expand Up @@ -222,8 +233,16 @@ def do_it():
if request.echo_prompt:
# Add prompt to list of generated tokens.
generated_tokens = raw_completion["tokens"][response["input_length"] :]
for token_text in raw_completion["tokens"][: response["input_length"]]:
tokens.append(Token(text=token_text, logprob=0.0, top_logprobs={}))
if request.max_tokens == 0:
for token_text, logprob, top_logprobs_dict in zip(
raw_completion["tokens"][: response["input_length"]], raw_completion["logprobs"][: response["input_length"]], raw_completion["top_logprobs_dicts"][: response["input_length"]]
):
tokens.append(Token(text=token_text, logprob=logprob, top_logprobs=top_logprobs_dict))
sequence_logprob += logprob
else:
for token_text in raw_completion["tokens"][: response["input_length"]]:
tokens.append(Token(text=token_text, logprob=0.0, top_logprobs={}))

else:
generated_tokens = raw_completion["tokens"]

Expand Down

0 comments on commit 3bddc3a

Please sign in to comment.