-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathprepare_data.py
executable file
·93 lines (82 loc) · 3.28 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python
import pandas as pd
from tqdm import tqdm
SRCTRAINFILE='yahoo_answers_csv/train.csv'
SRCVALIDATIONFILE='yahoo_answers_csv/test.csv'
DSTTRAINFILE='comprehend-train.csv'
DSTVALIDATIONFILE='comprehend-test.csv'
# Preparation of the train set
trainFrame = pd.read_csv(SRCTRAINFILE, header=None)
tqdm.pandas()
# Amazon Comprehend "recommend[s] that you train the model with up to 1,000 training documents for
# each label". and no more than 1000000 documents.
#
# Here, we are limiting to 20000 documents per label in order to reduce costs of this demo.
#
# If you want to test Amazon Comprehend on the full dataset, set MAXITEM to 100000
# MAXITEM=100000
MAXITEM=10000
# Keeping MAXITEM for each label
for i in range(1, 11):
num = len(trainFrame[trainFrame[0] == i])
dropnum = num - MAXITEM
indextodrop = trainFrame[trainFrame[0] == i].sample(n=dropnum).index
trainFrame.drop(indextodrop, inplace=True)
# Applying translation of numerical codes into labels
trainFrame[0] = trainFrame[0].progress_apply({
1:'SOCIETY_AND_CULTURE',
2:'SCIENCE_AND_MATHEMATICS',
3:'HEALTH',
4:'EDUCATION_AND_REFERENCE',
5:'COMPUTERS_AND_INTERNET',
6:'SPORTS',
7:'BUSINESS_AND_FINANCE',
8:'ENTERTAINMENT_AND_MUSIC',
9:'FAMILY_AND_RELATIONSHIPS',
10:'POLITICS_AND_GOVERNMENT'
}.get)
# Joining "Question title", "question content", and "best answer".
trainFrame['document'] = trainFrame[trainFrame.columns[1:]].progress_apply(
lambda x: ' \\n '.join(x.dropna().astype(str)),
axis=1
)
# Keeping only the first two columns: label and joint text
trainFrame.drop([1, 2, 3], axis=1, inplace=True)
# Escaping ','
trainFrame['document'] = trainFrame['document'].str.replace(',', ',')
# Writing csv file
trainFrame.to_csv(path_or_buf=DSTTRAINFILE,
header=False,
index=False,
escapechar='\\',
doublequote=False,
quotechar='"')
# Preparation of the validation set
validationFrame = pd.read_csv(SRCVALIDATIONFILE, header=None)
tqdm.pandas()
# Here, we are limiting to 100 documents to test in order to reduce costs of this demo.
# If you want to test Amazon Comprehend on the full dataset, set MAXITEM to None
# MAXITEM=None
MAXITEM=100
# Keeping MAXITEM
if MAXITEM:
num = len(validationFrame)
dropnum = num - MAXITEM
indextodrop = validationFrame.sample(n=dropnum).index
validationFrame.drop(indextodrop, inplace=True)
# Joining "Question title", "question content", and "best answer".
validationFrame['document'] = validationFrame[validationFrame.columns[1:]].progress_apply(
lambda x: ' \\n '.join(x.dropna().astype(str)),
axis=1
)
# Removing all column but the aggregated one
validationFrame.drop([0, 1, 2, 3], axis=1, inplace=True)
# Escaping ','
validationFrame['document'] = validationFrame['document'].str.replace(',', ',')
# Writing csv file
validationFrame.to_csv(path_or_buf=DSTVALIDATIONFILE,
header=False,
index=False,
escapechar='\\',
doublequote=False,
quotechar='"')