In [None]:
from rosemary import jpt_setup; jpt_setup()
import os; os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:

import logging
import math
from pathlib import Path
import os
import sys
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional
import json
import numpy as np
import pickle
from functools import partial

import pyarrow
import datasets
import evaluate
import torch
from datasets import load_dataset, IterableDataset

import numpy as np

import transformers
from transformers import (
    CONFIG_MAPPING,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    HfArgumentParser,
    Trainer,
    DataCollatorForLanguageModeling,
    is_torch_tpu_available,
    set_seed,
)
from transformers.testing_utils import CaptureLogger
from transformers.trainer_utils import get_last_checkpoint
from transformers.trainer_callback import TrainerState
from transformers.trainer import TRAINER_STATE_NAME
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from doremi.training_args import ModelArguments, DataTrainingArguments, FullTrainingArguments
import doremi.dataloader as data_utils
from doremi.trainer import DoReMiTrainer
from doremi.dataloader import determine_skip_per_domain
from doremi.dataloader import interleave_datasets


try:
    import doremi.models as doremi_models
except Exception:
    
    pass
try:
    from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config
except Exception:
    pass


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.27.0")

require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

logger = logging.getLogger(__name__)

In [3]:
package_dir = "/gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/doremi"
cache_dir = '/gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/doremi/cache'
preprocessed_data = "/gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/doremi/data/processed"

envs = {
    "CACHE": cache_dir,
    "DOREMI_DIR": package_dir,
    "PILE_DIR": os.path.join(package_dir, "data", 'raw'),
    "PREPROCESSED_PILE_DIR": preprocessed_data,
    "MODEL_OUTPUT_DIR": os.path.join(package_dir, 'results'),
    "PARTITION": "el8",
    "HF_HOME": cache_dir,
    "TRANSFORMERS_CACHE": cache_dir,
    "HF_DATASETS_CACHE": cache_dir,
    "HF_DATASETS_IN_MEMORY_MAX_SIZE": "0",
    "TORCH_EXTENSIONS_DIR": cache_dir,
    "TMPDIR": cache_dir,
    "WANDB_DIR": os.path.join(cache_dir, "wandb"),
    "PREPROCESSED_DATA": preprocessed_data,
    'PREPROCESSED_CACHE': os.path.join(cache_dir, 'preprocessed_cache', 'perdomain_pile_preprocessed'),

}

for k, v in envs.items():
    os.environ[k] = v
    
os.makedirs(cache_dir, exist_ok=True)


In [4]:
# populate domain weight config
import json
domain_config_path = os.path.abspath('../configs/humanmix_baseline_50kvocab.json')
domain_weights = {"cot": .25, "flan_v2": .25, "dolly": .25, "oasst1": .25}
domain_weights = {'cot': 0.5, 'flan_v2': 0.25, 'dolly': 0.12, 'oasst1': 0.13}

domain_config = {"train_domain_weights": domain_weights, "eval_domain_weights": domain_weights}
with open(domain_config_path, 'w') as f:
    json.dump(domain_config, f)

In [5]:
job_name = 'train_baseline'

nodes = 1
num_gpus = 1

model_name_or_path = 'gpt2'; model_type = 'gpt2'
cache_dir = envs['CACHE']
domain_config_path = os.path.abspath('../configs/humanmix_baseline_50kvocab.json')
output_dir = os.path.join(envs['MODEL_OUTPUT_DIR'], job_name)
dataset_dir = envs['PREPROCESSED_CACHE']
dataset_dir = preprocessed_data
 
total_batch_size = 128 # # 64*8=512
per_device_train_batch_size = 2
gradient_accumulation_steps = 1
gradient_accumulation_steps = int(total_batch_size/(num_gpus*nodes)/per_device_train_batch_size)

max_steps = 200000; save_steps = 5 # 200k steps.

# use `dataset_dir` instead of `dataset_name` to specify `preprocessed_dir`
# --dataset_name=pile \

## learning rate for pretraining, substituted with finetuning hyperparameters
# --learning_rate 1e-3 \
# --lr_end 1e-4 \
# --adam_epsilon 1e-8 \

## don't need cosine scheduling for finetuning
# --weight_decay 0.01 \
# --lr_scheduler_name linear_warmup_cosine \
# --warmup_ratio 0.06 \

