Permalink
Browse files

dedupe creating RNN

  • Loading branch information...
spro committed May 19, 2017
1 parent 1b9a484 commit 0cc55f5aaed44e7903edb8842671411301fcf003
Showing with 9 additions and 21 deletions.
  1. +9 −21 char-rnn-classification/train.py
@@ -6,7 +6,10 @@
import math
n_hidden = 128
rnn = RNN(n_letters, n_hidden, n_categories)
n_epochs = 100000
print_every = 5000
plot_every = 1000
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
def categoryFromOutput(output):
top_n, top_i = output.data.topk(1) # Tensor out of Variable with .data
@@ -16,46 +19,31 @@ def categoryFromOutput(output):
def randomChoice(l):
return l[random.randint(0, len(l) - 1)]
def randomTrainingPair():
def randomTrainingPair():
category = randomChoice(all_categories)
line = randomChoice(category_lines[category])
category_tensor = Variable(torch.LongTensor([all_categories.index(category)]))
line_tensor = Variable(lineToTensor(line))
return category, line, category_tensor, line_tensor
rnn = RNN(n_letters, n_hidden, n_categories)
optimizer = torch.optim.SGD(rnn.parameters(), lr=learning_rate)
criterion = nn.NLLLoss()
learning_rate = 0.005 # If you set this too high, it might explode. If too low, it might not learn
def train(category_tensor, line_tensor):
hidden = rnn.initHidden()
rnn.zero_grad()
optimizer.zero_grad()
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
loss = criterion(output, category_tensor)
loss.backward()
# Add parameters' gradients to their values, multiplied by learning rate
for p in rnn.parameters():
p.data.add_(-learning_rate, p.grad.data)
optimizer.step()
return output, loss.data[0]
n_epochs = 100000
print_every = 5000
plot_every = 1000
rnn = RNN(n_letters, n_hidden, n_categories)
n_epochs = 100000
print_every = 5000
plot_every = 1000
rnn = RNN(n_letters, n_hidden, n_categories)
# Keep track of losses for plotting
current_loss = 0
all_losses = []

0 comments on commit 0cc55f5

Please sign in to comment.