Permalink
Browse files

Fix snli device to work with newest torchtext (#394)

* fix snli device to work with newest torchtext

* fix some warnings too
  • Loading branch information...
SsnL authored and soumith committed Jul 26, 2018
1 parent b5f3612 commit 75e7c75469d21cb76ada070058889124b850a632
Showing with 6 additions and 5 deletions.
  1. +1 −1 mnist/main.py
  2. +1 −1 mnist_hogwild/train.py
  3. +4 −3 snli/train.py
@@ -46,7 +46,7 @@ def test(args, model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

@@ -49,7 +49,7 @@ def test_epoch(model, data_loader):
with torch.no_grad():
for data, target in data_loader:
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target).sum().item()

@@ -15,6 +15,7 @@

args = get_args()
torch.cuda.set_device(args.gpu)
device = torch.device('cuda:{}'.format(args.gpu))

inputs = data.Field(lower=args.lower)
answers = data.Field(sequential=False)
@@ -32,7 +33,7 @@
answers.build_vocab(train)

train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train, dev, test), batch_size=args.batch_size, device=args.gpu)
(train, dev, test), batch_size=args.batch_size, device=device)

config = args
config.n_embed = len(inputs.vocab)
@@ -44,12 +45,12 @@
config.n_cells *= 2

if args.resume_snapshot:
model = torch.load(args.resume_snapshot, map_location=lambda storage, locatoin: storage.cuda(args.gpu))
model = torch.load(args.resume_snapshot, map_location=device)
else:
model = SNLIClassifier(config)
if args.word_vectors:
model.embed.weight.data.copy_(inputs.vocab.vectors)
model.cuda(args.gpu)
model.to(device)

criterion = nn.CrossEntropyLoss()
opt = O.Adam(model.parameters(), lr=args.lr)

0 comments on commit 75e7c75

Please sign in to comment.