In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# LSTM

<img src="fig/LSTM.png">

Math formulas:  

$
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}
$


LSTM expects all of its inputs to be 3D tensors 

|dim  |dim content|
|:---:|:---------:|
|1st  |sequence (along words in a sentence)|
|2ed  |mini-batch                          |
|3rd  |elements (embedding vector)         |

In [2]:
# input dim=3, output dim=3
lstm = nn.LSTM(3, 3)
list(lstm.parameters())

[Parameter containing:
 tensor([[ 0.2303,  0.2916, -0.5163],
         [-0.5510, -0.3449, -0.2393],
         [-0.4055, -0.5512, -0.5466],
         [ 0.3284,  0.2614, -0.3736],
         [ 0.4501,  0.0070, -0.2560],
         [-0.0455,  0.1008, -0.3602],
         [ 0.5728, -0.5570, -0.2382],
         [-0.1546, -0.1515, -0.1301],
         [-0.1647,  0.0114, -0.3590],
         [-0.4640, -0.0678,  0.5465],
         [ 0.4682, -0.2864, -0.4383],
         [ 0.0923,  0.3207,  0.0377]]), Parameter containing:
 tensor([[ 0.1018,  0.4423, -0.0289],
         [-0.4877,  0.2096,  0.1797],
         [ 0.5208, -0.2490,  0.4127],
         [ 0.5612,  0.3022,  0.5166],
         [ 0.3660,  0.5360,  0.1614],
         [-0.0994, -0.0956,  0.3715],
         [ 0.5119, -0.0050,  0.4271],
         [-0.3460,  0.3529,  0.0926],
         [-0.2276, -0.3582, -0.0635],
         [-0.4995,  0.0098,  0.0274],
         [-0.4142, -0.5609, -0.3524],
         [ 0.4259, -0.1725,  0.2226]]), Parameter containing:
 tensor([-0.3728,

### Step seq one element at a time

In [3]:
torch.manual_seed(0)
# seq-length=5, batch-size=1, emb-dim=3
ins = torch.randn(5, 1, 3)

# initialize the hidden state
# it is (h_0, c_0) in picture above
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))

# Step through the sequence one element at a time.
for ins_i in ins:
    out, hidden = lstm(ins_i.view(1, 1, -1), hidden)
    # out is always the first of two elements of hidden
    # hidden is (h_t, c_t) in picture above
    print(out)
    print(hidden)

tensor([[[-0.0182, -0.5513,  0.4001]]])
(tensor([[[-0.0182, -0.5513,  0.4001]]]), tensor([[[-0.1637, -0.7133,  0.9240]]]))
tensor([[[ 0.0512,  0.1123,  0.3908]]])
(tensor([[[ 0.0512,  0.1123,  0.3908]]]), tensor([[[ 0.2024,  0.1343,  1.2729]]]))
tensor([[[-0.0490,  0.2408,  0.4710]]])
(tensor([[[-0.0490,  0.2408,  0.4710]]]), tensor([[[-0.1568,  0.4381,  1.0895]]]))
tensor([[[-0.1252,  0.3123,  0.3314]]])
(tensor([[[-0.1252,  0.3123,  0.3314]]]), tensor([[[-0.2275,  0.7763,  0.9236]]]))
tensor([[[-0.2550,  0.3341,  0.4322]]])
(tensor([[[-0.2550,  0.3341,  0.4322]]]), tensor([[[-0.6994,  0.9446,  1.0604]]]))


### Step seq all at once

In [4]:
torch.manual_seed(0)
# seq-length=5, batch-size=1, emb-dim=3
ins = torch.randn(5, 1, 3)

# initialize the hidden state
# it is (h_0, c_0) in picture above
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))

# do the entire sequence all at once.
outs, hidden = lstm(ins, hidden)
print(outs)
print(hidden)

tensor([[[-0.0182, -0.5513,  0.4001]],

        [[ 0.0512,  0.1123,  0.3908]],

        [[-0.0490,  0.2408,  0.4710]],

        [[-0.1252,  0.3123,  0.3314]],

        [[-0.2550,  0.3341,  0.4322]]])
(tensor([[[-0.2550,  0.3341,  0.4322]]]), tensor([[[-0.6994,  0.9446,  1.0604]]]))


In [5]:
# use the output-sequence as the input-sequence of the next LSTM-layer
# use last hidden state as the initial hidden state of the next LSTM-layer
outs, hidden = lstm(outs, hidden)
print(outs)
print(hidden)

