Skip to content
/ mgs Public

MLE-Guided Parameter Search (AAAI 2021)

Notifications You must be signed in to change notification settings

wellecks/mgs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLE-Guided Parameter Search (MGS)

PyTorch implementation of the paper:

MLE-Guided Parameter Search for Task Loss Minimization in Neural Sequence Modeling
Sean Welleck, Kyunghyun Cho
AAAI 2021

Main Logic

For a quick overview of MGS's main logic, see this section from its training loop.

Installation

python setup.py develop

Data

For downloading the datasets below, it may be helpful to use gdown.pl.

Pretrained Models

We provide an example base MLE model and example models finetuned with MGS, PG, and MRT.
Note that metrics in the paper were computed using 5 models per method, each initialized with a different random seed.

Method
MLE
MGS-LM
MGS-LM (ancestral)
PG-LM
MRT-LM

Example commands

Below we show example commands for each stage of the pipeline.
The experiments in the paper were run with a script external to this repository.

Finetune starting from MLE finetune

# MGS
python seq_level/gpt2/train.py \
  --loss ggs \
  --ggs-metric lm \
  --ggs-beta 1.0 \
  --model-load-dir /path/to/mle_model

# PG
python seq_level/gpt2/train.py \
  --loss pg \
  --ggs-metric lm \
  --pg-normalize-distance 1 \
  --pg-mle-mix 0.1 \
  --pg-baseline avg \
  --model-load-dir /path/to/mle_model
  
# MRT
python seq_level/gpt2/train.py \
  --loss mrt \
  --ggs-metrc lm \
  --mrt-normalize-distance 1 \
  --mrt-mle-mix 0.1 \
  --model-load-dir /path/to/mle_model

Finetune MLE

python seq_level/gpt2/train.py \
  --loss mle \
  --valid-every 5000 \
  --print-every 100

Evaluate

python seq_level/gpt2/train.py --mode eval \
  --eval-split valid \ # | test
  --score-model-load-dir /path/to/mle_model \
  --model-load-dir /path/to/model \
  --eval-decoder greedy \ # | temp-1.0
  --token-limit-eval 500 \
  --eval-decode-max-length 500 \
  --chunk-size-valid 512 \
  --loss ggs \
  --ggs-metric lm \

Preprocess raw wikitext

*not needed if you download the dataset above

python seq_level/gpt2/prepare_wikitext.py --data-dir /path/to/wikitext-raw

About

MLE-Guided Parameter Search (AAAI 2021)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages