Skip to content

Commit

Permalink
Merge pull request #2 from yifding/bert_fine_tuning
Browse files Browse the repository at this point in the history
Bert fine tuning
  • Loading branch information
yifding committed Dec 18, 2020
2 parents d06fdf6 + 7b13178 commit 98391fe
Show file tree
Hide file tree
Showing 36 changed files with 954 additions and 59 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ support/
__pycache__/
*.so
*.cpp
BERT_DATA.py
BERT_DATA_test.py
hetseq/data/BERT_DATA.py
hetseq/data/BERT_DATA_test.py
test/script
6 changes: 3 additions & 3 deletions docs/source/dataset.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Dataset
-------

.. autoclass:: data.h5pyDataset.BertH5pyData
.. autoclass:: hetseq.data.h5pyDataset.BertH5pyData

.. autoclass:: data.h5pyDataset.ConBertH5pyData
.. autoclass:: hetseq.data.h5pyDataset.ConBertH5pyData

.. autoclass:: data.mnist_dataset.MNISTDataset
.. autoclass:: hetseq.data.mnist_dataset.MNISTDataset
18 changes: 9 additions & 9 deletions docs/source/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ BERT Task
.. code-block:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task bert --data ${DIST}/preprocessing/test_128/ \
--dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \
--config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \
Expand All @@ -27,7 +27,7 @@ BERT Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task bert --data ${DIST}/preprocessing/test_128/ \
--dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \
--config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \
Expand All @@ -45,7 +45,7 @@ BERT Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task bert --data ${DIST}/preprocessing/test_128/ \
--dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \
--config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \
Expand All @@ -61,7 +61,7 @@ BERT Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task bert --data ${DIST}/preprocessing/test_128/ \
--dict ${DIST}/preprocessing/uncased_L-12_H-768_A-12/vocab.txt \
--config_file ${DIST}/preprocessing/uncased_L-12_H-768_A-12/bert_config.json \
Expand All @@ -82,7 +82,7 @@ MNIST Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \
--data ${DIST} --clip-norm 100 \
--max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \
Expand All @@ -96,7 +96,7 @@ MNIST Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \
--data ${DIST} --clip-norm 100 \
--max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \
Expand All @@ -112,7 +112,7 @@ MNIST Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \
--data ${DIST} --clip-norm 100 \
--max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \
Expand All @@ -127,7 +127,7 @@ MNIST Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/train.py \
$ python3 ${DIST}/hetseq/train.py \
--task mnist --optimizer adadelta --lr-scheduler PolynomialDecayScheduler \
--data ${DIST} --clip-norm 100 \
--max-sentences 64 --fast-stat-sync --max-epoch 20 --update-freq 1 \
Expand All @@ -143,4 +143,4 @@ Evaluate MNIST Task
.. code:: console
$ DIST=~/hetseq
$ python3 ${DIST}/eval_mnist.py --model_ckpt /path/to/check/point --mnist_dir ${DIST}
$ python3 ${DIST}/hetseq/eval_mnist.py --model_ckpt /path/to/check/point --mnist_dir ${DIST}
2 changes: 1 addition & 1 deletion docs/source/lr_scheduler.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Learning Rate Scheduler
-----------------------

.. autoclass:: lr_scheduler.PolynomialDecayScheduler
.. autoclass:: hetseq.lr_scheduler.PolynomialDecayScheduler
6 changes: 3 additions & 3 deletions docs/source/meters.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Meters
------
.. autoclass:: meters.AverageMeter
.. autoclass:: hetseq.meters.AverageMeter

.. autoclass:: meters.TimeMeter
.. autoclass:: hetseq.meters.TimeMeter

.. autoclass:: meters.StopwatchMeter
.. autoclass:: hetseq.meters.StopwatchMeter

6 changes: 3 additions & 3 deletions docs/source/model.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Model
-----

.. autoclass:: tasks.MNISTNet
.. autoclass:: hetseq.tasks.MNISTNet

.. autoclass:: bert_modeling.BertConfig
.. autoclass:: hetseq.bert_modeling.BertConfig

.. autoclass:: bert_modeling.BertForPreTraining
.. autoclass:: hetseq.bert_modeling.BertForPreTraining
4 changes: 2 additions & 2 deletions docs/source/optimizer.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Optimizer
---------

.. autoclass:: optim.Adam
.. autoclass:: hetseq.optim.Adam

.. autoclass:: optim.Adadelta
.. autoclass:: hetseq.optim.Adadelta
6 changes: 3 additions & 3 deletions docs/source/progress_bar.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
Progress_bar
------------

.. autoclass:: progress_bar.progress_bar
.. autoclass:: hetseq.progress_bar.progress_bar

.. autoclass:: progress_bar.noop_progress_bar
.. autoclass:: hetseq.progress_bar.noop_progress_bar

.. autoclass:: progress_bar.simple_progress_bar
.. autoclass:: hetseq.progress_bar.simple_progress_bar