## avoids grad scaling error
# --fp16 \
## for training model from scratch
# --config_overrides="n_positions=1024,n_embd=1024,n_layer=18,n_head=16" \

## added the following
# add_domain_id: for non-pile preprocessed dataset
# do_padding: true for variable size sequences, as in instruction tuning datasets.
# --max_train_samples 1000 \

reference_model_name_or_path = '/gpfs/u/home/PTFM/PTFMqngp/scratch/github/mitibm2023/external/doremi/results/train_baseline/checkpoint-10'
doremi_options = f"""
--doremi_optimizer=doremiv1 \
--reweight_eta=1 \
--reweight_eps=1e-4 \
--train_domain_weights_tmp_file={os.path.join(output_dir, 'domain_weights')} \
--reweight_domains \
--remove_unused_columns=False \
--reference_model_name_or_path={reference_model_name_or_path} \
"""
doremi_options = ''


cmd = f"""
--model_name_or_path={model_name_or_path} \
--model_type={model_type} \
--tokenizer_name=gpt2 \
--do_train \
--cache_dir={cache_dir} \
--dataset_dir={dataset_dir} \
--domain_config_path={domain_config_path} \
--max_token_length=1024 \
--per_device_train_batch_size={per_device_train_batch_size} \
--gradient_accumulation_steps={gradient_accumulation_steps} \
--dataloader_num_workers=1 \
--learning_rate=2e-5 \
--lr_scheduler_type=linear \
--warmup_ratio=0.03 \
--weight_decay=0. \
--max_grad_norm=1.0 \
--max_steps={max_steps} \
--evaluation_strategy=no \
--save_strategy=steps \
--save_steps={save_steps} \
--save_total_limit=1 \
--run_name={job_name} \
--seed=1111 \
--logging_strategy=steps \
--logging_steps=10 \
--logging_first_step \
--report_to=tensorboard \
--optim=adamw_hf \
--adam_beta1=0.9 \
--adam_beta2=0.99 \
--add_domain_id=True \
--do_padding=True \
{doremi_options if doremi_options else ''} \
--output_dir={output_dir} \
"""
# --overwrite_output_dir \

import shlex
args = shlex.split(cmd)

parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FullTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses(args)
model_args, data_args, training_args

(ModelArguments(model_name_or_path='gpt2', model_type='gpt2', config_overrides=None, config_name=None, tokenizer_name='gpt2', cache_dir='/gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/doremi/cache', use_fast_tokenizer=True, model_revision='main', use_auth_token=False, torch_dtype=None),
 DataTrainingArguments(dataset_dir='/gpfs/u/scratch/PTFM/PTFMqngp/github/mitibm2023/external/doremi/data/processed', dataset_name='', max_train_samples=None, max_eval_samples=None, max_token_length=1024, block_size=None, overwrite_cache=False, do_padding=True, add_domain_id=True, preprocessing_num_workers=None, shuffle=True),

In [None]:

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
)

if training_args.should_log:
    # The default of training_args.log_level is passive, so we set log level at info here to have that default.
    transformers.utils.logging.set_verbosity_info()
    

log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
datasets.utils.logging.set_verbosity(log_level)
transformers.utils.logging.set_verbosity(log_level)
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()


# Log on each process the small summary:
logger.warning(
    f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
    + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Training/evaluation parameters {training_args}")

In [None]:

# Detecting last checkpoint.
last_checkpoint = None
num_skip_examples = 0
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
    last_checkpoint = get_last_checkpoint(training_args.output_dir)
    if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. "
            "Use --overwrite_output_dir to overcome."
        )
    elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
        logger.info(
            f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
            "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
        )
        state = TrainerState.load_from_json(str(Path(last_checkpoint) / TRAINER_STATE_NAME))
        global_batch_size = training_args.train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
        num_skip_examples = state.global_step * global_batch_size
        logger.info(f"Skipping {num_skip_examples} examples")
        
last_checkpoint, num_skip_examples

In [None]:

# Set seed before initializing model.
set_seed(training_args.seed)

In [None]:

# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

config_kwargs = {
    "cache_dir": model_args.cache_dir,
    "revision": model_args.model_revision,
    "use_auth_token": True if model_args.use_auth_token else None,
}
if model_args.config_name:
    config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
