In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x4a694fc470>

In [3]:
lstm = nn.LSTM(3, 3)
inputs = [torch.randn(1, 3) for _ in range(5)]

hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))

for i in inputs:
    out, hidden = lstm(i.view(1, 1, -1), hidden)

inputs = torch.cat(inputs).view(len(inputs), 1, -1)

hidden = (torch.randn(1, 1, 3), torch.randn(1, 1, 3))  # clean out hidden state
out, hidden = lstm(inputs, hidden)
print(out)
print(hidden)

tensor([[[-0.0187,  0.1713, -0.2944]],

        [[-0.3521,  0.1026, -0.2971]],

        [[-0.3191,  0.0781, -0.1957]],

        [[-0.1634,  0.0941, -0.1637]],

        [[-0.3368,  0.0959, -0.0538]]])
(tensor([[[-0.3368,  0.0959, -0.0538]]]), tensor([[[-0.9825,  0.4715, -0.0633]]]))


In [4]:
class LSTMNumSeq(nn.Module):
    
    def __init__(self,embedding_dim, hidden_dim, num_size):
        super(LSTMNumSeq, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.word_embeddings = nn.Embedding(num_size, embedding_dim)
        
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, num_size)
        self.hidden = self.init_hidden()
        
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_dim),
                torch.zeros(1, 1, self.hidden_dim))
    
    def forward(self, seq):
        embeds = self.word_embeddings(seq)
        lstm_out, self.hidden = self.lstm(embeds.view(len(seq), 1, -1), self.hidden)
        tag_space = self.hidden2tag(lstm_out.view(len(seq), -1))
        out_seq = F.log_softmax(tag_space, dim=1)
        return out_seq
        

In [5]:
input_seq = [("012345678","123456789"),("123456789","234567890"),("234567890","345678901"),("345678901","456789012"),("456789012","567890123")]
#[("0123456789","1234567890"),("1234567890","2345678901"),("2345678901","3456789012"),("3456789012","4567890123"),("4567890123","5678901234")]
num_to_ix = {"0": 0, "1":1,"2":2,"3":3,"4":4,"5":5,"6":6,"7":7,"8":8,"9":9}

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[seq[w]] for w in range(len(seq))]
    return torch.tensor(idxs, dtype = torch.long)


In [6]:
model = LSTMNumSeq(6, 6, 10)
loss_function = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.01)

#with torch.no_grad():
#    inputs = prepare_sequence(input_seq[0][0], num_to_ix)
#    output_seq = model(inputs)
    #print(output_seq)
    
for epoch in range(3000):
    for sequence, ty in input_seq:
        model.zero_grad()
        
        model.hidden = model.init_hidden()
        
        seq_in = prepare_sequence(sequence, num_to_ix)
        target = prepare_sequence(ty, num_to_ix)
        
        out = model(seq_in)
        
        loss = loss_function(out, target)
        loss.backward()
        optimizer.step()
        
#with torch.no_grad():
#    inputs = prepare_sequence(input_seq[0][0], num_to_ix)
#    output_seq = model(inputs)
#    print(output_seq)

In [7]:
#print(torch.argmax(output_seq, dim = 1))
with torch.no_grad():
    inputs = prepare_sequence("612359048", num_to_ix)
    output_seq = model(inputs)
    #print(output_seq)
    
output = ""
for i in torch.argmax(output_seq, dim = 1):
    output += str(num_to_ix[str(i.item())])

print(output)

323464059


In [12]:
word_embeddings = nn.Embedding(10, 6)

seq = "012345678"
idxs = [num_to_ix[seq[w]] for w in range(len(seq))]
tseq = torch.tensor(idxs, dtype = torch.long)

embeds = word_embeddings(tseq)
print(embeds)
print(embeds.view(len(tseq), 1, -1))

tensor([[ 0.7068, -1.8458, -0.8011, -1.5776, -0.9171, -0.2311],
        [ 1.0470, -1.5918,  0.0556, -0.6261, -0.5794, -0.5948],
        [ 0.0714,  0.3420,  0.8866, -0.8954,  0.0848,  0.2620],
        [-0.9102, -0.1423,  0.2989,  1.4571,  0.2304, -0.1479],
        [-0.5929, -0.3364, -0.0321, -0.5684, -1.4244, -1.3247],
        [-2.0823, -0.6323, -0.5450,  0.3116,  0.5931,  1.8194],
        [-0.8792, -1.1781,  0.2504,  0.3679,  0.6677,  0.6348],
        [-0.5083,  1.2397, -0.2392,  1.7528, -1.3151,  0.1726],
        [-0.0877, -0.4218, -1.1414, -1.9074, -0.0156,  1.0395]])
tensor([[[ 0.7068, -1.8458, -0.8011, -1.5776, -0.9171, -0.2311]],

        [[ 1.0470, -1.5918,  0.0556, -0.6261, -0.5794, -0.5948]],

        [[ 0.0714,  0.3420,  0.8866, -0.8954,  0.0848,  0.2620]],

        [[-0.9102, -0.1423,  0.2989,  1.4571,  0.2304, -0.1479]],

        [[-0.5929, -0.3364, -0.0321, -0.5684, -1.4244, -1.3247]],

        [[-2.0823, -0.6323, -0.5450,  0.3116,  0.5931,  1.8194]],

        [[-0.8792, -1

In [15]:
x = torch.randn(2,2)
print(x)
y = x.view(2,1,-1)
print(y)

tensor([[ 0.4298, -0.3652],
        [-0.7078,  0.2642]])
tensor([[[ 0.4298, -0.3652]],

        [[-0.7078,  0.2642]]])