Expand Down
6 changes: 3 additions & 3 deletions docs/source/task.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Task
----

.. autoclass:: tasks.Task
.. autoclass:: hetseq.tasks.Task

.. autoclass:: tasks.MNISTNet
.. autoclass:: hetseq.tasks.MNISTNet

.. autoclass:: tasks.LanguageModelingTask
.. autoclass:: hetseq.tasks.LanguageModelingTask
File renamed without changes.
2 changes: 1 addition & 1 deletion bert_modeling.py → hetseq/bert_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.nn import CrossEntropyLoss
from torch.utils import checkpoint

from file_utils import cached_path
from hetseq.file_utils import cached_path

from torch.nn import Module
from torch.nn.parameter import Parameter
Expand Down
File renamed without changes.
17 changes: 9 additions & 8 deletions controller.py → hetseq/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as DDP

import utils
import optim
import lr_scheduler
import checkpoint_utils
import distributed_utils


from meters import AverageMeter, StopwatchMeter, TimeMeter
from hetseq import (
utils,
optim,
lr_scheduler,
checkpoint_utils,
distributed_utils,
)

from hetseq.meters import AverageMeter, StopwatchMeter, TimeMeter


class Controller(object):
Expand Down
121 changes: 121 additions & 0 deletions hetseq/data/BERT_DATA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import collections

import h5py
from tqdm import tqdm

import numpy as np
import torch.utils.data


# # method is too large(memory 7% took 20GB) and ~25min to load data.
class CombineBertData(torch.utils.data.Dataset):
def __init__(self, files, max_pred_length=512,
keys=('input_ids', 'input_mask', 'segment_ids',
'masked_lm_positions', 'masked_lm_ids','next_sentence_labels')):
# can potentially using multiple processing/thread to speed up reading
# if input is too large, considering other strategies to load data
self.max_pred_length = max_pred_length
self.keys = keys
self.inputs = collections.OrderedDict()

for key in self.keys:
self.inputs[key] = []

for input_file in tqdm(files):
f = h5py.File(input_file, "r")
for i, key in enumerate(keys):
if i < 5:
self.inputs[key].append(f[key][:])
else:
self.inputs[key].append(np.asarray(f[key][:]))
f.close()

for key in self.inputs:
self.inputs[key] = np.concatenate(self.inputs[key])

def __len__(self):
return len(self.inputs[self.keys[0]])

def __getitem__(self, index):
return [self.inputs[key][index] for key in self.keys]
# ['input_ids', 'input_mask', 'segment_ids','masked_lm_positions',
# 'masked_lm_ids','next_sentence_labels']


class BertTask(object):
def __init__(self, args):
self.args = self.load()
self.dict = self.load_vocab(args.dict)
self.seed = args.seed
self.datasets = {}

@staticmethod
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r", encoding="utf-8") as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab

def load_dataset(self, split='train'):
"""combine multiple files into one single dataset
Args:
split (str): name must included in the file(e.g., train, valid, test)
"""
path = self.args.data
if not os.path.exists(path):
raise FileNotFoundError(
"Dataset not found: ({})".format(path)
)

files = os.listdir(path) if os.path.isdir(path) else [path]
files = [f for f in files if split in f]
assert len(files) > 0

self.datasets[split] = CombineBertData(files)


"""
dataset = data_utils.load_indexed_dataset(
split_path, self.dictionary, self.args.dataset_impl, combine=combine
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = TokenBlockDataset(
dataset,
dataset.sizes,
self.args.tokens_per_sample,
pad=self.dictionary.pad(),
eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode,
include_targets=True,
)
add_eos_for_other_targets = (
self.args.sample_break_mode is not None
and self.args.sample_break_mode != "none"
)
self.datasets[split] = MonolingualDataset(
dataset,
dataset.sizes,
self.dictionary,
self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets,
shuffle=True,
targets=self.targets,
add_bos_token=self.args.add_bos_token,
)
"""
22 changes: 22 additions & 0 deletions hetseq/data/BERT_DATA_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
from BERT_DATA import CombineBertData

split = "train"
files = "/scratch365/yding4/bert_project/bert_prep_working_dir/" \
"hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/wikicorpus_en"



path = files
if not os.path.exists(path):
raise FileNotFoundError(
"Dataset not found: ({})".format(path)
)

files = [os.path.join(path, f) for f in os.listdir(path)] if os.path.isdir(path) else [path]
print(files)
files = [f for f in files if split in f]
print(files)
assert len(files) > 0

self.datasets[split] = CombineBertData(files)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion data/iterators.py → hetseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import torch

from data import data_utils
from hetseq.data import data_utils

class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count.
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion distributed_utils.py → hetseq/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.distributed as dist

import utils
from hetseq import utils


def distributed_init(args):
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion lr_scheduler.py → hetseq/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import torch
from optim import _Optimizer
from hetseq.optim import _Optimizer


class _LRScheduler(object):
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 98391fe

Please sign in to comment.