elif model_args.model_name_or_path:
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
    if model_args.model_type == 'gpt_neox_flash':
        config = gpt_neox_config_to_gpt2_config(config)
        config.use_flash_attn = True
        config.fused_mlp = True
        config.fused_bias_fc = True
        config.fused_dropout_add_ln = True
        config.pad_vocab_size_multiple = 8
        config.activation_function = 'gelu_new'
        config.n_inner = None
        # disable absolute
        config.max_position_embeddings = 0
else:
    if model_args.model_type == 'gpt_flash': 
        config = GPT2Config(
                vocab_size=50257, n_positions=2048, n_embd=2048,
                n_layer=24, n_head=16, 
                scale_attn_by_inverse_layer_idx=True, 
                rotary_emb_fraction=0.5,
                use_flash_attn=True, fused_mlp=True,
                fused_bias_fc=True, fused_dropout_add_ln=True, 
                pad_vocab_size_multiple=8)
        # disable absolute
        config.max_position_embeddings = 0
    elif model_args.model_type == 'gpt_neox_flash':
        # convert to GPT2 config
        config = CONFIG_MAPPING['gpt_neox']() 
        config = gpt_neox_config_to_gpt2_config(config)
        config.use_flash_attn = True
        config.fused_mlp = True
        config.fused_bias_fc = True
        config.fused_dropout_add_ln = True
        config.pad_vocab_size_multiple = 8
        config.activation_function = 'gelu_new'
        config.n_inner = None
        # disable absolute
        config.max_position_embeddings = 0
    else:
        config = CONFIG_MAPPING[model_args.model_type]()
    logger.warning("You are instantiating a new config instance from scratch.")
    if model_args.config_overrides is not None:
        logger.info(f"Overriding config: {model_args.config_overrides}")
        config.update_from_string(model_args.config_overrides)
        logger.info(f"New config: {config}")


In [None]:

tokenizer_kwargs = {
    "cache_dir": model_args.cache_dir,
    "use_fast": model_args.use_fast_tokenizer,
    "revision": model_args.model_revision,
    "use_auth_token": True if model_args.use_auth_token else None,
}

if model_args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)

elif model_args.model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
else:
    raise ValueError(
        "You are instantiating a new tokenizer from scratch. This is not supported by this script."
        "You can do it from another script, save it, and load it from here, using --tokenizer_name."
    )
    
tokenizer

In [None]:

if model_args.model_name_or_path:
    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    if model_args.model_type in {'gpt_flash', 'gpt_neox_flash'}:
        model = doremi_models.GPTFlashAttnLMHeadModel.from_pretrained(
            model_args.model_name_or_path, config=config)
    elif model_args.model_type in ['gpt2']:
        model = doremi_models.GPT2LMHeadModelDoReMi.from_pretrained(
            model_args.model_name_or_path, 
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            torch_dtype=torch_dtype,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            torch_dtype=torch_dtype,
        )
else:
    if model_args.model_type in {'gpt_flash', 'gpt_neox_flash'}:
        model = doremi_models.GPTFlashAttnLMHeadModel(config)
    elif model_args.model_type in {'gpt2'}:
        model = doremi_models.GPT2LMHeadModelDoReMi(config)
    else:
        model = AutoModelForCausalLM.from_config(config)

    n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
    logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
    
model

In [None]:
# from typing import Tuple, Union
# from transformers import GPT2LMHeadModel
# # from doremi.models import CausalLMOutputWithDomainIDs
# from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
# from torch.nn import CrossEntropyLoss

# @dataclass
# class CausalLMOutputWithDomainIDs(CausalLMOutputWithCrossAttentions):
#     domain_ids: Optional[torch.LongTensor] = None
#     reference_pertoken_loss: Optional[torch.FloatTensor] = None  # corresponds to uniq_domain_ids
#     pertoken_loss: Optional[torch.FloatTensor] = None  # corresponds to uniq_domain_ids
#     token_mask: Optional[torch.BoolTensor] = None  # 1 for tokens that are not padding



# #             model_args.model_name_or_path,
# #             from_tf=bool(".ckpt" in model_args.model_name_or_path),
# #             config=config,
# #             cache_dir=model_args.cache_dir,
# #             revision=model_args.model_revision,
# #             use_auth_token=True if model_args.use_auth_token else None,
# #             torch_dtype=torch_dtype,

# # model = doremi_models.GPTFlashAttnLMHeadModel.from_pretrained(
# #     model_args.model_name_or_path, config=config)

        
# model = GPTLMHeadModelDoReMi.from_pretrained(
#     model_args.model_name_or_path, config=config)


