Skip to content

Commit

Permalink
eosphoros-ai#122: support multi-turn datasets(chase\sparc\cosql), and…
Browse files Browse the repository at this point in the history
… merge all data together
  • Loading branch information
John-Saxon committed Nov 6, 2023
1 parent 1fbd892 commit 03a6f49
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 29 deletions.
28 changes: 28 additions & 0 deletions dbgpt_hub/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
EXT2TYPE = {"csv": "csv", "json": "json", "jsonl": "json", "txt": "text"}

# text2sql dataset information for processing sql data
# TODO: BIRD \ WiKiSQL \ ...
SQL_DATA_INFO = [
{
"data_source": "spider",
Expand All @@ -53,6 +54,33 @@
"db_id_name": "db_id",
"is_multiple_turn": False,
}
,
{
"data_source": "chase",
"train_file": ["Chase/chase_train.json"],
"dev_file": ["Chase/chase_dev.json"],
"tables_file": "Chase/chase_tables.json",
"db_id_name": "database_id",
"is_multiple_turn": True,
}
,
{
"data_source": "cosql_dataset",
"train_file": ["sql_state_tracking/cosql_train.json"],
"dev_file": ["sql_state_tracking/cosql_dev.json"],
"tables_file": "tables.json",
"db_id_name": "database_id",
"is_multiple_turn": True,
}
,
{
"data_source": "sparc",
"train_file": ["train.json"],
"dev_file": ["dev.json"],
"tables_file": "tables.json",
"db_id_name": "database_id",
"is_multiple_turn": True,
}
]
INSTRUCTION_PROMPT = """\
I want you to act as a SQL terminal in front of an example database, \
Expand Down
83 changes: 54 additions & 29 deletions dbgpt_hub/data_process/sql_data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@


class ProcessSqlData:
def __init__(self) -> None:
pass
def __init__(self, train_file=None, dev_file=None) -> None:
self.train_file = train_file
self.dev_file = dev_file

def decode_json_file(self, data_file_list, table_file, out_file):
def decode_json_file(self, data_file_list, table_file, db_id_name, is_multiple_turn=False):
"""
TO DO:
1.将相关prompt放入config中
2.将不同数据来源的字段信息放入config中
3.支持多轮对话数据集
"""

if table_file.endswith(".jsonl"):
Expand Down Expand Up @@ -87,46 +87,71 @@ def decode_json_file(self, data_file_list, table_file, out_file):
# 单论对话
res = []
for data in tqdm(datas):
if data["db_id"] in db_dict.keys():
input = {
"db_id": data["db_id"],
"instruction": INSTRUCTION_PROMPT.format(db_dict[data["db_id"]]),
"input": INPUT_PROMPT.format(data["question"]),
"output": data["query"],
"history": [],
}
res.append(input)

with open(out_file, "w", encoding="utf-8") as s:
json.dump(res, s, indent=4, ensure_ascii=False)
if data[db_id_name] in db_dict.keys():
if is_multiple_turn:
history = []
for interaction in data["interaction"]:
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(db_dict[data[db_id_name]]),
"input": INPUT_PROMPT.format(interaction["utterance"]),
"output": interaction["query"],
"history": history,
}
res.append(input)
history.append((INPUT_PROMPT.format(interaction["utterance"]), interaction["query"]))
else:
input = {
"db_id": data[db_id_name],
"instruction": INSTRUCTION_PROMPT.format(db_dict[data[db_id_name]]),
"input": INPUT_PROMPT.format(data["question"]),
"output": data["query"],
"history": [],
}
res.append(input)
return res

def create_sft_raw_data(self):
train_data = []
dev_data = []
for data_info in SQL_DATA_INFO:
train_data_file_list = [
os.path.join(DATA_PATH, data_info["data_source"], file)
for file in data_info["train_file"]
]
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
out_file=os.path.join(DATA_PATH, "example_text2sql_train.json"),
train_data.extend(
self.decode_json_file(
data_file_list=train_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
db_id_name=data_info["db_id_name"],
is_multiple_turn=data_info['is_multiple_turn']
)
)

dev_data_file_list = [
os.path.join(DATA_PATH, data_info["data_source"], file)
for file in data_info["dev_file"]
]
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
out_file=os.path.join(DATA_PATH, "example_text2sql_dev.json"),
dev_data.extend(
self.decode_json_file(
data_file_list=dev_data_file_list,
table_file=os.path.join(
DATA_PATH, data_info["data_source"], data_info["tables_file"]
),
db_id_name=data_info["db_id_name"],
is_multiple_turn=data_info['is_multiple_turn']
)
)
with open(self.train_file, "w", encoding="utf-8") as s:
json.dump(train_data, s, indent=4, ensure_ascii=False)
with open(self.dev_file, "w", encoding="utf-8") as s:
json.dump(dev_data, s, indent=4, ensure_ascii=False)


if __name__ == "__main__":
precess = ProcessSqlData()
all_in_one_train_file = os.path.join(DATA_PATH, "example_text2sql_train.json")
all_in_one_dev_file = os.path.join(DATA_PATH, "example_text2sql_dev.json")
precess = ProcessSqlData(train_file=all_in_one_train_file, dev_file=all_in_one_dev_file)
precess.create_sft_raw_data()

0 comments on commit 03a6f49

Please sign in to comment.