tensor([[[-0.2576,  0.3855,  0.2870]],

        [[-0.2681,  0.3511,  0.2528]],

        [[-0.2961,  0.3242,  0.2180]],

        [[-0.3105,  0.3283,  0.2119]],

        [[-0.3320,  0.3066,  0.2043]]])
(tensor([[[-0.3320,  0.3066,  0.2043]]]), tensor([[[-0.6347,  0.8573,  0.4443]]]))


### Batched input and output of LSTM layer

In [6]:
# batch input & output
# seq-length=5, batch-size=16, emb-dim=3
ins = torch.randn(5, 16, 3)

# The initial hidden state should be the same across mini-batches
hidden = (torch.zeros(1, 16, 3),
          torch.zeros(1, 16, 3))

outs, hidden = lstm(ins, hidden)
print(outs.size())
print(hidden[0].size(), hidden[1].size())

torch.Size([5, 16, 3])
torch.Size([1, 16, 3]) torch.Size([1, 16, 3])


### Batch-first

In [7]:
# input dim=3, output dim=3
lstm = nn.LSTM(3, 3, batch_first=True)

# batch input & output
# batch-size=16, seq-length=5, emb-dim=3
ins = torch.randn(16, 5, 3)

# The initial hidden state should be the same across mini-batches
hidden = (torch.zeros(1, 16, 3),
          torch.zeros(1, 16, 3))

outs, hidden = lstm(ins, hidden)
print(outs.size())
print(hidden[0].size(), hidden[1].size())

torch.Size([16, 5, 3])
torch.Size([1, 16, 3]) torch.Size([1, 16, 3])


### Packed (masked) input and output of LSTM layer

In [8]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# input dim=3, output dim=3
lstm = nn.LSTM(3, 3, batch_first=True)

# Masked input & output
# batch-size=4, seq-length=5, emb-dim=3
ins = torch.randn(4, 5, 3)
print(ins)

# lengths array has to be sorted in a decreasing order
# The first sentence has 4 valid words, the second has 3 valid words
packed_ins = pack_padded_sequence(ins, lengths=[5, 5, 4, 2], batch_first=True)
print(packed_ins)

tensor([[[ 0.2712, -1.0442,  1.8786],
         [ 1.8543,  0.7049,  0.0305],
         [-0.8542,  0.5388,  1.1896],
         [-0.9635, -0.3666,  1.3375],
         [-2.0546,  0.5259,  0.5995]],

        [[-0.4078,  0.7882, -1.2330],
         [ 0.4544, -1.4366, -0.0984],
         [ 0.4855,  0.7076,  0.0431],
         [ 0.5976, -0.2389, -0.3460],
         [ 1.7701, -1.0646, -0.2297]],

        [[-1.2564,  0.5570, -0.2271],
         [ 0.6575, -0.1183,  0.6762],
         [-1.8095,  0.6926,  1.1982],
         [ 1.3167,  1.0615, -0.6378],
         [ 0.4384,  0.9643,  0.5926]],

        [[-0.1689, -2.2863,  0.2011],
         [-0.3304, -1.5413, -1.2677],
         [-0.6478,  3.4754,  0.6533],
         [ 2.9936,  0.0164, -0.0544],
         [-0.9569,  0.0034, -1.0419]]])
