In [53]:
import os
import torch
import torch.nn.functional as F
import torchtext.data as data
from src import obj_dict, movie_reviews_dataset, model_cnn, train

args = obj_dict.objdict({
    'batch_size': 1024,
    'cuda': False and torch.cuda.is_available(),
    'epochs': 250,
    'save_dir': './snapshots/cnn/',
    'static': False,
    'kernel_sizes': [3,4,5],
    'embed_dim': 300,
    'kernel_num': 100,
    'dropout': 0.5,
    'lr': 0.001,
    'log_interval': 10,
    'test_interval': 100,
    'save_interval': 1000
})

Setup Movie Reviews data set loader

In [54]:
# load MR dataset
def mr(text_field, label_field, **kargs):
    train_data, dev_data = movie_reviews_dataset.MR.splits(text_field, label_field, root='./data')
    text_field.build_vocab(train_data, dev_data)
    label_field.build_vocab(train_data, dev_data)
    train_iter, dev_iter = data.Iterator.splits(
        (train_data, dev_data),
        batch_sizes=(args.batch_size, len(dev_data)),
        **kargs)
    return train_iter, dev_iter

print(args)

# load data
print("\nLoading data...")
text_field = data.Field(lower=True)
label_field = data.Field(sequential=False)
train_iter, dev_iter = mr(text_field, label_field, device=args.device if args.cuda else -1, repeat=False)
print("Loaded", len(text_field.vocab), "samples")

{'batch_size': 1024, 'cuda': False, 'epochs': 250, 'save_dir': './snapshots/cnn/', 'static': False, 'kernel_sizes': [3, 4, 5], 'embed_dim': 300, 'kernel_num': 100, 'dropout': 0.5, 'lr': 0.001, 'log_interval': 10, 'test_interval': 100, 'save_interval': 1000}

Loading data...


Loaded 21109 samples


In [55]:
args.embed_num = len(text_field.vocab)
args.class_num = len(label_field.vocab) - 1

args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))

# model
cnn = model_cnn.CNN_Text(args)

train.train(train_iter, dev_iter, cnn, args)

Batch[10] - loss: 0.785570  acc: 50.2632%(191/380)

Batch[20] - loss: 0.772005  acc: 47.8947%(182/380)

Batch[30] - loss: 0.748966  acc: 51.5789%(196/380)

Batch[40] - loss: 0.693711  acc: 57.6316%(219/380)

Batch[50] - loss: 0.735481  acc: 48.9474%(186/380)

Batch[60] - loss: 0.719329  acc: 53.6842%(204/380)

KeyboardInterrupt: 

In [None]:
# train or predict
if args.predict is not None:
    label = train2.predict(args.predict, cnn, text_field, label_field)
    print('\n[Text]  {}[Label] {}\n'.format(args.predict, label))
elif args.test:
    try:
        train2.eval(test_iter, cnn, args)
    except Exception as e:
        print("\nSorry. The test dataset doesn't  exist.\n")
else:
    print()