In [1]:
import torch
from torch.autograd._functions import tensor
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class fasttext(nn.Module):
    def __init__(self, vocab_size, twoGrams_size, threeGrams_size, embed_size, hidden_size, output_size,
                 embedding_pretrained=None):
        super(fasttext, self).__init__()
        if embedding_pretrained is None:
            self.embedding_word = nn.Embedding(vocab_size, embed_size)
        else:
            self.embedding_word = nn.Embedding.from_pretrained(embedding_pretrained, freeze=False)

        self.embedding_2gram = nn.Embedding(twoGrams_size, embed_size)
        self.embedding_3gram = nn.Embedding(threeGrams_size, embed_size)
        self.dropout = nn.Dropout(p=0.5)

        self.hidden = nn.Linear(embed_size, hidden_size)

        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        e_word = self.embedding_word(x[0])
        e_2gram = self.embedding_2gram(x[1])
        e_3gram = self.embedding_3gram(x[2])
        e_cat = torch.cat((e_word, e_2gram, e_3gram), dim=1)
        e_avg = e_cat.mean(dim=1)
        h = self.hidden(self.dropout(e_avg))
        o = F.softmax(self.output(h))
        return o, {
            "embedding_word": e_word,
            "embedding_2gram": e_2gram,
            "embedding_3gram": e_3gram,
            "e_cat": e_cat,
            "e_avg": e_avg,
            "hidden": h
        }

In [3]:
vocab_size = 10
twoGrams_size = 20
threeGrams_size = 30
embed_size = 128
hidden_size = 256
output_size = 16
ft = fasttext(vocab_size, twoGrams_size, threeGrams_size, embed_size, hidden_size, output_size)

In [4]:
x_0 = torch.LongTensor([[1, 2, 3, 3, 5]])  #batch_size = 1, seq_len = 5
x_1 = torch.LongTensor([[1, 2, 3, 4]])  #batch_size =1, seq_len = 4
x_2 = torch.LongTensor([[1, 2, 3]])  #batch_size=1, seq_len=3
x = (x_0, x_1, x_2)
output, tmp = ft(x)
print("embedding_word:", tmp["embedding_word"].size())
print("embedding_2gram:", tmp["embedding_2gram"].size())
print("embedding_3gram:", tmp["embedding_3gram"].size())
print("e_cat:", tmp["e_cat"].size())
print("e_avg:", tmp["e_avg"].size())
print("hidden:", tmp["hidden"].size())
print("output", output.size())

embedding_word: torch.Size([1, 5, 128])
embedding_2gram: torch.Size([1, 4, 128])
embedding_3gram: torch.Size([1, 3, 128])
e_cat: torch.Size([1, 12, 128])
e_avg: torch.Size([1, 128])
hidden: torch.Size([1, 256])
output torch.Size([1, 16])


  o = F.softmax(self.output(h))
