Skip to content

Commit

Permalink
Merge pull request #2 from ShengdingHu/main
Browse files Browse the repository at this point in the history
add checkpointing & modify logging dir
  • Loading branch information
ShengdingHu committed Oct 4, 2021
2 parents 2b5263a + 142ad5a commit 52b5f26
Show file tree
Hide file tree
Showing 11 changed files with 419 additions and 144 deletions.
17 changes: 15 additions & 2 deletions experiments/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
sys.path.append(".")

Expand Down Expand Up @@ -29,6 +30,8 @@





def get_config():
parser = argparse.ArgumentParser("classification config")
parser.add_argument('--config_yaml', type=str, help='the configuration file for this experiment.')
Expand All @@ -47,23 +50,33 @@ def build_dataloader(dataset, template, tokenizer, config, split):
shuffle=config[split].shuffle_data,
teacher_forcing=config[split].teacher_forcing \
if hasattr(config[split],'teacher_forcing') else None,
predict_eos_token=True if config.task=="generation" else False,
**config.dataloader
)
return dataloader

def save_config_to_yaml(config):
from contextlib import redirect_stdout
saved_yaml_path = os.path.join(config.logging.path, "config.yaml")

with open(saved_yaml_path, 'w') as f:
with redirect_stdout(f): print(config.dump())
logger.info("Config saved as {}".format(saved_yaml_path))


def main():
init_logger(log_file_level=logging.DEBUG, log_level=logging.INFO)
config = get_config()
# init logger, create log dir and set log level, etc.
init_logger(config=config)
# save config to the logger directory
save_config_to_yaml(config)
# set seed
set_seed(config)
# load the pretrained models, its model, tokenizer, and config.
plm_model, plm_tokenizer, plm_config = load_plm(config)
# load dataset. The valid_dataset can be None
train_dataset, valid_dataset, test_dataset, Processor = load_dataset(config)


if config.task == "classification":
# define prompt
template = load_template(config=config, model=plm_model, tokenizer=plm_tokenizer, plm_config=plm_config)
Expand Down
7 changes: 2 additions & 5 deletions experiments/generation_manual_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ train:
num_epochs: 5
batch_size: 2
teacher_forcing: True
gradient_accumulation_step: 2

generation: # Adding any arguments for generation here.
parent_config: task
max_length: 512
result_path: ./outputs/generation/webnlg/

plm:
model_name: gpt2
Expand All @@ -21,7 +21,7 @@ plm:
freeze_para: False

## LEARINING SETTING ####################################################
learning_setting: full # selecting from "full", "zero-shot", "few-shot"
learning_setting: full # selecting from "full", "zero_shot", "few_shot"

# few_shot:
# parent_config: learning_setting
Expand All @@ -34,9 +34,6 @@ learning_setting: full # selecting from "full", "zero-shot", "few-shot"
# num_examples_per_label_dev: 100
# seed: 123

dataloader:
predict_eos_token: True

template: manual_template
verbalizer:

Expand Down
11 changes: 4 additions & 7 deletions experiments/generation_prefixtuning.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,20 @@ task: generation

train:
num_epochs: 5
batch_size: 2
batch_size: 3
teacher_forcing: True
gradient_accumulation_step: 2


generation: # Adding any arguments for generation here.
parent_config: task
max_length: 512
result_path: ./outputs/generation/webnlg/

plm:
model_name: gpt2
model_path: gpt2-medium
optimize:
freeze_para: False

dataloader:
predict_eos_token: True # this is necessary for generation.
freeze_para: True

## LEARINING SETTING ####################################################
learning_setting: full # selecting from "full", "zero-shot", "few-shot"
Expand Down Expand Up @@ -52,7 +49,7 @@ prefix_tuning_template:
<text_b>: text_b
prefix_dropout: 0.0
optimize:
lr: 0.0004
lr: 0.00005



Expand Down
7 changes: 5 additions & 2 deletions openprompt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_conditional_config(config):
deeper_config[key] = config[key]
config.pop(key)

# depth search over all config nodes
# breadth first search over all config nodes
queue = [config]

while len(queue) > 0:
Expand All @@ -45,13 +45,16 @@ def get_conditional_config(config):
leaf[1] in deeper_config.keys():
retrieved = deeper_config[leaf[1]]
setattr(config, leaf[1], retrieved)
if isinstance(retrieved, CfgNode):
if isinstance(retrieved, CfgNode):
# also BFS the newly added CfgNode.
queue.append(retrieved)
elif isinstance(leaf[1], CfgNode):
queue.append(leaf[1])
return config




_VALID_TYPES = {tuple, list, str, int, float, bool, type(None)}


Expand Down
84 changes: 64 additions & 20 deletions openprompt/config_default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ reproduce: # seed for reproduction
cuda_seed: -1 # seed for cuda

plm: # plm parameters
model_name:
model_name: # the model name, e.g. bert, roberta, gpt2, ...
# for all the available model, please check the ./plms directory.
model_path:
optimize:
freeze_para: False
Expand All @@ -29,51 +30,94 @@ plm: # plm parameters
type: # by default, it will choose get_linear_schedule_with_warmup
num_warmup_steps: 500

