Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtext import datasets

from model import SNLIClassifier
from util import get_args
from util import get_args, makedirs


args = get_args()
Expand All @@ -27,7 +27,7 @@
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
inputs.vocab.load_vectors(wv_dir=args.data_cache, wv_type=args.word_vectors, wv_dim=args.d_embed)
os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True)
makedirs(os.path.dirname(args.vector_cache))
torch.save(inputs.vocab.vectors, args.vector_cache)
answers.build_vocab(train)

Expand Down Expand Up @@ -59,7 +59,7 @@
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
os.makedirs(args.save_path, exist_ok=True)
makedirs(args.save_path)
print(header)

for epoch in range(args.epochs):
Expand Down
17 changes: 17 additions & 0 deletions snli/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
import os
from argparse import ArgumentParser

def makedirs(name):
"""helper function for python 2 and 3 to call os.makedirs()
avoiding an error if the directory to be created already exists"""

import os, errno

try:
os.makedirs(name)
except OSError as ex:
if ex.errno == errno.EEXIST and os.path.isdir(name):
# ignore existing directory
pass
else:
# a different error happened
raise


def get_args():
parser = ArgumentParser(description='PyTorch/torchtext SNLI example')
parser.add_argument('--epochs', type=int, default=50)
Expand Down