In [None]:
model.device, batch['input_ids'].device

In [None]:
out = model(**batch, return_pertoken_losses=True, )

In [None]:
out.pertoken_loss.shape, out.token_mask.shape

In [None]:
batch['labels']

In [None]:
collate_fn = data_utils.get_data_collator(tokenizer, do_padding=data_args.do_padding)

batch = list(train_dataset.take(3))
batch = collate_fn(batch)
batch = {k: v.to('cuda') for k,v in batch.items()}
batch['input_ids'].shape

In [None]:

with open(training_args.domain_config_path, 'r') as f:
    domain_config = json.load(f)

train_domain_weights_dict = domain_config['train_domain_weights']
eval_domain_weights_dict = domain_config['eval_domain_weights']
# whenever we convert dict to array, we sort by key
domain_list = list(sorted(train_domain_weights_dict.keys()))
num_domains = len(domain_list)

print(domain_list, num_domains, train_domain_weights_dict)

if training_args.do_train:
    # data script could change tokenizer shape
    train_dataset = data_utils.get_preprocessed_mixed_dataset(
            preprocessed_dir=data_args.dataset_dir,
            domain_weights_dict=train_domain_weights_dict,
            dataset_name=data_args.dataset_name,
            cache_dir=model_args.cache_dir,
            split='train',
            max_samples=data_args.max_train_samples,
            add_domain_id=data_args.add_domain_id,
            tmp_file=None,
            seed=training_args.seed,
            tokenizer=tokenizer,
            shuffle=data_args.shuffle,
            num_skip_examples=num_skip_examples,
            shard_reversal=training_args.reweight_domains,
            training_args=training_args,
    )

if training_args.do_eval:
    eval_dataset = data_utils.get_preprocessed_mixed_dataset(
            preprocessed_dir=data_args.dataset_dir,
            domain_weights_dict=eval_domain_weights_dict,
            dataset_name=data_args.dataset_name,
            cache_dir=model_args.cache_dir,
            split='validation',
            add_domain_id=data_args.add_domain_id,
            max_samples=data_args.max_eval_samples,
            tokenizer=tokenizer,
            no_interleave=True,
            training_args=training_args,
    )

In [None]:

# preprocessed_dir=data_args.dataset_dir
# domain_weights_dict=train_domain_weights_dict
# dataset_name=data_args.dataset_name
# cache_dir=model_args.cache_dir
# split='train'
# max_samples=data_args.max_train_samples
# add_domain_id=data_args.add_domain_id
# tmp_file=None
# seed=training_args.seed
# tokenizer=tokenizer
# shuffle=data_args.shuffle
# num_skip_examples=num_skip_examples
# shard_reversal=training_args.reweight_domains
# no_interleave=False

# print(preprocessed_dir)
# print(domain_weights_dict)
# print(dataset_name)
# print(cache_dir)
# print(split)
# print(max_samples)
# print(add_domain_id)
# print(seed)
# print(shuffle)
# print(num_skip_examples)
# print(shard_reversal)


# domain_names = list(sorted(domain_weights_dict.keys()))
# domain_to_idx = {domain_names[i]: i for i in range(len(domain_names))}
# domain_weights = np.asarray([domain_weights_dict[domain_name] for domain_name in domain_names])
# domain_weights = domain_weights / domain_weights.sum()

# print()
# print(json.dumps({'domain_names': domain_names, 
#                   'domain_to_idx': domain_to_idx, 
#                   'domain_weights': list(domain_weights)},
#                 indent=4))


# # write domain weights to file if tmp_file is set
# if tmp_file is not None:
#     probabilities_tmp_file = tmp_file

#     with open(str(probabilities_tmp_file), 'wb') as f:
#         pickle.dump(domain_weights, f)
#     probabilities = None
# else:
#     probabilities = domain_weights
#     probabilities_tmp_file = None


# print()
# print(json.dumps({'probabilities': list(probabilities)}, indent=4))

# # from doremi.dataloader import get_perdomain_datasets
# # all_ds = get_perdomain_datasets(
# #     preprocessed_dir, 
# #     domain_weights_dict,
# #     cache_dir=cache_dir,
# #     split=split,
# #     seed=seed,
# #     domain_weights=domain_weights,
# #     domain_names=domain_names,
# #     num_skip_examples=num_skip_examples,
# #     shuffle=shuffle,
# #     shard_reversal=shard_reversal
# # )

