-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from yifding/bert_fine_tuning
Bert fine tuning
- Loading branch information
Showing
36 changed files
with
954 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
Dataset | ||
------- | ||
|
||
.. autoclass:: data.h5pyDataset.BertH5pyData | ||
.. autoclass:: hetseq.data.h5pyDataset.BertH5pyData | ||
|
||
.. autoclass:: data.h5pyDataset.ConBertH5pyData | ||
.. autoclass:: hetseq.data.h5pyDataset.ConBertH5pyData | ||
|
||
.. autoclass:: data.mnist_dataset.MNISTDataset | ||
.. autoclass:: hetseq.data.mnist_dataset.MNISTDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
Learning Rate Scheduler | ||
----------------------- | ||
|
||
.. autoclass:: lr_scheduler.PolynomialDecayScheduler | ||
.. autoclass:: hetseq.lr_scheduler.PolynomialDecayScheduler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
Meters | ||
------ | ||
.. autoclass:: meters.AverageMeter | ||
.. autoclass:: hetseq.meters.AverageMeter | ||
|
||
.. autoclass:: meters.TimeMeter | ||
.. autoclass:: hetseq.meters.TimeMeter | ||
|
||
.. autoclass:: meters.StopwatchMeter | ||
.. autoclass:: hetseq.meters.StopwatchMeter | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
Model | ||
----- | ||
|
||
.. autoclass:: tasks.MNISTNet | ||
.. autoclass:: hetseq.tasks.MNISTNet | ||
|
||
.. autoclass:: bert_modeling.BertConfig | ||
.. autoclass:: hetseq.bert_modeling.BertConfig | ||
|
||
.. autoclass:: bert_modeling.BertForPreTraining | ||
.. autoclass:: hetseq.bert_modeling.BertForPreTraining |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
Optimizer | ||
--------- | ||
|
||
.. autoclass:: optim.Adam | ||
.. autoclass:: hetseq.optim.Adam | ||
|
||
.. autoclass:: optim.Adadelta | ||
.. autoclass:: hetseq.optim.Adadelta |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
Task | ||
---- | ||
|
||
.. autoclass:: tasks.Task | ||
.. autoclass:: hetseq.tasks.Task | ||
|
||
.. autoclass:: tasks.MNISTNet | ||
.. autoclass:: hetseq.tasks.MNISTNet | ||
|
||
.. autoclass:: tasks.LanguageModelingTask | ||
.. autoclass:: hetseq.tasks.LanguageModelingTask |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import os | ||
import collections | ||
|
||
import h5py | ||
from tqdm import tqdm | ||
|
||
import numpy as np | ||
import torch.utils.data | ||
|
||
|
||
# # method is too large(memory 7% took 20GB) and ~25min to load data. | ||
class CombineBertData(torch.utils.data.Dataset): | ||
def __init__(self, files, max_pred_length=512, | ||
keys=('input_ids', 'input_mask', 'segment_ids', | ||
'masked_lm_positions', 'masked_lm_ids','next_sentence_labels')): | ||
# can potentially using multiple processing/thread to speed up reading | ||
# if input is too large, considering other strategies to load data | ||
self.max_pred_length = max_pred_length | ||
self.keys = keys | ||
self.inputs = collections.OrderedDict() | ||
|
||
for key in self.keys: | ||
self.inputs[key] = [] | ||
|
||
for input_file in tqdm(files): | ||
f = h5py.File(input_file, "r") | ||
for i, key in enumerate(keys): | ||
if i < 5: | ||
self.inputs[key].append(f[key][:]) | ||
else: | ||
self.inputs[key].append(np.asarray(f[key][:])) | ||
f.close() | ||
|
||
for key in self.inputs: | ||
self.inputs[key] = np.concatenate(self.inputs[key]) | ||
|
||
def __len__(self): | ||
return len(self.inputs[self.keys[0]]) | ||
|
||
def __getitem__(self, index): | ||
return [self.inputs[key][index] for key in self.keys] | ||
# ['input_ids', 'input_mask', 'segment_ids','masked_lm_positions', | ||
# 'masked_lm_ids','next_sentence_labels'] | ||
|
||
|
||
class BertTask(object): | ||
def __init__(self, args): | ||
self.args = self.load() | ||
self.dict = self.load_vocab(args.dict) | ||
self.seed = args.seed | ||
self.datasets = {} | ||
|
||
@staticmethod | ||
def load_vocab(vocab_file): | ||
"""Loads a vocabulary file into a dictionary.""" | ||
vocab = collections.OrderedDict() | ||
index = 0 | ||
with open(vocab_file, "r", encoding="utf-8") as reader: | ||
while True: | ||
token = reader.readline() | ||
if not token: | ||
break | ||
token = token.strip() | ||
vocab[token] = index | ||
index += 1 | ||
return vocab | ||
|
||
def load_dataset(self, split='train'): | ||
"""combine multiple files into one single dataset | ||
Args: | ||
split (str): name must included in the file(e.g., train, valid, test) | ||
""" | ||
path = self.args.data | ||
if not os.path.exists(path): | ||
raise FileNotFoundError( | ||
"Dataset not found: ({})".format(path) | ||
) | ||
|
||
files = os.listdir(path) if os.path.isdir(path) else [path] | ||
files = [f for f in files if split in f] | ||
assert len(files) > 0 | ||
|
||
self.datasets[split] = CombineBertData(files) | ||
|
||
|
||
""" | ||
dataset = data_utils.load_indexed_dataset( | ||
split_path, self.dictionary, self.args.dataset_impl, combine=combine | ||
) | ||
if dataset is None: | ||
raise FileNotFoundError( | ||
"Dataset not found: {} ({})".format(split, split_path) | ||
) | ||
dataset = TokenBlockDataset( | ||
dataset, | ||
dataset.sizes, | ||
self.args.tokens_per_sample, | ||
pad=self.dictionary.pad(), | ||
eos=self.dictionary.eos(), | ||
break_mode=self.args.sample_break_mode, | ||
include_targets=True, | ||
) | ||
add_eos_for_other_targets = ( | ||
self.args.sample_break_mode is not None | ||
and self.args.sample_break_mode != "none" | ||
) | ||
self.datasets[split] = MonolingualDataset( | ||
dataset, | ||
dataset.sizes, | ||
self.dictionary, | ||
self.output_dictionary, | ||
add_eos_for_other_targets=add_eos_for_other_targets, | ||
shuffle=True, | ||
targets=self.targets, | ||
add_bos_token=self.args.add_bos_token, | ||
) | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
from BERT_DATA import CombineBertData | ||
|
||
split = "train" | ||
files = "/scratch365/yding4/bert_project/bert_prep_working_dir/" \ | ||
"hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/wikicorpus_en" | ||
|
||
|
||
|
||
path = files | ||
if not os.path.exists(path): | ||
raise FileNotFoundError( | ||
"Dataset not found: ({})".format(path) | ||
) | ||
|
||
files = [os.path.join(path, f) for f in os.listdir(path)] if os.path.isdir(path) else [path] | ||
print(files) | ||
files = [f for f in files if split in f] | ||
print(files) | ||
assert len(files) > 0 | ||
|
||
self.datasets[split] = CombineBertData(files) |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
Oops, something went wrong.