-
Notifications
You must be signed in to change notification settings - Fork 32
/
passage_nn.py
82 lines (66 loc) · 2.89 KB
/
passage_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Adapted from Passage's sentiment.py at
https://github.com/IndicoDataSolutions/Passage/blob/master/examples/sentiment.py
License: MIT
"""
import argparse
import numpy as np
from passage.models import RNN
from passage.updates import Adadelta
from passage.layers import Embedding, GatedRecurrent, LstmRecurrent, Dense
from passage.preprocessing import Tokenizer
from io_util import load_data
def main(ptrain, ntrain, ptest, ntest, out, modeltype):
assert modeltype in ["gated_recurrent", "lstm_recurrent"]
print("Using the %s model ..." % modeltype)
print("Loading data ...")
trX, trY = load_data(ptrain, ntrain)
teX, teY = load_data(ptest, ntest)
tokenizer = Tokenizer(min_df=10, max_features=100000)
trX = tokenizer.fit_transform(trX)
teX = tokenizer.transform(teX)
print("Training ...")
if modeltype == "gated_recurrent":
layers = [
Embedding(size=256, n_features=tokenizer.n_features),
GatedRecurrent(size=512, activation='tanh', gate_activation='steeper_sigmoid',
init='orthogonal', seq_output=False, p_drop=0.75),
Dense(size=1, activation='sigmoid', init='orthogonal')
]
else:
layers = [
Embedding(size=256, n_features=tokenizer.n_features),
LstmRecurrent(size=512, activation='tanh', gate_activation='steeper_sigmoid',
init='orthogonal', seq_output=False, p_drop=0.75),
Dense(size=1, activation='sigmoid', init='orthogonal')
]
model = RNN(layers=layers, cost='bce', updater=Adadelta(lr=0.5))
model.fit(trX, trY, n_epochs=10)
# Predicting the probabilities of positive labels
print("Predicting ...")
pr_teX = model.predict(teX).flatten()
predY = np.ones(len(teY))
predY[pr_teX < 0.5] = -1
with open(out, "w") as f:
for lab, pos_pr, neg_pr in zip(predY, pr_teX, 1 - pr_teX):
f.write("%d %f %f\n" % (lab, pos_pr, neg_pr))
if __name__ == "__main__":
"""
Usage :
python passage_nn.py\
--ptrain /PATH/data/full-train-pos.txt\
--ntrain /PATH/data/full-train-neg.txt\
--ptest /PATH/data/test-pos.txt\
--ntest /PATH/data/test-neg.txt\
--modeltype model_type\
--out TEST-SCORE
"""
parser = argparse.ArgumentParser(description='Use Passage for sentiment analysis.')
parser.add_argument('--ptrain', help='path of the text file TRAIN POSITIVE')
parser.add_argument('--ntrain', help='path of the text file TRAIN NEGATIVE')
parser.add_argument('--ptest', help='path of the text file TEST POSITIVE')
parser.add_argument('--ntest', help='path of the text file TEST NEGATIVE')
parser.add_argument('--modeltype', help='Passage\'s model type: gated_recurrent or lstm_recurrent')
parser.add_argument('--out', help='path and filename for score output')
args = vars(parser.parse_args())
main(**args)