Skip to content

Commit

Permalink
start to add bert_fine_tuning_token_classification task
Browse files Browse the repository at this point in the history
  • Loading branch information
yifding committed Dec 11, 2020
1 parent 6f425c1 commit f22e6ac
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,33 @@ def load_dataset(self, split, **kwargs):
print('| loading finished')


class BertFineTuningTask(Task):
def __init__(self, dictionary):
super(BertFineTuningTask, self).__init__(args)
self.dictionary = dictionary

@classmethod
def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
dictionary = cls.load_dictionary(cls, args.dict)

return cls(args, dictionary)

def build_model(self, args):
if args.task == 'BertForTokenClassification':
from bert_modeling import BertForPreTraining, BertConfig
config = BertConfig.from_json_file(args.config_file)
# mention detection, num_label is by default 3
num_label = args.num_label if hasattr(args, 'num_label') else 3
model = BertForTokenClassification(config, num_label)
else:
raise ValueError('Unknown fine_tunning task!')
return model


class MNISTTask(Task):
def __init__(self, args):
super(MNISTTask, self).__init__(args)
Expand Down

0 comments on commit f22e6ac

Please sign in to comment.