-
Notifications
You must be signed in to change notification settings - Fork 0
/
process_data.py
52 lines (34 loc) · 1.51 KB
/
process_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from datasets import Dataset, load_dataset, DatasetDict
from config import *
# download dataset
dataset = load_dataset(DATASET_ID)
def format_instruction(medical_condition: str, treatment_options: str):
return f'''### Instruction:
For describe the treatment options the following conversation.
### Explaining medical conditions:
{medical_condition.strip()}
### Describe the treatment options:
{treatment_options.strip()}
'''.strip()
def generate_instruction_dataset(data_point):
return {
'medical_condition': data_point['instruction'],
'treatment_options': data_point['output'],
'text': format_instruction(medical_condition=data_point['instruction'], treatment_options=data_point['output'])
}
def process_dataset(data: Dataset):
return (
data.shuffle(seed=RANDOM_SEED)
.map(generate_instruction_dataset).remove_columns(['input',])
)
def split_dataset(dataset=dataset):
dataset = process_dataset(dataset)
train_testvalid = dataset['train'].train_test_split(test_size=0.2)
# Split the 10% test + valid in half test, half valid
test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
# gather everyone if you want to have a single DatasetDict
know_medical_dialogue_dataset = DatasetDict({
'train': train_testvalid['train'],
'test': test_valid['test'],
'validation': test_valid['train']})
return know_medical_dialogue_dataset