In [1]:
import argparse
import configparser
import itertools
import json
import logging
import os
from collections import defaultdict
import torch
from torch.utils import data
from torch.utils.data import DataLoader

import transformers
from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, AutoModelForSeq2SeqLM, EncoderDecoderModel, \
    BertConfig, EncoderDecoderConfig, Trainer
from transformers import T5ForConditionalGeneration, T5Config, default_data_collator

from arguments import ModelArguments, DataTrainingArguments, TrainingArguments
from datasets import load_dataset
from evaluate import evaluate, get_avg_results, print_results
from utils import get_episode_indices

os.environ["TOKENIZERS_PARALLELISM"] = "false"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('job')
parser.add_argument('-c', '--config_file', type=str, default='config.ini', help='configuration file')
parser.add_argument('-e', '--eval', action='store_true', default=False, help='run evaluation only')
parser.add_argument('--evaluate_checkpoints', action='store_true', default=False,
                    help='evaluate intermediate checkpoints instead of the final model')
parser.add_argument('--evaluate_last_checkpoint', action='store_true', default=False,
                    help='evaluate the last intermediate checkpoint instead of the final model')
parser.add_argument('--evaluate_checkpoint_in_dir', type=str, default=None,
                    help='evaluate the checkpoint in the given directory')
parser.add_argument('-a', '--evaluate_all', action='store_true', default=False,
                    help='evaluate intermediate checkpoints together with the final model')
parser.add_argument('-g', '--gpu', type=int, default=0, help='which GPU to use for evaluation')
parser.add_argument('-v', '--verbose_results', action='store_true', default=False,
                    help='print results for each evaluation run')
args, remaining_args = parser.parse_known_args()

# read config file
config = configparser.ConfigParser(allow_no_value=False)
config.read(args.config_file)
job = args.job
assert job in config

# set defaults for other arguments
defaults = {
    'overwrite_output_dir': True,
    'overwrite_cache': True,
    'per_device_eval_batch_size': 4,
    'learning_rate': 5e-4,
    'logging_steps': 1,  # do not log by default = 'logging_steps': 0
    'save_steps': 0,  # do not save checkpoints by default
}

# the config file gives default values for the command line arguments
defaults.update(dict(config.items(job)))
for key in defaults:
    if defaults[key] in ['True', 'False']:
        # interpret True/False as boolean
        defaults[key] = config.getboolean(job, key)
    if defaults[key] == 'None':
        # interpret as None
        defaults[key] = None

if args.eval:
    # run evaluation only
    defaults['do_train'] = False

# parse remaining arguments and divide them into three categories
second_parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
second_parser.set_defaults(**defaults)
model_args, data_args, training_args = second_parser.parse_args_into_dataclasses(remaining_args)

try:
    os.mkdir(training_args.output_dir)
except FileExistsError:
    pass

# process arguments related to max length
if data_args.max_output_seq_length_eval is None:
    # defaults first to max_output_seq_length, then max_seq_length_eval, then max_seq_length
    data_args.max_output_seq_length_eval = data_args.max_output_seq_length \
                                           or data_args.max_seq_length_eval \
                                           or data_args.max_seq_length

if data_args.max_output_seq_length is None:
    # defaults to max_seq_length
    data_args.max_output_seq_length = data_args.max_seq_length

if data_args.max_seq_length_eval is None:
    # defaults to max_seq_length
    data_args.max_seq_length_eval = data_args.max_seq_length

if data_args.chunk_size_eval is None:
    # defaults to chunk_size
    data_args.chunk_size_eval = data_args.chunk_size

if data_args.chunk_overlap_eval is None:
    # defaults to chunk overlap
    data_args.chunk_overlap_eval = data_args.chunk_overlap

# construct name for the output directory
# for example: conll04-t5-base-ep200-len256-ratio0-b4-train
if data_args.exp:
    output_dir = os.path.join(
        training_args.output_dir,
        f'{args.job}'
        f'-{model_args.model_name_or_path.split("/")[-1]}'
        f'-{data_args.exp}'
        f'-ep{round(training_args.num_train_epochs)}'
        f'-len{data_args.max_seq_length}'
    )
else:
    output_dir = os.path.join(
        training_args.output_dir,
        f'{args.job}'
        f'-{model_args.model_name_or_path.split("/")[-1]}'
        f'-ep{round(training_args.num_train_epochs)}'
        f'-len{data_args.max_seq_length}'
    )

if data_args.max_output_seq_length != data_args.max_seq_length:
    output_dir += f'-{data_args.max_output_seq_length}'

if training_args.learning_rate != 5e-4:
    output_dir += f'-lr{training_args.learning_rate}'

output_dir += f'-b{training_args.per_device_train_batch_size}' \
              f'-{data_args.train_split}'

if data_args.chunk_size != 128:
    output_dir += f'-chunk{data_args.chunk_size}'
if data_args.chunk_overlap != 64:
    output_dir += f'-overlap{data_args.chunk_overlap}'

if data_args.output_format is not None:
    output_dir += f'-{data_args.output_format}'
if data_args.input_format is not None:
    output_dir += f'-{data_args.input_format}'
if data_args.train_subset < 1:
    output_dir += f'-size{data_args.train_subset:.2f}'
if data_args.output_format_type is not None:
    output_dir += f'-{data_args.output_format_type}'
if data_args.comment is not None:
    output_dir += f'-{data_args.comment}'

try:
    os.mkdir(output_dir)
except FileExistsError:
    pass

# setup logging
logging.basicConfig(
    filename=os.path.join(output_dir, 'logs.log'),
    format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
)
logging.getLogger().addHandler(logging.StreamHandler())

# construct file name for the evaluation results
evaluation_output_filename = f'results'
if data_args.num_beams is not None:
    evaluation_output_filename += f'-{data_args.num_beams}beams'
if data_args.max_seq_length_eval is not None:
    evaluation_output_filename += f'-len{data_args.max_seq_length_eval}'

# create model config
config = AutoConfig.from_pretrained(
    model_args.config_name if model_args.config_name else model_args.model_name_or_path,
    cache_dir=model_args.cache_dir,
)

# create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
)

# get list of dataset names
dataset_names = data_args.datasets.split(',')

# construct list of episode indices
episode_indices = get_episode_indices(data_args.episodes)