-
Notifications
You must be signed in to change notification settings - Fork 161
/
seq2sql_condition_predict.py
122 lines (104 loc) · 4.76 KB
/
seq2sql_condition_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from net_utils import run_lstm
class Seq2SQLCondPredictor(nn.Module):
def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, gpu):
super(Seq2SQLCondPredictor, self).__init__()
print "Seq2SQL where prediction"
self.N_h = N_h
self.max_tok_num = max_tok_num
self.max_col_num = max_col_num
self.gpu = gpu
self.cond_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
num_layers=N_depth, batch_first=True,
dropout=0.3, bidirectional=True)
self.cond_decoder = nn.LSTM(input_size=self.max_tok_num,
hidden_size=N_h, num_layers=N_depth,
batch_first=True, dropout=0.3)
self.cond_out_g = nn.Linear(N_h, N_h)
self.cond_out_h = nn.Linear(N_h, N_h)
self.cond_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1))
self.softmax = nn.Softmax()
def gen_gt_batch(self, tok_seq, gen_inp=True):
# If gen_inp: generate the input token sequence (removing <END>)
# Otherwise: generate the output token sequence (removing <BEG>)
B = len(tok_seq)
ret_len = np.array([len(one_tok_seq)-1 for one_tok_seq in tok_seq])
max_len = max(ret_len)
ret_array = np.zeros((B, max_len, self.max_tok_num), dtype=np.float32)
for b, one_tok_seq in enumerate(tok_seq):
out_one_tok_seq = one_tok_seq[:-1] if gen_inp else one_tok_seq[1:]
for t, tok_id in enumerate(out_one_tok_seq):
ret_array[b, t, tok_id] = 1
ret_inp = torch.from_numpy(ret_array)
if self.gpu:
ret_inp = ret_inp.cuda()
ret_inp_var = Variable(ret_inp) #[B, max_len, max_tok_num]
return ret_inp_var, ret_len
def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
col_num, gt_where, gt_cond, reinforce):
max_x_len = max(x_len)
B = len(x_len)
h_enc, hidden = run_lstm(self.cond_lstm, x_emb_var, x_len)
decoder_hidden = tuple(torch.cat((hid[:2], hid[2:]),dim=2)
for hid in hidden)
if gt_where is not None:
gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where, gen_inp=True)
g_s, _ = run_lstm(self.cond_decoder,
gt_tok_seq, gt_tok_len, decoder_hidden)
h_enc_expand = h_enc.unsqueeze(1)
g_s_expand = g_s.unsqueeze(2)
cond_score = self.cond_out( self.cond_out_h(h_enc_expand) +
self.cond_out_g(g_s_expand) ).squeeze()
for idx, num in enumerate(x_len):
if num < max_x_len:
cond_score[idx, :, num:] = -100
else:
h_enc_expand = h_enc.unsqueeze(1)
scores = []
choices = []
done_set = set()
t = 0
init_inp = np.zeros((B, 1, self.max_tok_num), dtype=np.float32)
init_inp[:,0,7] = 1 #Set the <BEG> token
if self.gpu:
cur_inp = Variable(torch.from_numpy(init_inp).cuda())
else:
cur_inp = Variable(torch.from_numpy(init_inp))
cur_h = decoder_hidden
while len(done_set) < B and t < 100:
g_s, cur_h = self.cond_decoder(cur_inp, cur_h)
g_s_expand = g_s.unsqueeze(2)
cur_cond_score = self.cond_out(self.cond_out_h(h_enc_expand) +
self.cond_out_g(g_s_expand)).squeeze()
for b, num in enumerate(x_len):
if num < max_x_len:
cur_cond_score[b, num:] = -100
scores.append(cur_cond_score)
if not reinforce:
_, ans_tok_var = cur_cond_score.view(B, max_x_len).max(1)
ans_tok_var = ans_tok_var.unsqueeze(1)
else:
ans_tok_var = self.softmax(cur_cond_score).multinomial()
choices.append(ans_tok_var)
ans_tok = ans_tok_var.data.cpu()
if self.gpu: #To one-hot
cur_inp = Variable(torch.zeros(
B, self.max_tok_num).scatter_(1, ans_tok, 1).cuda())
else:
cur_inp = Variable(torch.zeros(
B, self.max_tok_num).scatter_(1, ans_tok, 1))
cur_inp = cur_inp.unsqueeze(1)
for idx, tok in enumerate(ans_tok.squeeze()):
if tok == 1: #Find the <END> token
done_set.add(idx)
t += 1
cond_score = torch.stack(scores, 1)
if reinforce:
return cond_score, choices
else:
return cond_score