## LOGGIN and CHECKPOINTING ##################################################
## in logging, each experiment will create a seperate folder for saving log.txt
## , (full) config.json, and the checkpoint (if use the same path).
## logging is in the following format:
## ./log
## - DIR_NAME_1
## - log.txt
## - config.yaml
## - checkpoint.pt
## - ...
## - DIR_NAME_2
## - ...
##
logging:
path_base: ./logs # the path base of all the logs.
file_level: NOTSET # make sure it's an option of logging package
console_level: INFO # make sure it's an option of logging package
unique_string: # the generated (or usr defined) unique string for one experiment.
unique_string_keys: # used to generate the unique string for saving
- dataset.name
- plm.model_path # only keep the last folder name in code,
# .i.e ../.cache/roberta-large/ will save as roberta-large
- template
- verbalizer
- datetime # a 12-digit string recording the date time of running the experiment, i.e., YYMMDDHHMMSS.
datetime_format: "%y%m%d%H%M%S" # only useful when unique_string_keys includes `datetime`.
# make sure it's a valid format for datetime package.
path: # always keep none to let the config generate a full path according to
# path_base and unique_string.
overwrite: True # if a same log path exists, overwrite it.

checkpoint: # checkpoint use the same directory as logging.
save_lastest: False # Normaly set to False to reduce memory use, set
# to true to allow resuming learning process.
save_best: False # Keep saving the epoch of the best-performance.
higher_better: True # is the metric to determine best checkpoint highter better?


## PIPELINE #######################################################

train:
num_epochs: 5
batch_size: 2
shuffle_data: True
teacher_forcing: False

num_epochs: 5 # the number of training epochs.
batch_size: 2 # the batch_size.
shuffle_data: True # whether shuffle the training data.
teacher_forcing: False # whether perform teacher forcing in training.
# if true, the desired prediction on each mask will
# be filled in the mask.
gradient_accumulation_step: 1 # update weight every N step of training.
# set 1 to disable gradient accumulation.

dev:
batch_size: 2
shuffle_data: False
batch_size: 2 # evaluationn batch_size, can be a bit larger than training batch_size
shuffle_data: False # whether to perform data shuffling in evaluation

test:
batch_size: 2
shuffle_data: False
batch_size: 2 # evaluationn batch_size, can be a bit larger than training batch_size
shuffle_data: False # whether to perform data shuffling in evaluation

## TASK ##########################################################@
task: classification


classification:
parent_config: task
metric:
- micro-f1
metric: # the first one will be the main to determine checkpoint.
- micro-f1 # whether the higher metric value is better.
loss_function: cross_entropy ## the loss function for classification

generation: # Adding any arguments for generation here.
parent_config: task
max_length: 512 # the max_length of the generated sentence. INCLUDING the input_ids. So: generation.max_length > dataloader.max_seq_length
result_path: # the path to save the generated sentences.

metric:
- sentence_bleu

relation_classification:
parent_config: task

## DATASET #########################################################
dataset:
name:
path:
name: # the name of the dataset, for the supported choices,
# please see the processors in ./data_utils/
path: # whether is the dataset saved in your local machine.

## DATALOADER ######################################################
dataloader:
max_seq_length: 256
decoder_max_length: 256
predict_eos_token: False # necessary to set to true in generation.
max_seq_length: 256 # max_seq_length
decoder_max_length: 256 # the decoder max length to truncate decoder input sequence
# if it is an encoder-decoder architecture. Note that it's not equavalent
# to generation.max_length which is used merely in the generation phase.
truncate_method: "head" # choosing from balanced, head, tail

## LEARINING SETTING ####################################################
Expand Down
7 changes: 7 additions & 0 deletions openprompt/data_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def load_dataset(config: CfgNode, return_class=True):
except FileNotFoundError:
logger.warning("Has no test dataset.")
test_dataset = None
# checking whether donwloaded.
if (train_dataset is None) and \
(valid_dataset is None) and \
(test_dataset is None):
logger.error("Dataset is empty. Either there is no download or the path is wrong. "+ \
"If not downloaded, please `cd datasets/` and `bash download_xxx.sh`")
exit()
if return_class:
return train_dataset, valid_dataset, test_dataset, processor
else:
Expand Down
41 changes: 40 additions & 1 deletion openprompt/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,29 @@ def tokenizer(self):
r'''Utility property, to get the tokenizer more easily.
'''
return self.verbalizer.tokenizer

def state_dict(self):
r""" Save the model using template and verbalizer's save methods.
Args:
path (:obj:`str`): the full path of the checkpoint.
save_plm (:obj:`bool`): whether saving the pretrained language model.
kwargs: other information, such as the achieved metric value.
"""
_state_dict = {}
_state_dict['plm'] = self.model.state_dict()
_state_dict['template'] = self.template.state_dict()
_state_dict['verbalizer'] = self.verbalizer.state_dict()
return _state_dict

def load_state_dict(self, state_dict):
if 'plm' in state_dict:
self.model.load_state_dict(state_dict['plm'])
self.template.load_state_dict(state_dict['template'])
self.verbalizer.load_state_dict(state_dict['verbalizer'])






class PromptForGeneration(nn.Module, GenerationMixin):
Expand Down Expand Up @@ -475,5 +498,21 @@ def _prepare_encoder_decoder_kwargs_for_generation(
model_inputs = self.prompt_model.prepare_model_inputs(batch)
model_kwargs["encoder_outputs"] = encoder(return_dict=True, **model_inputs)
return model_kwargs


def state_dict(self):
r""" Save the model using template and verbalizer's save methods.
Args:
path (:obj:`str`): the full path of the checkpoint.
save_plm (:obj:`bool`): whether saving the pretrained language model.
kwargs: other information, such as the achieved metric value.
"""
_state_dict = {}
_state_dict['plm'] = self.model.state_dict()
_state_dict['template'] = self.template.state_dict()
return _state_dict

def load_state_dict(self, state_dict):
if 'plm' in state_dict:
self.model.load_state_dict(state_dict['plm'])
self.template.load_state_dict(state_dict['template'])

0 comments on commit 52b5f26

Please sign in to comment.