-
Notifications
You must be signed in to change notification settings - Fork 0
/
no_dot_data_porcessing.py
71 lines (38 loc) · 2.2 KB
/
no_dot_data_porcessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import numpy as np
import pandas as pd
print('#########################################')
print('Precessing....')
saved_path = '.\data\processed_training_data'
file_name_set = ["symp_diagnosis_relation_training", "complaint_training_data", "combine_complaint_symp"]
def replace_dot(df_train):
temp_list = [idx.split('[SEP]') for idx in df_train['symptoms'].values]
for idx in range(len(temp_list)):
temp_list[idx][1] = temp_list[idx][1].replace('.', '')
df_train['symptoms'] = np.array([idx[0] + ' [SEP] ' + idx[1] for idx in temp_list])
return df_train
for file_idx in range(len(file_name_set)):
text_path = '.\data\\' + file_name_set[file_idx] + '.csv'
data_path = os.path.join(saved_path, file_name_set[file_idx])
df_train_path = os.path.join(data_path, 'train.csv')
df_val_path = os.path.join(data_path, 'val.csv')
df_test_path = os.path.join(data_path, 'test.csv')
df_train = pd.read_csv(df_train_path)
df_val = pd.read_csv(df_val_path)
df_test = pd.read_csv(df_test_path)
original_text = pd.read_csv(text_path)
if file_idx is 0:
df_train['symptoms'] = np.array([idx.replace('.', '') for idx in df_train['symptoms'].values])
df_val['symptoms'] = np.array([idx.replace('.', '') for idx in df_val['symptoms'].values])
df_test['symptoms'] = np.array([idx.replace('.', '') for idx in df_test['symptoms'].values])
original_text['symptoms'] = np.array([idx.replace('.', '') for idx in original_text['symptoms'].values])
elif file_idx is 2:
df_train = replace_dot(df_train)
df_val = replace_dot(df_val)
df_test = replace_dot(df_test)
original_text = replace_dot(original_text)
df_train.to_csv(os.path.join(data_path, 'no_dot_train.csv'), index=False, encoding="utf-8")
df_val.to_csv(os.path.join(data_path, 'no_dot_val.csv'), index=False, encoding="utf-8")
df_test.to_csv(os.path.join(data_path, 'no_dot_test.csv'), index=False, encoding="utf-8")
original_text.to_csv(os.path.join(data_path, 'no_dot_original_text.csv'), index=False, encoding="utf-8")
print('Has already delete the dot!!')