In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class CNN_Text(nn.Module):
    
    def __init__(self, args):
        super(CNN_Text, self).__init__()
        self.args = args
        
        V = args.embed_num
        D = args.embed_dim
        C = args.class_num
        Ci = 1
        Co = args.kernel_num
        Ks = args.kernel_sizes

        self.embed = nn.Embedding(V, D)
        
        self.convs1 = nn.ModuleList([nn.Conv2d(Ci, Co, (K, D)) for K in Ks])
        
        self.dropout = nn.Dropout(args.dropout)
        
        self.fc1 = nn.Linear(len(Ks)*Co, C)

    def forward(self, x):
        x = self.embed(x)  # (N, W, D)
        
        if self.args.static:
            x = Variable(x)

        x = x.unsqueeze(1)  # (N, Ci, W, D)

        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]  # [(N, Co, W), ...]*len(Ks)

        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]  # [(N, Co), ...]*len(Ks)

        x = torch.cat(x, 1)

        x = self.dropout(x)  # (N, len(Ks)*Co)

        logit = self.fc1(x)  # (N, C)
        
        return logit

In [20]:
class args:
    pass

args = args()

args.embed_num = 2000
args.embed_dim = 300
args.class_num = 2
args.kernel_num = 2
args.kernel_sizes = [2,3,4,5,6]
args.dropout = 0.5
args.static = True

cnn = CNN_Text(args)

In [22]:
cnn

CNN_Text(
  (embed): Embedding(2000, 300)
  (convs1): ModuleList(
    (0): Conv2d(1, 2, kernel_size=(2, 300), stride=(1, 1))
    (1): Conv2d(1, 2, kernel_size=(3, 300), stride=(1, 1))
    (2): Conv2d(1, 2, kernel_size=(4, 300), stride=(1, 1))
    (3): Conv2d(1, 2, kernel_size=(5, 300), stride=(1, 1))
    (4): Conv2d(1, 2, kernel_size=(6, 300), stride=(1, 1))
  )
  (dropout): Dropout(p=0.5)
  (fc1): Linear(in_features=10, out_features=2, bias=True)
)