-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
34 lines (31 loc) · 1.38 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
### 建立网络
import math
import mindspore as ms
import mindspore.nn as nn
import mindspore.numpy as mnp
import mindspore.ops as ops
from mindspore.common.initializer import Uniform, HeUniform
class RNN(nn.Cell):
def __init__(self, embedding_dim, hidden_dim, output_dim, n_layers,
bidirectional, dropout):
super().__init__()
# vocab_size, embedding_dim = inputs.shape
# self.embedding = nn.Embedding(vocab_size, embedding_dim, embedding_table=ms.Tensor(embeddings), padding_idx=pad_idx)
self.rnn = nn.LSTM(embedding_dim,
hidden_dim,
num_layers=n_layers,
bidirectional=bidirectional,
dropout=dropout,
batch_first=True)
weight_init = HeUniform(math.sqrt(5))
bias_init = Uniform(1 / math.sqrt(hidden_dim * 2))
self.fc = nn.Dense(hidden_dim * 2, output_dim, weight_init=weight_init, bias_init=bias_init)
self.dropout = nn.Dropout(1 - dropout)
self.sigmoid = ops.Sigmoid()
def construct(self, inputs):
embedded = self.dropout(inputs)
_, (hidden, _) = self.rnn(embedded)
hidden = self.dropout(mnp.concatenate((hidden[-2, :, :], hidden[-1, :, :]), axis=1))
output = self.fc(hidden)
output = output.squeeze()
return self.sigmoid(output)