From 142ad5a62ea3735a9086bc4ff5a7b0c4fa07b03c Mon Sep 17 00:00:00 2001 From: ShengdingHu Date: Mon, 4 Oct 2021 11:50:54 +0000 Subject: [PATCH] add checkpointing & modify logging dir --- experiments/cli.py | 17 +- experiments/generation_manual_template.yaml | 7 +- experiments/generation_prefixtuning.yaml | 11 +- openprompt/config.py | 7 +- openprompt/config_default.yaml | 84 +++++-- openprompt/data_utils/__init__.py | 7 + openprompt/pipeline_base.py | 41 +++- openprompt/trainer.py | 229 ++++++++++++++------ openprompt/utils/logging.py | 76 ++++++- openprompt/utils/metrics.py | 51 +++-- openprompt/utils/utils.py | 33 ++- 11 files changed, 419 insertions(+), 144 deletions(-) diff --git a/experiments/cli.py b/experiments/cli.py index ae7740c0..6c0f42c7 100644 --- a/experiments/cli.py +++ b/experiments/cli.py @@ -1,3 +1,4 @@ +import os import sys sys.path.append(".") @@ -29,6 +30,8 @@ + + def get_config(): parser = argparse.ArgumentParser("classification config") parser.add_argument('--config_yaml', type=str, help='the configuration file for this experiment.') @@ -47,15 +50,26 @@ def build_dataloader(dataset, template, tokenizer, config, split): shuffle=config[split].shuffle_data, teacher_forcing=config[split].teacher_forcing \ if hasattr(config[split],'teacher_forcing') else None, + predict_eos_token=True if config.task=="generation" else False, **config.dataloader ) return dataloader +def save_config_to_yaml(config): + from contextlib import redirect_stdout + saved_yaml_path = os.path.join(config.logging.path, "config.yaml") + + with open(saved_yaml_path, 'w') as f: + with redirect_stdout(f): print(config.dump()) + logger.info("Config saved as {}".format(saved_yaml_path)) def main(): - init_logger(log_file_level=logging.DEBUG, log_level=logging.INFO) config = get_config() + # init logger, create log dir and set log level, etc. + init_logger(config=config) + # save config to the logger directory + save_config_to_yaml(config) # set seed set_seed(config) # load the pretrained models, its model, tokenizer, and config. @@ -63,7 +77,6 @@ def main(): # load dataset. The valid_dataset can be None train_dataset, valid_dataset, test_dataset, Processor = load_dataset(config) - if config.task == "classification": # define prompt template = load_template(config=config, model=plm_model, tokenizer=plm_tokenizer, plm_config=plm_config) diff --git a/experiments/generation_manual_template.yaml b/experiments/generation_manual_template.yaml index 19531c68..fcd73c35 100644 --- a/experiments/generation_manual_template.yaml +++ b/experiments/generation_manual_template.yaml @@ -8,11 +8,11 @@ train: num_epochs: 5 batch_size: 2 teacher_forcing: True + gradient_accumulation_step: 2 generation: # Adding any arguments for generation here. parent_config: task max_length: 512 - result_path: ./outputs/generation/webnlg/ plm: model_name: gpt2 @@ -21,7 +21,7 @@ plm: freeze_para: False ## LEARINING SETTING #################################################### -learning_setting: full # selecting from "full", "zero-shot", "few-shot" +learning_setting: full # selecting from "full", "zero_shot", "few_shot" # few_shot: # parent_config: learning_setting @@ -34,9 +34,6 @@ learning_setting: full # selecting from "full", "zero-shot", "few-shot" # num_examples_per_label_dev: 100 # seed: 123 -dataloader: - predict_eos_token: True - template: manual_template verbalizer: diff --git a/experiments/generation_prefixtuning.yaml b/experiments/generation_prefixtuning.yaml index abda6fa5..ecce2f64 100644 --- a/experiments/generation_prefixtuning.yaml +++ b/experiments/generation_prefixtuning.yaml @@ -6,23 +6,20 @@ task: generation train: num_epochs: 5 - batch_size: 2 + batch_size: 3 teacher_forcing: True + gradient_accumulation_step: 2 generation: # Adding any arguments for generation here. parent_config: task max_length: 512 - result_path: ./outputs/generation/webnlg/ plm: model_name: gpt2 model_path: gpt2-medium optimize: - freeze_para: False - -dataloader: - predict_eos_token: True # this is necessary for generation. + freeze_para: True ## LEARINING SETTING #################################################### learning_setting: full # selecting from "full", "zero-shot", "few-shot" @@ -52,7 +49,7 @@ prefix_tuning_template: : text_b prefix_dropout: 0.0 optimize: - lr: 0.0004 + lr: 0.00005 diff --git a/openprompt/config.py b/openprompt/config.py index aeb919db..daed52db 100644 --- a/openprompt/config.py +++ b/openprompt/config.py @@ -33,7 +33,7 @@ def get_conditional_config(config): deeper_config[key] = config[key] config.pop(key) - # depth search over all config nodes + # breadth first search over all config nodes queue = [config] while len(queue) > 0: @@ -45,13 +45,16 @@ def get_conditional_config(config): leaf[1] in deeper_config.keys(): retrieved = deeper_config[leaf[1]] setattr(config, leaf[1], retrieved) - if isinstance(retrieved, CfgNode): + if isinstance(retrieved, CfgNode): + # also BFS the newly added CfgNode. queue.append(retrieved) elif isinstance(leaf[1], CfgNode): queue.append(leaf[1]) return config + + _VALID_TYPES = {tuple, list, str, int, float, bool, type(None)} diff --git a/openprompt/config_default.yaml b/openprompt/config_default.yaml index 73a8ef7b..63d5c4e2 100644 --- a/openprompt/config_default.yaml +++ b/openprompt/config_default.yaml @@ -16,7 +16,8 @@ reproduce: # seed for reproduction cuda_seed: -1 # seed for cuda plm: # plm parameters - model_name: + model_name: # the model name, e.g. bert, roberta, gpt2, ... + # for all the available model, please check the ./plms directory. model_path: optimize: freeze_para: False @@ -29,51 +30,94 @@ plm: # plm parameters type: # by default, it will choose get_linear_schedule_with_warmup num_warmup_steps: 500 +## LOGGIN and CHECKPOINTING ################################################## +## in logging, each experiment will create a seperate folder for saving log.txt +## , (full) config.json, and the checkpoint (if use the same path). +## logging is in the following format: +## ./log +## - DIR_NAME_1 +## - log.txt +## - config.yaml +## - checkpoint.pt +## - ... +## - DIR_NAME_2 +## - ... +## +logging: + path_base: ./logs # the path base of all the logs. + file_level: NOTSET # make sure it's an option of logging package + console_level: INFO # make sure it's an option of logging package + unique_string: # the generated (or usr defined) unique string for one experiment. + unique_string_keys: # used to generate the unique string for saving + - dataset.name + - plm.model_path # only keep the last folder name in code, + # .i.e ../.cache/roberta-large/ will save as roberta-large + - template + - verbalizer + - datetime # a 12-digit string recording the date time of running the experiment, i.e., YYMMDDHHMMSS. + datetime_format: "%y%m%d%H%M%S" # only useful when unique_string_keys includes `datetime`. + # make sure it's a valid format for datetime package. + path: # always keep none to let the config generate a full path according to + # path_base and unique_string. + overwrite: True # if a same log path exists, overwrite it. + +checkpoint: # checkpoint use the same directory as logging. + save_lastest: False # Normaly set to False to reduce memory use, set + # to true to allow resuming learning process. + save_best: False # Keep saving the epoch of the best-performance. + higher_better: True # is the metric to determine best checkpoint highter better? + + ## PIPELINE ####################################################### train: - num_epochs: 5 - batch_size: 2 - shuffle_data: True - teacher_forcing: False - + num_epochs: 5 # the number of training epochs. + batch_size: 2 # the batch_size. + shuffle_data: True # whether shuffle the training data. + teacher_forcing: False # whether perform teacher forcing in training. + # if true, the desired prediction on each mask will + # be filled in the mask. + gradient_accumulation_step: 1 # update weight every N step of training. + # set 1 to disable gradient accumulation. + dev: - batch_size: 2 - shuffle_data: False + batch_size: 2 # evaluationn batch_size, can be a bit larger than training batch_size + shuffle_data: False # whether to perform data shuffling in evaluation test: - batch_size: 2 - shuffle_data: False + batch_size: 2 # evaluationn batch_size, can be a bit larger than training batch_size + shuffle_data: False # whether to perform data shuffling in evaluation ## TASK ##########################################################@ task: classification - classification: parent_config: task - metric: - - micro-f1 + metric: # the first one will be the main to determine checkpoint. + - micro-f1 # whether the higher metric value is better. loss_function: cross_entropy ## the loss function for classification generation: # Adding any arguments for generation here. parent_config: task max_length: 512 # the max_length of the generated sentence. INCLUDING the input_ids. So: generation.max_length > dataloader.max_seq_length - result_path: # the path to save the generated sentences. - + metric: + - sentence_bleu relation_classification: parent_config: task ## DATASET ######################################################### dataset: - name: - path: + name: # the name of the dataset, for the supported choices, + # please see the processors in ./data_utils/ + path: # whether is the dataset saved in your local machine. ## DATALOADER ###################################################### dataloader: - max_seq_length: 256 - decoder_max_length: 256 - predict_eos_token: False # necessary to set to true in generation. + max_seq_length: 256 # max_seq_length + decoder_max_length: 256 # the decoder max length to truncate decoder input sequence + # if it is an encoder-decoder architecture. Note that it's not equavalent + # to generation.max_length which is used merely in the generation phase. truncate_method: "head" # choosing from balanced, head, tail ## LEARINING SETTING #################################################### diff --git a/openprompt/data_utils/__init__.py b/openprompt/data_utils/__init__.py index 8a1476bc..ca5017d9 100644 --- a/openprompt/data_utils/__init__.py +++ b/openprompt/data_utils/__init__.py @@ -62,6 +62,13 @@ def load_dataset(config: CfgNode, return_class=True): except FileNotFoundError: logger.warning("Has no test dataset.") test_dataset = None + # checking whether donwloaded. + if (train_dataset is None) and \ + (valid_dataset is None) and \ + (test_dataset is None): + logger.error("Dataset is empty. Either there is no download or the path is wrong. "+ \ + "If not downloaded, please `cd datasets/` and `bash download_xxx.sh`") + exit() if return_class: return train_dataset, valid_dataset, test_dataset, processor else: diff --git a/openprompt/pipeline_base.py b/openprompt/pipeline_base.py index 5a3f5fec..35dbc4af 100644 --- a/openprompt/pipeline_base.py +++ b/openprompt/pipeline_base.py @@ -270,6 +270,29 @@ def tokenizer(self): r'''Utility property, to get the tokenizer more easily. ''' return self.verbalizer.tokenizer + + def state_dict(self): + r""" Save the model using template and verbalizer's save methods. + Args: + path (:obj:`str`): the full path of the checkpoint. + save_plm (:obj:`bool`): whether saving the pretrained language model. + kwargs: other information, such as the achieved metric value. + """ + _state_dict = {} + _state_dict['plm'] = self.model.state_dict() + _state_dict['template'] = self.template.state_dict() + _state_dict['verbalizer'] = self.verbalizer.state_dict() + return _state_dict + + def load_state_dict(self, state_dict): + if 'plm' in state_dict: + self.model.load_state_dict(state_dict['plm']) + self.template.load_state_dict(state_dict['template']) + self.verbalizer.load_state_dict(state_dict['verbalizer']) + + + + class PromptForGeneration(nn.Module, GenerationMixin): @@ -475,5 +498,21 @@ def _prepare_encoder_decoder_kwargs_for_generation( model_inputs = self.prompt_model.prepare_model_inputs(batch) model_kwargs["encoder_outputs"] = encoder(return_dict=True, **model_inputs) return model_kwargs - + + def state_dict(self): + r""" Save the model using template and verbalizer's save methods. + Args: + path (:obj:`str`): the full path of the checkpoint. + save_plm (:obj:`bool`): whether saving the pretrained language model. + kwargs: other information, such as the achieved metric value. + """ + _state_dict = {} + _state_dict['plm'] = self.model.state_dict() + _state_dict['template'] = self.template.state_dict() + return _state_dict + + def load_state_dict(self, state_dict): + if 'plm' in state_dict: + self.model.load_state_dict(state_dict['plm']) + self.template.load_state_dict(state_dict['template']) diff --git a/openprompt/trainer.py b/openprompt/trainer.py index d7d3539f..d76eaadf 100644 --- a/openprompt/trainer.py +++ b/openprompt/trainer.py @@ -1,11 +1,12 @@ - import os import sys +sys.path.append(".") from torch.utils.data import dataloader -sys.path.append(".") -from typing import Callable, Union + +from openprompt.utils.utils import load_checkpoint, save_checkpoint +from typing import Callable, OrderedDict, Union from torch.nn.parallel.data_parallel import DataParallel from openprompt.pipeline_base import PromptForClassification, PromptForGeneration from tqdm import tqdm @@ -19,8 +20,8 @@ -class ClassificationRunner(object): - r"""A runner for simple training without training tricks. +class BaseRunner(object): + r"""A base runner for training without training tricks. Applying training tricks such as ensemble of template or verbalizer, or self-training can use other runner class. This class is specially implemented for classification. @@ -41,7 +42,6 @@ def __init__(self, valid_dataloader: Optional[PromptDataLoader] = None, test_dataloader: Optional[PromptDataLoader] = None, config: CfgNode = None, - loss_function: Optional[Callable] = None, ): self.prompt_model = prompt_model self.inner_model = prompt_model.module if isinstance(prompt_model, DataParallel) else prompt_model @@ -49,11 +49,94 @@ def __init__(self, self.valid_dataloader = valid_dataloader self.test_dataloader = test_dataloader self.config = config + self.config_optimize() + + def config_loss_function(self,): + raise NotImplementedError + + def config_optimize(self,): + raise NotImplementedError + + def evaluate(self, dataloader, split, post_evaluate_hook=None): + raise NotImplementedError + + def train_epoch(self, epoch): + raise NotImplementedError + + def prompt_initialize(self): + r"""Some initialization works + """ + pass + + def run(self): + self.prompt_initialize() + max_score = None + for epoch in range(self.config.train.num_epochs): + total_loss = self.train_epoch(epoch) + scores = self.evaluate(self.valid_dataloader, "Valid") + model_state_dict = self.inner_model.state_dict() + if self.config.plm.optimize.freeze_para: + model_state_dict.pop('plm') + state_dict = { + "epoch": epoch+1, + "state_dict": self.inner_model.state_dict(), + "optimizer": [opt.state_dict() if isinstance(opt, torch.optim.Optimizer) else None for opt in self.optimizers] , + "scheduler": self.schedulers, + "scores": scores + } + cur_score = scores.popitem()[1] + + is_best = ((cur_score - max_score)>=0) == \ + self.config.checkpoint.higher_better if max_score is not None else True + if is_best: + max_score = cur_score + save_checkpoint(state_dict = state_dict, + is_best=(is_best and self.config.checkpoint.save_best), + save_path=self.config.logging.path) + state_dict = load_checkpoint(load_path=self.config.logging.path, + load_best = self.config.checkpoint.save_best, + map_location="cpu", # cpu to prevent CUDA out of memory. + ) + self.inner_model.load_state_dict(state_dict['state_dict']) + self.inner_model.to("cuda:{}".format(self.config.environment.local_rank)) + self.evaluate(self.test_dataloader, "Test") + + + +class ClassificationRunner(BaseRunner): + r"""A runner for simple training without training tricks. + Applying training tricks such as ensemble of template or verbalizer, + or self-training can use other runner class. + This class is specially implemented for classification. + For generation task, though it can be integrated in this class + via `task` option, we keep it as another class for simplicity. + + Args: + prompt_model (:obj:`Union[DataParallel, PromptForClassification]`): One ``PromptModel`` object. + train_dataloader (:obj:`PromptDataloader`, optional): The dataloader to bachify and process the training data. + valid_dataloader (:obj:`PromptDataloader`, optionla): The dataloader to bachify and process the val data. + test_dataloader (:obj:`PromptDataloader`, optional): The dataloader to bachify and process the test data. + config (:obj:`CfgNode`): A configuration object. + loss_function (:obj:`Callable`, optional): The loss function in the training process. + """ + def __init__(self, + prompt_model: Union[DataParallel, PromptForClassification], + train_dataloader: Optional[PromptDataLoader] = None, + valid_dataloader: Optional[PromptDataLoader] = None, + test_dataloader: Optional[PromptDataLoader] = None, + config: CfgNode = None, + loss_function: Optional[Callable] = None, + ): + super().__init__(prompt_model=prompt_model, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + test_dataloader=test_dataloader, + config=config) + if loss_function is None: self.config_loss_function() else: self.loss_function = loss_function - self.config_optimize() def config_loss_function(self,): r"""config the loss function if it's not passed. @@ -157,35 +240,44 @@ def evaluate(self, dataloader, split, post_evaluate_hook=None): preds.extend(pred.cpu().tolist()) labels.extend(batch['label'].cpu().tolist()) self.prompt_model.train() - scores = {} + scores = OrderedDict() scores_str = "" for metric in self.config.classification.metric: score = classification_metrics(preds, labels, metric) scores[metric] = score scores_str += "{}: {}\n".format(metric, score) - logger.info("{} Performance: {}".format(split, scores_str)) + logger.info("{} Performance: {}".format(split, scores_str.strip())) + return scores def train_epoch(self, epoch): self.prompt_model.train() self.prompt_model.zero_grad() + accumulation_steps = self.config.train.gradient_accumulation_step total_loss = 0.0 - for batch in tqdm(self.train_dataloader, desc="Train"): - for optimizer in self.optimizers: - if optimizer is not None: - optimizer.zero_grad() + pbar = tqdm(self.train_dataloader, desc="Train epoch {}".format(epoch)) + for i, batch in enumerate(pbar): batch = batch.to("cuda:{}".format(self.config.environment.local_rank)).to_dict() logits = self.prompt_model(batch) loss = self.loss_function(logits, batch['label']) - total_loss += loss.item() + loss = loss / accumulation_steps loss.backward() - for optimizer in self.optimizers: - if optimizer is not None: - optimizer.step() - - for scheduler in self.schedulers: - if scheduler is not None: - scheduler.step() + if (i+1) % accumulation_steps == 0: + # do optimizer step + for optimizer in self.optimizers: + if optimizer is not None: + optimizer.step() + for scheduler in self.schedulers: + if scheduler is not None: + scheduler.step() + # zero_grad + self.prompt_model.zero_grad() + for optimizer in self.optimizers: + if optimizer is not None: + optimizer.zero_grad() + total_loss += loss.item() + pbar.set_postfix(loss = loss.item()) logger.info("Epoch {}, loss: {:.4f}".format(epoch, total_loss)) + return total_loss def prompt_initialize(self): verbalizer_config = self.config[self.config.verbalizer] @@ -216,15 +308,8 @@ def prompt_initialize(self): if hasattr(self.inner_model.template, "optimize_to_initialize" ): self.inner_model.template.optimize_to_initialize() - def run(self): - self.prompt_initialize() - for epoch in range(self.config.train.num_epochs): - self.train_epoch(epoch) - self.evaluate(self.valid_dataloader, "Valid") - self.evaluate(self.test_dataloader, "Test") - -class GenerationRunner(object): +class GenerationRunner(BaseRunner): r"""A runner for simple training without training tricks. Applying training tricks such as ensemble of template or verbalizer, or self-training can use other runner class. @@ -244,21 +329,16 @@ def __init__(self, test_dataloader: Optional[PromptDataLoader] = None, config: CfgNode = None, ): - self.prompt_model = prompt_model - self.inner_model = prompt_model.module if isinstance(prompt_model, DataParallel) else prompt_model - self.train_dataloader = train_dataloader - self.valid_dataloader = valid_dataloader - self.test_dataloader = test_dataloader - self.config = config - self.config_optimize() + super().__init__(prompt_model=prompt_model, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + test_dataloader=test_dataloader, + config=config) def config_loss_function(self,): - r"""config the loss function if it's not passed. + r""" No need to config loss_function in generation. """ - if self.config.classification.loss_function == "cross_entropy": - self.loss_function = torch.nn.CrossEntropyLoss() - else: - raise NotImplementedError + pass def config_optimize(self,): r"""config the optimizer and scheduler for 1. model 2. template 3. verbalizer @@ -316,51 +396,56 @@ class Dummy: self.schedulers = [self.model_scheduler, self.template_scheduler] def evaluate(self, dataloader, split, post_evaluate_hook=None): - if not os.path.exists(self.config.generation.result_path): - raise FileNotFoundError("Can't find {}".format(self.config.generation.result_path)) - - # TODO: allow more flexible file name - ret_file_name= os.path.join(self.config.generation.result_path,"{}_{}.txt".format(self.config.template, split)) - fout = open(ret_file_name,'w') + ret_file_name= os.path.join(self.config.logging.path,"{}_generated_text.txt".format(split)) + tgt_texts = [] generated_sentences_all = [] - logger.info("Begin generation, result written at {}".format(ret_file_name)) for batch in tqdm(dataloader, desc=split): batch = batch.to("cuda:{}".format(self.config.environment.local_rank)).to_dict() output_sequences, generated_sentences = self.inner_model.generate(batch) tgt_texts.extend(batch['tgt_text']) generated_sentences_all.extend(generated_sentences) - for i in range(len(batch['tgt_text'])): - fout.write("[Gold]:"+batch['tgt_text'][i]+"\n") - fout.write("[Gen]: "+generated_sentences[i]+"\n\n") + + fout = open(ret_file_name,'w') + for i in range(len(tgt_texts)): + fout.write("[Gold]:"+ tgt_texts[i]+"\n") + fout.write("[Gen]: "+ generated_sentences_all[i]+"\n\n") fout.close() - score = generation_metric(tgt_texts, generated_sentences_all) - logger.info("Evaluate Bleu score: {:.3f}.".format(score*100)) + + scores = OrderedDict() + scores_str = "" + for metric in self.config.generation.metric: + score = generation_metric(tgt_texts, generated_sentences_all, metric) + scores[metric] = score + scores_str += "{}: {}\n".format(metric, score) + logger.info("{} Performance: {}".format(split, scores_str.strip())) + return scores def train_epoch(self, epoch): self.prompt_model.train() self.prompt_model.zero_grad() + accumulation_steps = self.config.train.gradient_accumulation_step total_loss = 0.0 - for batch in tqdm(self.train_dataloader, desc="Train"): - for optimizer in self.optimizers: - if optimizer is not None: - optimizer.zero_grad() + pbar = tqdm(self.train_dataloader, desc="Train epoch {}".format(epoch)) + for i, batch in enumerate(pbar): batch = batch.to("cuda:{}".format(self.config.environment.local_rank)).to_dict() loss = self.prompt_model(batch).sum() #TODO: parallel doesn't aggregate the result for some reason. to fix. - total_loss += loss.item() + loss = loss / accumulation_steps loss.backward() - for optimizer in self.optimizers: - if optimizer is not None: - optimizer.step() - - for scheduler in self.schedulers: - if scheduler is not None: - scheduler.step() + if (i+1) % accumulation_steps == 0: + # do optimizer step + for optimizer in self.optimizers: + if optimizer is not None: + optimizer.step() + for scheduler in self.schedulers: + if scheduler is not None: + scheduler.step() + # zero_grad + self.prompt_model.zero_grad() + for optimizer in self.optimizers: + if optimizer is not None: + optimizer.zero_grad() + total_loss += loss.item() + pbar.set_postfix(loss = loss.item()) logger.info("Epoch {}, loss: {:.4f}".format(epoch, total_loss)) - - def run(self): - # currently no methods support automatic template initialization for generation - for epoch in range(self.config.train.num_epochs): - self.train_epoch(epoch) - self.evaluate(self.valid_dataloader, "Valid") - self.evaluate(self.test_dataloader, "Test") + return total_loss \ No newline at end of file diff --git a/openprompt/utils/logging.py b/openprompt/utils/logging.py index 518b9923..99753441 100644 --- a/openprompt/utils/logging.py +++ b/openprompt/utils/logging.py @@ -1,15 +1,80 @@ # -*- coding: utf-8 -*- import logging -from logging.handlers import RotatingFileHandler +import os +import datetime + logger = logging.getLogger() +def config_logger(config): + r""" Automatic generate log directory for experiments. + First generate the unique_string of one experiment, if the user + didn't specify one, according + to the user-defined keys logging.unique_string_keys. + Then create the directory. + """ + if not os.path.exists(config.logging.path_base): + raise NotADirectoryError("No logging base directory") + + # generate unique string + temp_strs = [] + if config.logging.unique_string is None: + for item in config.logging.unique_string_keys: + if item == "datetime": + continue + item = item.split(".") # split into sub items. + subconfig = config + for key in item: + try: + subconfig = subconfig[key] + except: + raise ValueError("The unique string_key is not a config option ") + if not isinstance(subconfig, str): + try: + subconfig = str(subconfig) + except: + print("The value of subconfig key {} can't be converted to a string".format(".".join(item))) + continue + + subconfig = subconfig.split("/")[-1] + temp_strs.append(subconfig) + + if 'datetime' in config.logging.unique_string_keys: + if config.logging.datetime_format is None: + config.logging.datetime_format = '%y%m%d%H%M%S' + time_str = datetime.datetime.now().strftime(config.logging.datetime_format) + temp_strs.append(time_str) + config.logging.unique_string = "_".join(temp_strs) + config.logging.path = os.path.join(config.logging.path_base, config.logging.unique_string) + + # create the log directory + if os.path.exists(config.logging.path): + if config.logging.overwrite: + import shutil + shutil.rmtree(config.logging.path) + os.mkdir(config.logging.path) + else: + raise FileExistsError("Log dir {} exists and can't overwrite!") + else: + os.mkdir(config.logging.path) + + kwargs = {} + kwargs['log_file'] = os.path.join(config.logging.path, "log.txt") + kwargs['log_file_level'] = getattr(logging, config.logging.file_level) + kwargs['log_level'] = getattr(logging, config.logging.console_level) + return kwargs + def init_logger( + config=None, log_file=None, log_file_level=logging.NOTSET, - rotate=False, log_level=logging.INFO, ): + if config is not None: + kwargs = config_logger(config) + logger = init_logger(**kwargs) + return logger + log_format = logging.Formatter("[\033[032m%(asctime)s\033[0m %(levelname)s] %(module)s.%(funcName)s %(message)s") logger = logging.getLogger() logger.setLevel(log_level) @@ -19,13 +84,8 @@ def init_logger( logger.handlers = [console_handler] if log_file and log_file != '': - if rotate: - file_handler = RotatingFileHandler( - log_file, maxBytes=1000000, backupCount=10) - else: - file_handler = logging.FileHandler(log_file) + file_handler = logging.FileHandler(log_file) file_handler.setLevel(log_file_level) file_handler.setFormatter(log_format) logger.addHandler(file_handler) - return logger \ No newline at end of file diff --git a/openprompt/utils/metrics.py b/openprompt/utils/metrics.py index 796c309c..5b201ca3 100644 --- a/openprompt/utils/metrics.py +++ b/openprompt/utils/metrics.py @@ -4,55 +4,60 @@ def classification_metrics(preds: Sequence[int], labels: Sequence[int], - metric_type: Optional[str] = "micro-f1", + metric: Optional[str] = "micro-f1", ) -> float: """evaluation metrics for classification task. Args: preds (Sequence[int]): predicted label ids for each examples labels (Sequence[int]): gold label ids for each examples - type (str, optional): type of evaluation function, support 'micro-f1', 'macro-f1', 'accuracy', 'precision', 'recall'. Defaults to "micro-f1". + metric (str, optional): type of evaluation function, support 'micro-f1', 'macro-f1', 'accuracy', 'precision', 'recall'. Defaults to "micro-f1". Returns: score (float): evaluation score """ - if metric_type == "micro-f1": + if metric == "micro-f1": score = f1_score(labels, preds, average='micro') - elif metric_type == "macro-f1": + elif metric == "macro-f1": score = f1_score(labels, preds, average='macro') - elif metric_type == "accuracy": + elif metric == "accuracy": score = accuracy_score(labels, preds) - elif metric_type == "precision": + elif metric == "precision": score = precision_score(labels, preds) - elif metric_type == "recall": + elif metric == "recall": score = recall_score(labels, preds) else: - raise ValueError("'{}' is not a valid evaluation type".format(metric_type)) + raise ValueError("'{}' is not a valid evaluation type".format(metric)) return score -def generation_metric(hypos, refs, metric_type: Optional[str] = "bleu"): +def generation_metric(hypos, + refs, + metric: Optional[str] = "sentence_bleu"): r"""Some basic metric function for generation. However, many generation tasks has their own evaluation bash scripts. Args: hypos (:obj:`str`) : the generated sentence. refs (:obj:`str`) : the referenced (ground-truth) sentence. - metric_type (:obj:`str`, `optional`) : the type of metric option + metric (:obj:`str`, `optional`) : the type of metric option Returns: score (float): evaluate score """ - # a simple criterion to visualize the performance, not rigorous. - from nltk.translate.bleu_score import sentence_bleu - from nltk.tokenize import word_tokenize - from nltk.translate.bleu_score import SmoothingFunction - smoothie = SmoothingFunction().method4 # a function for smooth - scores = [] - - for ref, hypo in zip(refs, hypos): - ref = word_tokenize(ref) - hypo = word_tokenize(hypo) - scores.append(sentence_bleu([ref], hypo, smoothing_function=smoothie)) - score = sum(scores)/len(scores) - return score \ No newline at end of file + if metric == "sentence_bleu": + # a simple criterion to visualize the performance, not rigorous. + from nltk.translate.bleu_score import sentence_bleu + from nltk.tokenize import word_tokenize + from nltk.translate.bleu_score import SmoothingFunction + smoothie = SmoothingFunction().method4 # a function for smooth + scores = [] + + for ref, hypo in zip(refs, hypos): + ref = word_tokenize(ref) + hypo = word_tokenize(hypo) + scores.append(sentence_bleu([ref], hypo, smoothing_function=smoothie)) + score = sum(scores)/len(scores) + return score + else: + raise ValueError("'{}' is not a valid metric type.".format(metric)) \ No newline at end of file diff --git a/openprompt/utils/utils.py b/openprompt/utils/utils.py index 428678bd..c97515bc 100644 --- a/openprompt/utils/utils.py +++ b/openprompt/utils/utils.py @@ -1,8 +1,13 @@ from math import ceil +import os +import shutil from typing import List import inspect from collections import namedtuple + +import torch from yacs.config import CfgNode +import dill from openprompt.utils.logging import logger @@ -67,15 +72,35 @@ def check_config_conflicts(config: CfgNode): if config.task == "generation": assert config['train'].teacher_forcing == True, "You should use teacher forcing to train generation!" - if config.task == "generation": - if config.dataloader.predict_eos_token == False: - logger.warning("You are recommended to predict eos token in generation training!") - if config.task == "generation": if config.dataloader.max_seq_length >= config.generation.max_length: logger.warning("In generation, your config.generation.max_length is shorter than config.max_seq_length" "This can lead to unexpected behavior. You should consider increasing ``config.generation.max_length``." ) raise RuntimeError + +def save_checkpoint(state_dict, is_best, save_path, filename='checkpoint.pt'): + r"""save the checkpoint to :obj:`save_path`. + """ + full_file_path= os.path.join(save_path, filename) + logger.info("Saving the lastest checkpoint.") + torch.save(state_dict, full_file_path, pickle_module=dill) + if is_best: + full_best_path= os.path.join(save_path, 'best.'+filename) + logger.info("Saving the best checkpoint.") + shutil.copyfile(full_file_path, full_best_path) +def load_checkpoint(load_path, load_best, filename="checkpoint.pt", map_location=None): + r"""load the checkpoint from :obj:`load_path`. + """ + if load_best: + full_file_path= os.path.join(load_path, "best."+filename) + logger.info("Loading the best checkpoint.") + else: + full_file_path= os.path.join(load_path, filename) + logger.info("Loading the latest checkpoint.") + state_dict = torch.load(full_file_path,pickle_module=dill, map_location=map_location) + return state_dict + +