Skip to content

Commit

Permalink
support hellaswag dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed May 5, 2024
1 parent 74b4fc7 commit e312bd9
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions mlora/tasks/qa_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,42 @@ def loading_data(self,
return ret


class HellaSwag(QuestionAnswerTask):
def __init__(self) -> None:
super().__init__(["A", "B", "C", "D"])

def loading_data(self,
tokenizer: Tokenizer,
is_train: bool = True) -> List[DataClass]:
data = hf_datasets.load_dataset(
"Rowan/hellaswag")["train" if is_train else "validation"]
logging.info("Preparing data for HellaSwag")
ret: List[DataClass] = []
for idx, data_point in enumerate(data):
prompt = "Please choose the correct ending to complete the given sentence: " + \
f"{data_point['activity_label']}. {data_point['ctx']}\n"
for label, text in enumerate(data_point["endings"]):
prompt += f" ({self.labels_[label]}) {text}"
prompt += "\nAnswer:"
if is_train:
prompt += " " + self.labels_[int(data_point["label"])]
labels = None
else:
labels = [int(data_point["label"])]
tokens = tokenizer.encode(data=prompt)
ret.append(DataClass(tokens_=tokens, labels_=labels))
if idx % 10000 == 0:
logging.info(f"Encode text data: {idx}/{len(data)}")

return ret


def update_task_dict(task_dict):
task_dict.update({
"arc-e": ARC("ARC-Easy"),
"arc-c": ARC("ARC-Challenge"),
"boolq": Boolq(),
"obqa": OpenBookQA(),
"piqa": PIQA(),
"hellas": HellaSwag(),
})

0 comments on commit e312bd9

Please sign in to comment.