# domain_name_to_skip_num = determine_skip_per_domain(num_skip_examples, seed, domain_weights, domain_names)

# preprocessed_dir = Path(preprocessed_dir)
# if split is not None and (preprocessed_dir / split).exists():
#     preprocessed_dir = preprocessed_dir / split
# else:
#     logger.warn(f"No split used or split directory not found: using same data for all splits.")

# domains = list(sorted(domain_weights_dict.keys()))

# print(preprocessed_dir)
# print(domain_name_to_skip_num)
# print()
# print(json.dumps({'preprocessed_dir': str(preprocessed_dir), 
#                   'domain_name_to_skip_num': domain_name_to_skip_num}, indent=4))


# all_ds = {}
# for domain in domains:
#     domain_dir = preprocessed_dir / domain
    
#     ## wpq: read instruction tuning dataset off `jsonl` files
#     if (domain_dir / f'{domain}_data.jsonl').exists():
#         from datasets import load_dataset
#         from functools import partial
#         from open_instruct.finetune_trainer import encode_with_prompt_completion_format, encode_with_messages_format
#         from doremi.dataloader import skippable_data_gen_dataset

#         data_files = {'train': str(domain_dir / f'{domain}_data.jsonl')}
#         raw_datasets = load_dataset(
#             "json",
#             data_files=data_files,
#             cache_dir=cache_dir,
#             use_auth_token=True if model_args.use_auth_token else None,
#         )
#         # Preprocessing the datasets.
#         if "prompt" in raw_datasets["train"].column_names and "completion" in raw_datasets["train"].column_names:
#             encode_function = partial(
#                 encode_with_prompt_completion_format,
#                 tokenizer=tokenizer,
#                 max_seq_length=1024,
#             )
#         elif "messages" in raw_datasets["train"].column_names:
#             encode_function = partial(
#                 encode_with_messages_format,
#                 tokenizer=tokenizer,
#                 max_seq_length=1024,
#             )
#         else:
#             raise ValueError("You need to have either 'prompt'&'completion' or 'messages' in your column names.")

#         with training_args.main_process_first(local=False, desc="Processing instruction data"):
#             lm_datasets = raw_datasets.map(
#                 encode_function,
#                 num_proc=16,
#                 batched=False,
#             )
#             lm_datasets.set_format(type="pt")
#         ds = lm_datasets['train']
#         ds = IterableDataset.from_generator(
#                 skippable_data_gen_dataset,
#                 gen_kwargs={'ds': ds,
#                             'num_skip_examples': domain_name_to_skip_num[domain],
#                             'loop': (split == 'train'),
#                             'seed': seed,
#                             'shuffle': shuffle}
#                 )
#         seed += 1
#     elif (domain_dir / 'dataset_info.json').exists():
#         ds = load_from_disk(dataset_path=str(domain_dir))
#         logger.info(f"Loaded {domain_dir}. Length: {len(ds)}")
#     else:
#         curr_shards = list(domain_dir.iterdir())
#         if shard_reversal:
#             curr_shards = list(reversed(curr_shards))
#         # shuffle shard order
#         random.Random(seed).shuffle(curr_shards)
#         ds = IterableDataset.from_generator(
#                 skippable_data_gen,
#                 gen_kwargs={'shards': curr_shards,
#                             'num_skip_examples': domain_name_to_skip_num[domain],
#                             'loop': (split == 'train'),
#                             'seed': seed,
#                             'shuffle': shuffle}
#                 )
#         seed += 1
#     all_ds[domain] = ds
    

# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token

# def add_domain_id_generator(ds, domain_idx):
#     for ex in ds:
#         ex['domain_id'] = domain_idx
#         yield ex
        
# domain_ds_ls = []
# for domain_name in domain_names:
#     domain_idx = domain_to_idx[domain_name]
#     domain_ds = all_ds[domain_name]
#     # add domain_id if necessary
#     if add_domain_id:
#         domain_ds = IterableDataset.from_generator(
#             add_domain_id_generator, 
#             gen_kwargs={'ds': domain_ds, 'domain_idx': domain_idx})
#     domain_ds_ls.append(domain_ds)

# if no_interleave:
#     # instead of interleaving, run through each dataset
#     def data_generator(shards):
#         for shard in shards:
#             for ex in shard:
#                 yield ex
#     ds = IterableDataset.from_generator(data_generator, gen_kwargs={'shards': domain_ds_ls})
#     logger.info("Not interleaving dataset - will not sample according to domain weights")