PackedSequence(data=tensor([[ 0.2712, -1.0442,  1.8786],
        [-0.4078,  0.7882, -1.2330],
        [-1.2564,  0.5570, -0.2271],
        [-0.1689, -2.2863,  0.2011],
        [ 1.8543,  0.7049,  0.0305],
        [ 0.4544, -1.4366, 

In [9]:
hidden = (torch.zeros(1, 4, 3),
          torch.zeros(1, 4, 3))

packed_outs, hidden = lstm(packed_ins, hidden)
packed_outs

PackedSequence(data=tensor([[-0.1482, -0.0951, -0.1542],
        [ 0.1282, -0.0283,  0.0724],
        [ 0.0993, -0.0539,  0.0180],
        [-0.0716, -0.2412, -0.3592],
        [-0.2117, -0.0310, -0.1532],
        [ 0.0232, -0.1885, -0.3005],
        [ 0.0083, -0.1049, -0.1695],
        [ 0.0072, -0.2328, -0.5603],
        [-0.0351, -0.0481, -0.1594],
        [-0.0017, -0.1050, -0.1681],
        [ 0.0548, -0.1002, -0.0869],
        [-0.0363, -0.1343, -0.1835],
        [-0.0372, -0.1242, -0.3132],
        [ 0.0851, -0.0450, -0.1015],
        [ 0.0588, -0.1205, -0.0732],
        [-0.1482, -0.1685, -0.3908]]), batch_sizes=tensor([ 4,  4,  3,  3,  2]))

In [10]:
outs, lengths = pad_packed_sequence(packed_outs, batch_first=True)
print(outs)
print(lengths)

tensor([[[-0.1482, -0.0951, -0.1542],
         [-0.2117, -0.0310, -0.1532],
         [-0.0351, -0.0481, -0.1594],
         [-0.0363, -0.1343, -0.1835],
         [ 0.0588, -0.1205, -0.0732]],

        [[ 0.1282, -0.0283,  0.0724],
         [ 0.0232, -0.1885, -0.3005],
         [-0.0017, -0.1050, -0.1681],
         [-0.0372, -0.1242, -0.3132],
         [-0.1482, -0.1685, -0.3908]],

        [[ 0.0993, -0.0539,  0.0180],
         [ 0.0083, -0.1049, -0.1695],
         [ 0.0548, -0.1002, -0.0869],
         [ 0.0851, -0.0450, -0.1015],
         [ 0.0000,  0.0000,  0.0000]],

        [[-0.0716, -0.2412, -0.3592],
         [ 0.0072, -0.2328, -0.5603],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000]]])
tensor([ 5,  5,  4,  2])


In [11]:
hidden

(tensor([[[ 0.0588, -0.1205, -0.0732],
          [-0.1482, -0.1685, -0.3908],
          [ 0.0851, -0.0450, -0.1015],
          [ 0.0072, -0.2328, -0.5603]]]), tensor([[[ 0.4140, -0.3131, -0.1218],
          [-0.2129, -0.2563, -1.1187],
          [ 0.1875, -0.0940, -0.3576],
          [ 0.0130, -0.6090, -0.9075]]]))

## Bidirectional LSTM

In [12]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# input dim=3, output dim=3
bi_lstm = nn.LSTM(3, 3, batch_first=True, bidirectional=True)

# Masked input & output
# batch-size=4, seq-length=5, emb-dim=3
ins = torch.randn(4, 5, 3)
print(ins)

# lengths array has to be sorted in a decreasing order
# The first sentence has 4 valid words, the second has 3 valid words
packed_ins = pack_padded_sequence(ins, lengths=[5, 5, 4, 2], batch_first=True)

tensor([[[-1.0817, -0.3302,  0.2718],
         [-1.0898, -0.7467, -0.9425],
         [-1.2516, -1.7303, -0.7223],
         [ 0.5260, -0.4363, -0.1890],
         [-0.4996,  1.5021, -0.0141]],

        [[ 1.1941, -0.7962, -0.3146],
         [-1.0103, -0.3654,  0.0223],
         [ 0.9876, -0.5608,  0.1105],
         [ 0.7676,  0.3664, -0.5944],
         [-0.2100, -0.9980,  1.4642]],

        [[ 0.1264, -0.5862, -2.2658],
         [ 1.0991,  0.1639, -0.4202],
         [-0.8483,  0.3627,  0.5430],
         [-1.4339,  0.5070,  0.1936],
         [ 0.6681, -1.3739,  0.6759]],

        [[ 0.8625,  1.7968, -1.5077],
         [ 0.1628, -0.6383,  0.8980],
         [-0.3837, -1.4963, -0.9825],
         [ 0.7184,  0.4402, -0.1025],
         [-1.7265, -0.4846,  0.7121]]])


In [13]:
hidden = (torch.zeros(2, 4, 3),
          torch.zeros(2, 4, 3))

packed_outs, hidden = bi_lstm(packed_ins, hidden)
packed_outs

PackedSequence(data=tensor([[ 0.0152,  0.2150, -0.0308,  0.1314,  0.0960, -0.1464],
        [-0.2240, -0.0841, -0.1037, -0.1180, -0.2355, -0.0526],
        [-0.2538, -0.1220, -0.0648, -0.0039, -0.0957, -0.1879],
        [-0.5405, -0.0363,  0.0543, -0.2315, -0.0575, -0.0245],
        [ 0.0208,  0.2373, -0.0896,  0.1984,  0.0027, -0.3197],
        [-0.1198,  0.1778, -0.0723,  0.0445,  0.0171, -0.0628],
        [-0.4691, -0.1003, -0.0544, -0.1385, -0.1249, -0.0048],
        [-0.2003,  0.1206, -0.0544, -0.0235, -0.0424,  0.0745],
        [ 0.0574,  0.2476, -0.1822,  0.1377, -0.0361, -0.2957],
        [-0.2235,  0.0816, -0.1198, -0.1725, -0.2239, -0.0040],
        [-0.3047,  0.1316,  0.0159,  0.0477,  0.1810,  0.0579],
        [-0.0945,  0.1339, -0.1542, -0.0949, -0.1047, -0.0205],
        [-0.4488,  0.0103, -0.0749, -0.1540, -0.1302, -0.0087],
        [-0.1892,  0.2238,  0.0740,  0.0636,  0.1279, -0.0020],
        [-0.3035,  0.1788,  0.0027, -0.0440,  0.1726,  0.0255],
        [-0.0621,  0

In [14]:
outs, lengths = pad_packed_sequence(packed_outs, batch_first=True)
print(outs)
print(lengths)

tensor([[[ 0.0152,  0.2150, -0.0308,  0.1314,  0.0960, -0.1464],
         [ 0.0208,  0.2373, -0.0896,  0.1984,  0.0027, -0.3197],
         [ 0.0574,  0.2476, -0.1822,  0.1377, -0.0361, -0.2957],
         [-0.0945,  0.1339, -0.1542, -0.0949, -0.1047, -0.0205],
         [-0.3035,  0.1788,  0.0027, -0.0440,  0.1726,  0.0255]],

        [[-0.2240, -0.0841, -0.1037, -0.1180, -0.2355, -0.0526],
         [-0.1198,  0.1778, -0.0723,  0.0445,  0.0171, -0.0628],
         [-0.2235,  0.0816, -0.1198, -0.1725, -0.2239, -0.0040],
         [-0.4488,  0.0103, -0.0749, -0.1540, -0.1302, -0.0087],
         [-0.0621,  0.1729, -0.1092,  0.0092,  0.0130,  0.1138]],

        [[-0.2538, -0.1220, -0.0648, -0.0039, -0.0957, -0.1879],
         [-0.4691, -0.1003, -0.0544, -0.1385, -0.1249, -0.0048],
         [-0.3047,  0.1316,  0.0159,  0.0477,  0.1810,  0.0579],
         [-0.1892,  0.2238,  0.0740,  0.0636,  0.1279, -0.0020],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.5405, 

In [15]:
hidden

(tensor([[[-0.3035,  0.1788,  0.0027],
          [-0.0621,  0.1729, -0.1092],
          [-0.1892,  0.2238,  0.0740],
          [-0.2003,  0.1206, -0.0544]],
 
         [[ 0.1314,  0.0960, -0.1464],
          [-0.1180, -0.2355, -0.0526],
          [-0.0039, -0.0957, -0.1879],
          [-0.2315, -0.0575, -0.0245]]]), tensor([[[-0.4221,  0.4370,  0.0066],
          [-0.1468,  0.3463, -0.4037],
          [-0.3316,  0.6430,  0.1905],
          [-0.3985,  0.2212, -0.1901]],
 
         [[ 0.3064,  0.1827, -0.2278],
          [-0.1824, -0.4024, -0.1472],
          [-0.0051, -0.2480, -0.6607],
          [-0.5946, -0.1037, -0.2454]]]))

## Multi-layer LSTM

In [16]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# input dim=3, output dim=3
bi_lstm = nn.LSTM(3, 3, batch_first=True, bidirectional=True, num_layers=2)

# Masked input & output
# batch-size=4, seq-length=5, emb-dim=3
ins = torch.randn(4, 5, 3)
print(ins)

# lengths array has to be sorted in a decreasing order
# The first sentence has 4 valid words, the second has 3 valid words
packed_ins = pack_padded_sequence(ins, lengths=[5, 5, 4, 2], batch_first=True)

tensor([[[ 0.4024,  1.0934,  1.0839],
         [-0.5964, -1.8560, -0.7325],
         [ 0.8716, -0.7590, -0.9404],
         [-0.4730,  0.2942,  0.4811],
         [ 0.5570,  1.5831, -0.4314]],

        [[ 0.3799,  2.2450,  0.6707],
         [-0.3891,  0.6506, -0.7701],
         [ 1.1734, -1.5961,  0.8083],
         [-0.3350,  0.9728, -1.0144],
         [-0.2099,  0.4682, -0.8873]],

        [[-0.0827,  0.8142,  0.1695],
         [ 1.3346,  0.5806,  1.2176],
         [-2.4879, -1.2186, -0.2566],
         [ 1.3973,  0.6949,  0.7323],
         [ 0.5744,  0.6738, -0.6111]],

        [[-0.5673,  1.0816, -0.4824],
         [ 0.5147, -1.1394, -0.0452],
         [-0.6360, -1.4588,  0.2359],
         [ 0.6541, -1.4844, -0.9906],
         [ 0.6978, -0.5004, -1.4782]]])


In [17]:
# Fisrt dim: num_layers * num_directions
hidden = (torch.zeros(4, 4, 3),
          torch.zeros(4, 4, 3))

packed_outs, hidden = bi_lstm(packed_ins, hidden)
packed_outs

PackedSequence(data=tensor([[ 0.0495, -0.0814,  0.0179,  0.1412,  0.0115,  0.0529],
        [ 0.0533, -0.0671,  0.0149,  0.1458,  0.0171,  0.0773],
        [ 0.0518, -0.0777,  0.0142,  0.1425,  0.0044,  0.0353],
        [ 0.0513, -0.0582,  0.0070,  0.1263,  0.0110,  0.0427],
        [ 0.0638, -0.1031,  0.0083,  0.1467,  0.0102,  0.0579],
        [ 0.0805, -0.1074,  0.0009,  0.1432,  0.0221,  0.0816],
        [ 0.0796, -0.1336,  0.0126,  0.1201, -0.0009,  0.0262],
        [ 0.0698, -0.0975, -0.0036,  0.0851,  0.0110,  0.0277],
        [ 0.0769, -0.1179, -0.0108,  0.1334,  0.0144,  0.0644],
        [ 0.0789, -0.1469, -0.0198,  0.1421,  0.0231,  0.0682],
        [ 0.0930, -0.1337,  0.0247,  0.1244, -0.0135,  0.0060],
        [ 0.0886, -0.1464, -0.0114,  0.1062,  0.0130,  0.0482],
        [ 0.0954, -0.1313, -0.0337,  0.1138,  0.0119,  0.0709],
        [ 0.1057, -0.1511,  0.0091,  0.0794,  0.0042,  0.0238],
        [ 0.1072, -0.1458, -0.0253,  0.0728,  0.0103,  0.0502],
        [ 0.1006, -0

In [18]:
outs, lengths = pad_packed_sequence(packed_outs, batch_first=True)
print(outs)
print(lengths)

tensor([[[ 0.0495, -0.0814,  0.0179,  0.1412,  0.0115,  0.0529],
         [ 0.0638, -0.1031,  0.0083,  0.1467,  0.0102,  0.0579],
         [ 0.0769, -0.1179, -0.0108,  0.1334,  0.0144,  0.0644],
         [ 0.0886, -0.1464, -0.0114,  0.1062,  0.0130,  0.0482],
         [ 0.1072, -0.1458, -0.0253,  0.0728,  0.0103,  0.0502]],

        [[ 0.0533, -0.0671,  0.0149,  0.1458,  0.0171,  0.0773],
         [ 0.0805, -0.1074,  0.0009,  0.1432,  0.0221,  0.0816],
         [ 0.0789, -0.1469, -0.0198,  0.1421,  0.0231,  0.0682],
         [ 0.0954, -0.1313, -0.0337,  0.1138,  0.0119,  0.0709],
         [ 0.1006, -0.1246, -0.0345,  0.0776,  0.0086,  0.0403]],

        [[ 0.0518, -0.0777,  0.0142,  0.1425,  0.0044,  0.0353],
         [ 0.0796, -0.1336,  0.0126,  0.1201, -0.0009,  0.0262],
         [ 0.0930, -0.1337,  0.0247,  0.1244, -0.0135,  0.0060],
         [ 0.1057, -0.1511,  0.0091,  0.0794,  0.0042,  0.0238],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0513, 

In [19]:
hidden

(tensor([[[ 0.4615, -0.4215,  0.0138],
          [ 0.3393, -0.3394, -0.0693],
          [ 0.2398, -0.3425,  0.0323],
          [ 0.2457, -0.2500, -0.1831]],
 
         [[ 0.0051, -0.2633,  0.3119],
          [-0.0012, -0.1913,  0.1647],
          [ 0.0229, -0.2508,  0.2479],
          [-0.0173, -0.2054,  0.1188]],
 
         [[ 0.1072, -0.1458, -0.0253],
          [ 0.1006, -0.1246, -0.0345],
          [ 0.1057, -0.1511,  0.0091],
          [ 0.0698, -0.0975, -0.0036]],
 
         [[ 0.1412,  0.0115,  0.0529],
          [ 0.1458,  0.0171,  0.0773],
          [ 0.1425,  0.0044,  0.0353],
          [ 0.1263,  0.0110,  0.0427]]]), tensor([[[ 0.9150, -0.7301,  0.0473],
          [ 0.5388, -0.5898, -0.2358],
          [ 0.5136, -0.8072,  0.0817],
          [ 0.4390, -0.7334, -0.4809]],
 
         [[ 0.0211, -0.3232,  0.6118],
          [-0.0061, -0.2208,  0.3198],
          [ 0.0578, -0.3454,  0.4544],
          [-0.0329, -0.2898,  0.2157]],
 
         [[ 0.2166, -0.3387, -0.0453],
        

# An LSTM for Part-of-Speech Tagging

In [20]:
# data
train_data = [
    ("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
    ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])
]

word2idx = {}
tag2idx = {}
for sent, tags in train_data:
    for word in sent:
        if word not in word2idx:
            word2idx[word] = len(word2idx)
    for tag in tags:
        if tag not in tag2idx:
            tag2idx[tag] = len(tag2idx)
print(word2idx)
print(tag2idx)

EMB_DIM = 6
HIDDEN_DIM = 6
VOC_SIZE = len(word2idx)
TAGSET_SIZE = len(tag2idx)

{'The': 0, 'dog': 1, 'ate': 2, 'the': 3, 'apple': 4, 'Everybody': 5, 'read': 6, 'that': 7, 'book': 8}
{'DET': 0, 'NN': 1, 'V': 2}


In [21]:
class LSTMTagger(nn.Module):
    def __init__(self, emb_dim, hidden_dim, voc_size, tagset_size):
        super(LSTMTagger, self).__init__()        
        self.word_emb = nn.Embedding(voc_size, emb_dim)
        self.lstm = nn.LSTM(emb_dim, hidden_dim)        
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        
        self.hidden_0 = (torch.zeros(1, 1, hidden_dim),
                         torch.zeros(1, 1, hidden_dim))
        
    def forward(self, sent):
        emb = self.word_emb(sent)
        lstm_outs, hidden = self.lstm(emb.view(len(sent), 1, -1), self.hidden_0)
        
        tag_space = self.hidden2tag(lstm_outs.view(len(sent), -1))
        tag_scores = F.log_softmax(tag_space, dim=-1)
        return tag_scores
    
model = LSTMTagger(EMB_DIM, HIDDEN_DIM, VOC_SIZE, TAGSET_SIZE)
loss_func = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

sent_idxes = [word2idx[w] for w in train_data[0][0]]
sent_ins = torch.tensor(sent_idxes, dtype=torch.long)
tag_scores = model(sent_ins)
print(tag_scores)

tensor([[-1.1640, -0.8494, -1.3467],
        [-1.0539, -0.9754, -1.2933],
        [-1.1063, -0.8800, -1.3688],
        [-1.2029, -0.7195, -1.5479],
        [-1.1558, -0.7911, -1.4616]])


In [22]:
for epoch in range(300):
    for sent, tags in train_data:
        model.zero_grad()
        
        sent_idxes = [word2idx[w] for w in sent]
        sent_ins = torch.tensor(sent_idxes, dtype=torch.long)
        tag_idxes = [tag2idx[tag] for tag in tags]
        targets = torch.tensor(tag_idxes, dtype=torch.long)
        
        tag_scores = model(sent_ins)
        loss = loss_func(tag_scores, targets)
        loss.backward()
        optimizer.step()
        
sent_idxes = [word2idx[w] for w in train_data[0][0]]
sent_ins = torch.tensor(sent_idxes, dtype=torch.long)
tag_scores = model(sent_ins)
print(tag_scores)

tensor([[-0.1799, -2.6442, -2.3693],
        [-4.9944, -0.0135, -5.0169],
        [-2.4991, -4.7163, -0.0955],
        [-0.0562, -3.4060, -3.8390],
        [-3.9440, -0.0213, -6.3509]])


In [23]:
targets

tensor([ 1,  2,  0,  1])