-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: GentleSmile <478691929@qq.com>
- Loading branch information
Showing
14 changed files
with
886 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
## LM -- a tensorflow implementation | ||
|
||
|
||
|
||
An implementation of LM(language model). | ||
|
||
|
||
|
||
### Require Packages | ||
|
||
- cotk | ||
- TensorFlow == 1.13.1 | ||
- TensorBoardX >= 1.4 | ||
|
||
|
||
|
||
### Quick Start | ||
|
||
- Downloading dataset and save it to ``./data``. (Dataset will be released soon.) | ||
- Execute ``python run.py`` to train the model. | ||
- The default dataset is ``MSCOCO``. You can use ``--dataset`` to specify other ``dataloader`` class. | ||
- It use `gloves` pretrained word vector by default setting. You can use ``--wvclass`` to specify ``wordvector`` class. | ||
- If you don't have GPUs, you can add `--cpu` for switching to CPU, but it may cost very long time. | ||
- You can view training process by tensorboard, the log is at `./tensorboard`. | ||
- For example, ``tensorboard --logdir=./tensorboard``. (You have to install tensorboard first.) | ||
- After training, execute ``python run.py --mode test --restore best`` for test. | ||
- You can use ``--restore filename`` to specify checkpoints files, which are in ``./model``. | ||
- ``--restore last`` means last checkpoint, ``--restore best`` means best checkpoints on dev. | ||
- Find results at ``./output``. | ||
|
||
|
||
|
||
### Arguments | ||
|
||
``` | ||
usage: run.py [-h] [--name NAME] [--restore RESTORE] [--mode MODE] | ||
[--dataset DATASET] [--datapath DATAPATH] [--epoch EPOCH] | ||
[--wvclass WVCLASS] [--wvpath WVPATH] [--out_dir OUT_DIR] | ||
[--log_dir LOG_DIR] [--model_dir MODEL_DIR] | ||
[--cache_dir CACHE_DIR] [--cpu] [--debug] [--cache] | ||
A language model | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--name NAME The name of your model, used for variable scope and | ||
tensorboard, etc. Default: runXXXXXX_XXXXXX | ||
(initialized by current time) | ||
--restore RESTORE Checkpoints name to load. "last" for last checkpoints, | ||
"best" for best checkpoints on dev. Attention: "last" | ||
and "best" wiil cause unexpected behaviour when run 2 | ||
models in the same dir at the same time. Default: None | ||
(don't load anything) | ||
--mode MODE "train" or "test". Default: train | ||
--dataset DATASET Dataloader class. Default: MSCOCO | ||
--datapath DATAPATH Directory for data set. Default: ./data | ||
--epoch EPOCH Epoch for trainning. Default: 10 | ||
--wvclass WVCLASS Wordvector class, None for using Glove pretrained | ||
wordvec. Default: None | ||
--wvpath WVPATH Path for pretrained wordvector. Default: wordvec | ||
--out_dir OUT_DIR Output directory for test output. Default: ./output | ||
--log_dir LOG_DIR Log directory for tensorboard. Default: ./tensorboard | ||
--model_dir MODEL_DIR | ||
Checkpoints directory for model. Default: ./model | ||
--cache_dir CACHE_DIR | ||
Checkpoints directory for cache. Default: ./cache | ||
--cpu Use cpu. | ||
--debug Enter debug mode (using ptvsd). | ||
--cache Use cache for speeding up load data and wordvec. (It | ||
may cause problems when you switch dataset.) | ||
``` | ||
|
||
For hyperparameter settings, please refer to `run.py`. | ||
|
||
|
||
|
||
#### For developer | ||
|
||
- Arguments above (except ``cache``\\``debug``) are required. You should remain the same behavior (not for implementation). | ||
- You can add more arguments if you want. | ||
|
||
|
||
|
||
### An example of tensorboard | ||
|
||
Execute ``tensorboard --logdir=./tensorboard``, you will see the plot in tensorboard pages: | ||
|
||
Following plot are shown in this model: | ||
|
||
- loss: reconstruction loss. | ||
|
||
![loss](image/loss.png) | ||
|
||
- perplexity: reconstruction perplexity. | ||
|
||
![perplexity](image/perplexity.png) | ||
|
||
|
||
And text output: | ||
|
||
```{ | ||
{ | ||
"epochs": 10, | ||
"lr": 0.1, | ||
"log_dir": "./tensorboard", | ||
"name": "LM", | ||
"max_sen_length": 50, | ||
"checkpoint_max_to_keep": 5, | ||
"embedding_size": 300, | ||
"momentum": 0.9, | ||
"checkpoint_steps": 1000, | ||
"datapath": "resources://MSCOCO~tsinghua", | ||
"cache": false, | ||
"debug": false, | ||
"wvclass": null, | ||
"restore": "last", | ||
"show_sample": [ | ||
0 | ||
], | ||
"wvpath": null, | ||
"dh_size": 200, | ||
"batch_size": 128, | ||
"lr_decay": 0.995, | ||
"model_dir": "./model", | ||
"out_dir": "./output", | ||
"cache_dir": "./cache", | ||
"softmax_samples": 512, | ||
"mode": "train", | ||
"grad_clip": 5.0, | ||
"dataset": "MSCOCO", | ||
"cuda": true | ||
} | ||
``` | ||
|
||
|
||
|
||
Following text are shown in this model: | ||
|
||
- args | ||
|
||
|
||
|
||
### An example of test output | ||
|
||
Execute ``python run.py --mode test --restore last`` | ||
|
||
The following output will be in `./output/[name]_test.txt`: | ||
|
||
``` | ||
self-bleu-3: 0.709417 | ||
bw-bleu-3: 0.513164 | ||
self-bleu-4: 0.515631 | ||
bw-bleu-4: 0.336216 | ||
self-bleu-2: 0.870640 | ||
fw-bw-bleu-3: 0.550495 | ||
perplexity: 13.409582 | ||
fw-bleu-4: 0.371952 | ||
fw-bleu-2: 0.836639 | ||
fw-bw-bleu-4: 0.353182 | ||
fw-bleu-3: 0.593684 | ||
bw-bleu-2: 0.723493 | ||
fw-bw-bleu-2: 0.775963 | ||
A man and motorcycle parked on front of a building . | ||
A old desk with a computers computers of computers . | ||
A people in to a table in to a tree sign | ||
A man dog with a in in a bathroom bathroom . | ||
A man young plate to be served by a people . | ||
A man is on a bench next front park of a woods . | ||
A man is a piece feeder 's a in | ||
A man of a old man in a small . | ||
A man is on a bench next looking dog is on her fence . the park station . | ||
A old man is standing a orange tie . | ||
A man is in a dog in to a . | ||
A old desk a and and a and desk . | ||
A of and and a black and a . | ||
A man is a man sitting to a table sign | ||
A up of a plate plate with with a stove pot . the . | ||
A man plate top oven in of kitchen . | ||
A man is on a small cars in a street field . | ||
A man of people are down a street lined street . | ||
A man is in two legs and in a couch . a . | ||
A man is two legs of a person on out a television . | ||
A in a helmet a on front of a building of people . | ||
A man is standing on a dog on a laptop . | ||
... | ||
``` | ||
|
||
#### For developer | ||
|
||
- You should remain similar output in this task. | ||
|
||
|
||
|
||
### Performance | ||
|
||
| | Reconstruction Perplexity | | ||
| ------ | ------------------------- | | ||
| MSCOCO | 13.41 | | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from cotk.dataloader import LanguageGeneration | ||
from cotk.wordvector import WordVector, Glove | ||
from utils import debug, try_cache | ||
|
||
from model import LMModel | ||
|
||
def create_model(sess, data, args, embed): | ||
with tf.variable_scope(args.name): | ||
model = LMModel(data, args, embed) | ||
model.print_parameters() | ||
latest_dir = '%s/checkpoint_latest' % args.model_dir | ||
best_dir = '%s/checkpoint_best' % args.model_dir | ||
if tf.train.get_checkpoint_state(latest_dir) and args.restore == "last": | ||
print("Reading model parameters from %s" % latest_dir) | ||
model.latest_saver.restore(sess, tf.train.latest_checkpoint(latest_dir)) | ||
else: | ||
if tf.train.get_checkpoint_state(best_dir) and args.restore == "best": | ||
print('Reading model parameters from %s' % best_dir) | ||
model.best_saver.restore(sess, tf.train.latest_checkpoint(best_dir)) | ||
else: | ||
print("Created model with fresh parameters.") | ||
global_variable = [gv for gv in tf.global_variables() if args.name in gv.name] | ||
sess.run(tf.variables_initializer(global_variable)) | ||
|
||
return model | ||
|
||
|
||
def main(args): | ||
if args.debug: | ||
debug() | ||
|
||
if args.cuda: | ||
config = tf.ConfigProto() | ||
config.gpu_options.allow_growth = True | ||
else: | ||
config = tf.ConfigProto(device_count={'GPU': 0}) | ||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
|
||
data_class = LanguageGeneration.load_class(args.dataset) | ||
wordvec_class = WordVector.load_class(args.wvclass) | ||
if wordvec_class == None: | ||
wordvec_class = Glove | ||
if args.cache: | ||
data = try_cache(data_class, (args.datapath,), args.cache_dir) | ||
vocab = data.vocab_list | ||
embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load(ez, vl), | ||
(args.wvpath, args.embedding_size, vocab), | ||
args.cache_dir, wordvec_class.__name__) | ||
else: | ||
data = data_class(args.datapath) | ||
wv = wordvec_class(args.wvpath) | ||
vocab = data.vocab_list | ||
embed = wv.load(args.embedding_size, vocab) | ||
|
||
embed = np.array(embed, dtype = np.float32) | ||
|
||
with tf.Session(config=config) as sess: | ||
model = create_model(sess, data, args, embed) | ||
if args.mode == "train": | ||
model.train_process(sess, data, args) | ||
else: | ||
model.test_process(sess, data, args) |
Oops, something went wrong.