# else:
#     ds = interleave_datasets(
#             domain_ds_ls,
#             probabilities=probabilities,
#             probabilities_file=probabilities_tmp_file,
#             seed=seed)
    

# def take_data_generator(ds, max_samples):
#     idx = 0
#     for ex in ds:
#         yield ex
#         idx += 1
#         if max_samples is not None and idx >= max_samples:
#             return

# ds = IterableDataset.from_generator(take_data_generator, gen_kwargs={'ds': ds, 'max_samples': max_samples})
# train_dataset = ds


In [None]:
# test_ds = load_dataset(
#     "json",
#     data_files='test.jsonl',
#     cache_dir=model_args.cache_dir)['train']
# test_ds[0]

# for x in test_ds.to_iterable_dataset():
#     print(x)
# for i, v in enumerate(ds):
#     if i == 10:
#         break
#     print(v)

In [None]:

if training_args.reweight_domains:
    torch_dtype = (
        model_args.torch_dtype
        if model_args.torch_dtype in ["auto", None]
        else getattr(torch, model_args.torch_dtype)
    )
    if model_args.model_type in {'gpt_flash', 'gpt_neox_flash'}:
        model_cls = doremi_models.GPTFlashAttnLMHeadModel
        reference_model = model_cls.from_pretrained(
            training_args.reference_model_name_or_path,
            config=config)
    elif model_args.model_type in {'gpt2'}:
        model_cls = doremi_models.GPT2LMHeadModelDoReMi
        reference_model = model_cls.from_pretrained(
            training_args.reference_model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            torch_dtype=torch_dtype,
        )
    else:
        model_cls = AutoModelForCausalLM

        reference_model = model_cls.from_pretrained(
            training_args.reference_model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
            revision=model_args.model_revision,
            use_auth_token=True if model_args.use_auth_token else None,
            torch_dtype=torch_dtype,
        )
    for param in reference_model.parameters():
        param.requires_grad = False
    reference_model.eval()
    model.reference_model = reference_model
    model.register_buffer('train_domain_weights', torch.tensor(
            [train_domain_weights_dict[domain] for domain in domain_list]))
    model.register_buffer('avg_domain_weights', model.train_domain_weights.clone())
    model.register_buffer('perdomain_scores', torch.ones(len(train_domain_weights_dict)) * np.log(len(tokenizer)))
    model.register_buffer('update_counter', torch.tensor(1))

else:
    reference_model = None


In [None]:

# turn off find unused parameters
training_args.ddp_find_unused_parameters = False

# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test.
# embedding_size = model.get_input_embeddings.weight.shape[0]
# if len(tokenizer) > embedding_size:
#     model.resize_token_embeddings(len(tokenizer))

torch.cuda.empty_cache()

# Initialize our Trainer
trainer = DoReMiTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_utils.get_data_collator(tokenizer, do_padding=data_args.do_padding),
)

print(trainer.args.max_grad_norm, \
    trainer.sharded_ddp, \
    trainer.args.half_precision_backend, \
    trainer.do_grad_scaling)
trainer

In [None]:

# Training
if training_args.do_train:
    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
    trainer.save_model()  # Saves the tokenizer too for easy upload

    metrics = train_result.metrics

    if training_args.reweight_domains:
        avg_domain_weights_dict = {}
        for i in range(len(model.avg_domain_weights)):
            domain_name = domain_list[i]
            metrics[f'avg_domain_weight:{domain_name}'] = model.avg_domain_weights[i].item()
            avg_domain_weights_dict[domain_name] = model.avg_domain_weights[i].item()

        # save avg domain weights to json
        avg_domain_weights_file = Path(training_args.output_dir) / 'avg_domain_weights.json'
        with open(avg_domain_weights_file, 'w') as f:
            json.dump(avg_domain_weights_dict, f, indent=2)

        # also save to configs dir
        config_dict = {"train_domain_weights": avg_domain_weights_dict,
                       "eval_domain_weights": avg_domain_weights_dict}
        config_dict_file = Path(__file__).parent.parent / 'configs' / f"{Path(training_args.output_dir).name}.json"
        with open(config_dict_file, 'w') as f:
            json.dump(config_dict, f, indent=2)

    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()