-
Notifications
You must be signed in to change notification settings - Fork 54
/
multi_channel_CNN.py
83 lines (62 loc) · 2.88 KB
/
multi_channel_CNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Version python3.6
# -*- coding: utf-8 -*-
# @Time : 2018/10/25 8:38 PM
# @Author : zenRRan
# @Email : zenrran@qq.com
# @File : multi_channel_CNN.py
# @Software: PyCharm Community Edition
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils.Embedding as Embedding
import random
class Multi_Channel_CNN(nn.Module):
def __init__(self, opts, vocab, label_vocab):
super(Multi_Channel_CNN, self).__init__()
random.seed(opts.seed)
torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed)
self.embed_dim = opts.embed_size
self.word_num = vocab.m_size
self.pre_embed_path = opts.pre_embed_path
self.string2id = vocab.string2id
self.embed_uniform_init = opts.embed_uniform_init
self.stride = opts.stride
self.kernel_size = opts.kernel_size
self.kernel_num = opts.kernel_num
self.label_num = label_vocab.m_size
self.embed_dropout = opts.embed_dropout
self.fc_dropout = opts.fc_dropout
self.embeddings = nn.Embedding(self.word_num, self.embed_dim)
self.embeddings_static = nn.Embedding(self.word_num, self.embed_dim)
if opts.pre_embed_path != '':
embedding = Embedding.load_predtrained_emb_zero(self.pre_embed_path, self.string2id)
self.embeddings_static.weight.data.copy_(embedding)
else:
nn.init.uniform_(self.embeddings_static.weight.data, -self.embed_uniform_init, self.embed_uniform_init)
nn.init.uniform_(self.embeddings.weight.data, -self.embed_uniform_init, self.embed_uniform_init)
# 2 convs
self.convs = nn.ModuleList(
[nn.Conv2d(2, self.embed_dim, (K, self.embed_dim), stride=self.stride, padding=(K // 2, 0)) for K in self.kernel_size])
in_fea = len(self.kernel_size)*self.kernel_num
self.linear1 = nn.Linear(in_fea, in_fea // 2)
self.linear2 = nn.Linear(in_fea // 2, self.label_num)
self.embed_dropout = nn.Dropout(self.embed_dropout)
self.fc_dropout = nn.Dropout(self.fc_dropout)
def forward(self, input):
static_embed = self.embeddings_static(input) # torch.Size([64, 39, 100])
embed = self.embeddings(input) # torch.Size([64, 39, 100])
x = torch.stack([static_embed, embed], 1) # torch.Size([64, 2, 39, 100])
out = self.embed_dropout(x)
l = []
for conv in self.convs:
l.append(F.relu(conv(out)).squeeze(3)) # torch.Size([64, 100, 39])
out = l
l = []
for i in out:
l.append(F.max_pool1d(i, kernel_size=i.size(2)).squeeze(2)) # torch.Size([64, 100])
out = torch.cat(l, 1) # torch.Size([64, 300])
out = self.fc_dropout(out)
out = self.linear1(out)
out = self.linear2(F.relu(out))
return out