Skip to content

Commit

Permalink
Update generate_batch function and add a run script. (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangguanheng66 authored and cpuhrsch committed Aug 1, 2019
1 parent 33b48bc commit 0230b2b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
6 changes: 6 additions & 0 deletions examples/text_classification/run_script.sh
@@ -0,0 +1,6 @@
if [ ! -d ".data" ]; then
mkdir .data
fi

python train.py AG_NEWS --device cuda --save-model-path model.i --dictionary vocab.i
cut -f 2- -d "," .data/ag_news_csv/test.csv | python predict.py model.i vocab.i > predict_script.o
54 changes: 34 additions & 20 deletions examples/text_classification/train.py
Expand Up @@ -9,6 +9,7 @@
from torch.utils.data import DataLoader

from model import TextSentiment
from torch.utils.data.dataset import random_split

r"""
This file shows the training process of the text classification model.
Expand All @@ -34,29 +35,23 @@ def generate_batch(batch):
index of the individual sequence in the text tensor.
cls: a tensor saving the labels of individual text entries.
"""
def generate_offsets(data_batch):
offsets = [0]
for entry in data_batch:
offsets.append(offsets[-1] + len(entry))
offsets = torch.tensor(offsets[:-1])
return offsets

cls = torch.tensor([entry[0] for entry in batch])
label = torch.tensor([entry[0] for entry in batch])
text = [entry[1] for entry in batch]
offsets = generate_offsets(text)
offsets = [0] + [len(entry) for entry in text]
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text = torch.cat(text)
return text, offsets, cls
return text, offsets, label


r"""
torch.utils.data.DataLoader is recommended for PyTorch users to load data.
We use DataLoader here to load datasets and send it to the train()
We use DataLoader here to load datasets and send it to the train_and_valid()
and text() functions.
"""


def train(lr_, num_epoch, data_):
def train_and_valid(lr_, num_epoch, sub_train_, sub_valid_):
r"""
We use a SGD optimizer to train the model here and the learning rate
decreases linearly with the progress of the training process.
Expand All @@ -67,13 +62,16 @@ def train(lr_, num_epoch, data_):
data_: the data used to train the model
"""

data = DataLoader(data_, batch_size=batch_size, shuffle=True,
collate_fn=generate_batch, num_workers=args.num_workers)
optimizer = torch.optim.SGD(model.parameters(), lr=lr_)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=args.lr_gamma)
num_lines = num_epochs * len(data)
train_data = DataLoader(sub_train_, batch_size=batch_size, shuffle=True,
collate_fn=generate_batch, num_workers=args.num_workers)
num_lines = num_epochs * len(train_data) * split_ratio

for epoch in range(num_epochs):
for i, (text, offsets, cls) in enumerate(data):

# Train the model
for i, (text, offsets, cls) in enumerate(train_data):
optimizer.zero_grad()
text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)
output = model(text, offsets)
Expand All @@ -88,7 +86,10 @@ def train(lr_, num_epoch, data_):
progress * 100, scheduler.get_lr()[0], loss))
# Adjust the learning rate
scheduler.step()
print("")

# Test the model on valid set
print("")
print("Valid - Accuracy: {}".format(test(sub_valid_)))


def test(data_):
Expand All @@ -104,7 +105,7 @@ def test(data_):
output = model(text, offsets)
accuracy = (output.argmax(1) == cls).float().mean().item()
total_accuracy.append(accuracy)
print("Test - Accuracy: {}".format(sum(total_accuracy) / len(total_accuracy)))
return sum(total_accuracy) / len(total_accuracy)


if __name__ == "__main__":
Expand All @@ -117,6 +118,8 @@ def test(data_):
help='embed dim. (default=128)')
parser.add_argument('--batch-size', type=int, default=64,
help='batch size (default=64)')
parser.add_argument('--split-ratio', type=float, default=0.95,
help='train/valid split ratio (default=0.95)')
parser.add_argument('--lr', type=float, default=4.0,
help='learning rate (default=4.0)')
parser.add_argument('--lr-gamma', type=float, default=0.8,
Expand All @@ -129,6 +132,8 @@ def test(data_):
help='device (default=cpu)')
parser.add_argument('--data', default='.data',
help='data directory (default=.data)')
parser.add_argument('--dictionary',
help='path to save vocab')
parser.add_argument('--save-model-path')
parser.add_argument('--logging-level', default='WARNING',
help='logging level (default=WARNING)')
Expand All @@ -140,6 +145,7 @@ def test(data_):
lr = args.lr
device = args.device
data = args.data
split_ratio = args.split_ratio

logging.basicConfig(level=getattr(logging, args.logging_level))

Expand All @@ -154,9 +160,17 @@ def test(data_):
embed_dim, len(train_dataset.get_labels())).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)

train(lr, num_epochs, train_dataset)
test(test_dataset)
# split train_dataset into train and valid
train_len = int(len(train_dataset) * split_ratio)
sub_train_, sub_valid_ = \
random_split(train_dataset, [train_len, len(train_dataset) - train_len])
train_and_valid(lr, num_epochs, sub_train_, sub_valid_)
print("Test - Accuracy: {}".format(test(test_dataset)))

if args.save_model_path:
print("Saving model to {}".format(args.save_model_path))
torch.save(model.to('cpu'), args.save_model_path)

if args.dictionary is not None:
print("Save vocab to {}".format(args.dictionary))
torch.save(train_dataset.get_vocab(), args.dictionary)

0 comments on commit 0230b2b

Please sign in to comment.