In [3]:
import torch

In [4]:
class RNNClassifier(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_class):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.rnn = torch.nn.RNN(embed_dim,hidden_dim,batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, num_class)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.embedding(x)
        x,h = self.rnn(x)
        return self.fc(x.mean(dim=1))

In [5]:
model = RNNClassifier(5000, 64, 256, 4)

In [7]:
model.rnn(torch.randn(1,100,64))

(tensor([[[ 0.0052, -0.4289, -0.0501,  ..., -0.3901,  0.0942,  0.2847],
          [-0.5030,  0.3298, -0.2329,  ...,  0.0495,  0.1546,  0.0312],
          [-0.5493,  0.0651,  0.6114,  ...,  0.4245,  0.4552,  0.6055],
          ...,
          [-0.1120, -0.1923, -0.0251,  ...,  0.0572,  0.0594, -0.2676],
          [ 0.6103,  0.2674, -0.4204,  ...,  0.0762, -0.4201,  0.2749],
          [-0.2596,  0.1526,  0.2913,  ...,  0.2150, -0.0217,  0.4369]]],
        grad_fn=<TransposeBackward1>),
 tensor([[[-0.2596,  0.1526,  0.2913, -0.0447, -0.0676, -0.2721, -0.1365,
           -0.5334,  0.3480,  0.2448,  0.3164,  0.5514,  0.5669, -0.2088,
           -0.5465, -0.1527,  0.3243, -0.6180,  0.3710,  0.5349,  0.1681,
           -0.6927, -0.2804,  0.0221, -0.4349,  0.2242,  0.2617,  0.5988,
           -0.1834,  0.0803, -0.0392, -0.3802, -0.0813, -0.0890, -0.4505,
            0.4620,  0.2371, -0.2049,  0.1971, -0.1879, -0.0042, -0.0465,
           -0.4820, -0.2911,  0.5664,  0.2060,  0.5074, -0.2285,  0.

In [9]:
class LSTMClassifier(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_class):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.embedding.weight.data = torch.randn_like(self.embedding.weight.data)-0.5
        self.rnn = torch.nn.LSTM(embed_dim,hidden_dim,batch_first=True)
        self.fc = torch.nn.Linear(hidden_dim, num_class)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.embedding(x)
        x,(h,c) = self.rnn(x)
        return self.fc(h[-1])

In [10]:
model = LSTMClassifier(5000, 64, 256, 4)
model

LSTMClassifier(
  (embedding): Embedding(5000, 64)
  (rnn): LSTM(64, 256, batch_first=True)
  (fc): Linear(in_features=256, out_features=4, bias=True)
)

In [12]:
output, (h_n, c_n) = model.rnn(torch.randn(1,100,64))
output.shape, h_n.shape, c_n.shape

(torch.Size([1, 100, 256]), torch.Size([1, 1, 256]), torch.Size([1, 1, 256]))

In [13]:
torch.randint(0, 5000, (5,100))

tensor([[3460,  697,   98, 4403, 3865, 2044, 1700, 4818, 4937, 1590, 3791, 3876,
         2028,  416, 3175, 1616, 2536, 2316, 4504,  441, 3030, 4312, 4846, 4235,
         4805, 1129,  227, 2332,  743, 3270, 4908, 3780, 4767, 2682,  998, 2784,
         2982, 4757, 3418, 4648, 1404, 3081, 3416, 2440, 2625, 4127, 4999, 3656,
         4650,  965, 4371, 4032, 3184, 4707, 2107, 3681, 1013, 1516, 1470, 1067,
         1492, 3633, 2820, 4604,  406, 3093, 4339, 2097, 3427, 3852, 3549, 4581,
         1166,  684, 3147, 1087,   45, 1145, 1068,  925, 4456, 4280, 3419, 2425,
         2344, 2373,  922, 3334, 3357,  418, 3324,    2,  715, 4245,  119,  452,
         2442,  793, 2672, 2593],
        [1958, 4833, 1599, 1278, 2572, 2781, 4132, 2200, 2897, 1489, 4735,  372,
         4765, 3138, 4661,  610,  787,  828, 4662, 3194, 3157, 2165, 4791, 2725,
         1551, 4964, 4342, 1184, 3207, 1476, 2868,  876, 2048,  133, 3764, 2938,
          244,  251,  664, 3326, 1075, 2590,  687, 4669, 2262,  641, 4496, 

In [14]:
model(torch.randint(0, 5000, (5,100))).shape

torch.Size([5, 4])

In [15]:
h_n[-1].shape

torch.Size([1, 256])