In [1]:
# Lab 12 Character Sequence RNN
import torch
import torch.nn as nn
from torch.autograd import Variable

torch.manual_seed(777)  # reproducibility

sample = "hello"
idx2char = list(set(sample))  # index -> char
char2idx = {c: i for i, c in enumerate(idx2char)}  # char -> index

In [2]:
idx2char

['o', 'h', 'e', 'l']

In [3]:
char2idx

{'e': 2, 'h': 1, 'l': 3, 'o': 0}

In [7]:
# hyperparameters
learning_rate = 0.1
num_epochs = 50
input_size = len(char2idx)  # RNN input size (one hot size)
hidden_size = len(char2idx)  # RNN output size
num_classes = len(char2idx)  # final output size (RNN or softmax, etc.)
batch_size = 1  # one sample data, one batch
sequence_length = len(sample) - 1  # number of lstm rollings (unit #)
num_layers = 1  # number of layers in RNN

sample_idx = [char2idx[c] for c in sample]  # char to index
x_data = [sample_idx[:-1]]  # X data sample (0 ~ n-1) hello: hell
y_data = [sample_idx[1:]]   # Y label sample (1 ~ n) hello: ello

x_data = torch.Tensor(x_data)
y_data = torch.LongTensor(y_data)

In [8]:
x_data, y_data

(
  1  2  3  3
 [torch.FloatTensor of size 1x4], 
  2  3  3  0
 [torch.LongTensor of size 1x4])

In [9]:
def one_hot(x, num_classes):
    idx = x.long()
    idx = idx.view(-1, 1)
    x_one_hot = torch.zeros(x.size()[0] * x.size()[1], num_classes)
    x_one_hot.scatter_(1, idx, 1)
    x_one_hot = x_one_hot.view(x.size()[0], x.size()[1], num_classes)
    return x_one_hot


In [12]:
x_one_hot = one_hot(x_data, num_classes)
print(x_one_hot)
print(y_data)
inputs = Variable(x_one_hot)
labels = Variable(y_data)


(0 ,.,.) = 
  0  1  0  0
  0  0  1  0
  0  0  0  1
  0  0  0  1
[torch.FloatTensor of size 1x4x4]


 2  3  3  0
[torch.LongTensor of size 1x4]



In [116]:
class LSTM(nn.Module):

    def __init__(self, num_classes, input_size, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.num_classes = num_classes
        self.num_layers = num_layers
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.sequence_length = sequence_length
        # Set parameters for RNN block
        # Note: batch_first=False by default.
        # When true, inputs are (batch_size, sequence_length, input_dimension)
        # instead of (sequence_length, batch_size, input_dimension)
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size,
                            num_layers=num_layers, batch_first=True)
        # Fully connected layer to obtain outputs corresponding to the number
        # of classes
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # Initialize hidden and cell states
        h_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
        c_0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))

        # Reshape input
        x.view(x.size(0), self.sequence_length, self.input_size)

        # Propagate input through RNN
        # Input: (batch, seq_len, input_size)
        # h_0: (num_layers * num_directions, batch, hidden_size)
        out, _ = self.lstm(x, (h_0, c_0))

        # Reshape output from (batch, seq_len, hidden_size) to (batch *
        # seq_len, hidden_size)
        out = out.view(-1, self.hidden_size)
        # Return outputs applied to fully connected layer
        out = self.fc(out)
        return out


In [117]:

# Instantiate RNN model
lstm = LSTM(num_classes, input_size, hidden_size, num_layers)

# Set loss and optimizer function
criterion = torch.nn.CrossEntropyLoss()    # Softmax is internally computed.
optimizer = torch.optim.Adam(lstm.parameters(), lr=learning_rate)


In [118]:
# Train the model
for epoch in range(num_epochs):
    outputs = lstm(inputs)
    optimizer.zero_grad()
    loss = criterion(outputs, labels.view(-1))
    loss.backward()
    optimizer.step()
    _, idx = outputs.max(1)
    idx = idx.data.numpy()
    result_str = [idx2char[c] for c in idx.squeeze()]
    print("epoch: %d, loss: %1.3f" % (epoch + 1, loss.data[0]))
    print("Predicted string: ", ''.join(result_str))

print("Learning finished!")

epoch: 1, loss: 1.239
Predicted string:  eeee
epoch: 2, loss: 1.112
Predicted string:  llll
epoch: 3, loss: 1.019
Predicted string:  llll
epoch: 4, loss: 0.924
Predicted string:  llll
epoch: 5, loss: 0.821
Predicted string:  lllo
epoch: 6, loss: 0.719
Predicted string:  lllo
epoch: 7, loss: 0.608
Predicted string:  lllo
epoch: 8, loss: 0.488
Predicted string:  lllo
epoch: 9, loss: 0.381
Predicted string:  ello
epoch: 10, loss: 0.287
Predicted string:  ello
epoch: 11, loss: 0.218
Predicted string:  ello
epoch: 12, loss: 0.175
Predicted string:  ello
epoch: 13, loss: 0.135
Predicted string:  ello
epoch: 14, loss: 0.109
Predicted string:  ello
epoch: 15, loss: 0.082
Predicted string:  ello
epoch: 16, loss: 0.067
Predicted string:  ello
epoch: 17, loss: 0.052
Predicted string:  ello
epoch: 18, loss: 0.041
Predicted string:  ello
epoch: 19, loss: 0.034
Predicted string:  ello
epoch: 20, loss: 0.026
Predicted string:  ello
epoch: 21, loss: 0.022
Predicted string:  ello
epoch: 22, loss: 0.018