diff --git a/mttl/arguments.py b/mttl/arguments.py index cc10e5788..29cd24231 100644 --- a/mttl/arguments.py +++ b/mttl/arguments.py @@ -410,6 +410,9 @@ class TrainingArgs(DataArgs): profile: bool = False # if 'True' will profile the model training + # dpo + rl_training: str = "dpo" + @property def dataset_config(self): if self.dataset_type is not None: diff --git a/mttl/config.py b/mttl/config.py new file mode 100644 index 000000000..79bc8751c --- /dev/null +++ b/mttl/config.py @@ -0,0 +1,352 @@ +import argparse +import ast +import json +import os +from string import Template +from typing import Dict + +from mttl.utils import logger, setup_logging + + +class Config: + def __init__(self, filenames=None, kwargs=None, raise_error=True, silent=False): + # Stores personalization of the config file in a dict (json serializable) + + self._updated_kwargs = {} + self.filenames = filenames + self._set_defaults() + + overwrite_logs = [] + if filenames: + for filename in filenames.split("+"): + if not os.path.exists(filename): + filename = os.path.join( + os.getenv("CONFIG_PATH", default="configs"), filename + ) + + if not os.path.exists(filename) and ".json" not in filename: + filename = filename + ".json" + + overwrite_logs += self.update_kwargs( + json.load(open(filename)), + eval=False, + raise_error=raise_error, + silent=silent, + ) + + if kwargs: + overwrite_logs += self.update_kwargs( + kwargs, raise_error=raise_error, silent=silent + ) + + # setup logging to the output dir + try: + setup_logging(self.output_dir) + except Exception as e: + if raise_error: + raise ValueError("Error setting up logging") from e + elif not silent: + logger.warning(f"Error setting up logging to {self.output_dir}") + + # log the overwrites + for log in overwrite_logs: + logger.warning(log) + + self.post_init(silent=silent) + + def post_init(self, silent=False): + pass + + @classmethod + def fromdict(cls, data): + _ = data.pop("_updated_kwargs", None) + return cls(kwargs=data, raise_error=False, silent=True) + + def asdict(self) -> Dict: + from mttl.models.utils import convert_hps_to_dict + + return convert_hps_to_dict(self.__dict__) + + def was_overridden(self, key): + return key in self._updated_kwargs + + def was_default(self, key): + return key not in self._updated_kwargs + + def update_kwargs(self, kwargs, eval=True, raise_error=True, silent=False): + overwrites_log = [] + for k, v in kwargs.items(): + if eval: + try: + v = ast.literal_eval(v) + except (ValueError, SyntaxError): + v = v + else: + v = v + + if not hasattr(self, k) and raise_error: + raise ValueError(f"{k} is not in the config") + + if eval and not silent: + overwrites_log.append(f"Overwriting {k} to {v}") + + if type(v) == str and "$" in v: + # this raises an error if the env. var does not exist + v = Template(v).substitute(os.environ) + + setattr(self, k, v) + self._updated_kwargs[k] = v + return overwrites_log + + def __getitem__(self, item): + return getattr(self, item, None) + + def to_json(self): + """ + Converts parameter values in config to json + :return: json + """ + import copy + + to_save = copy.deepcopy(self.__dict__) + to_save.pop("_updated_kwargs") + + return json.dumps(to_save, indent=4, sort_keys=False) + + def save_config(self, output_dir): + """ + Saves the config + """ + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, "config.json"), "w+") as fout: + fout.write(self.to_json()) + fout.write("\n") + + @classmethod + def parse( + cls, + extra_kwargs=None, + raise_error=True, + parent=None, + return_parser=False, + c=None, + ): + import itertools + + # dont do it if called from jupyter notebook + if c is None: + parser = ( + argparse.ArgumentParser(parents=[parent]) + if parent + else argparse.ArgumentParser() + ) + parser.add_argument("-c", "--config_files", required=False) + parser.add_argument("-k", "--kwargs", nargs="*", action="append") + args = parser.parse_args() + else: + args = argparse.Namespace() + args.kwargs = None + args.config_files = c + + kwargs = {} + if args.kwargs: + kwargs_opts = list(itertools.chain(*args.kwargs)) + for value in kwargs_opts: + key, _, value = value.partition("=") + + # allows multiple values for a given option when specified in the command line! + if key in kwargs: + if type(kwargs[key]) != list: + kwargs[key] = [kwargs[key]] + kwargs[key].append(value) + else: + kwargs[key] = value + + args.kwargs = kwargs + if extra_kwargs: + args.kwargs.update(extra_kwargs) + + config = cls( + filenames=args.config_files, kwargs=args.kwargs, raise_error=raise_error + ) + + if return_parser: + return config, args + return config + + def _set_defaults(self): + self.cache_dir = os.getenv("CACHE_DIR", "./cache") + + # Data config + self.dataset = None + self.custom_tasks_splits = None + + self.data_dir = os.getenv("TRAIN_DIR", "/tmp/") + self.output_dir = os.getenv("OUTPUT_DIR", "./output") + + self.finetune_task_name = None + self.example_to_ids_path = None # path to clustering of data + self.embeddings_path = None + + # NI related configs + self.use_task_descriptions = False # Use task descriptions + self.max_num_instances_per_task = ( + 100 # Max instances per training task (applies to NI) + ) + self.num_pos_examples = ( + 0 # Use some few-shot examples if possible (applies to NI) + ) + + self.task_prefix = None # xfit has task prefixes detailing # of shots, seed, etc; this is automatically filled in at fine-tuning time + self.exp_name = None + self.wandb_project = None + self.padding_side = "right" + self.truncation_side = "right" + self.max_input_length = 512 + self.max_output_length = 64 + self.num_beams = 4 + self.append_another_bos = False + self.do_lowercase = False + self.freeze_embeds = False + + # T0 related configs + self.use_t0_templates_as_tasks = ( + False # if True, then t0 consists of 313 tasks, otherwise 38 + ) + self.use_t0_few_shot_training_set = False # if True, then use 100 examples per task during training + 100 examples per validation task + + # Filtering configs, useful for flan flat, etc. + self.include_template_type = "zs_noopt" + self.include_task_source = "P3,Flan2021,CoT" + self.remove_phi_eval_tasks = False + + # Training config + self.compute_strategy = None + self.padding_side = "right" + self.scheduler = "linear_decay_with_warmup" + self.checkpoint = None # load from checkpoint + self.checkpoint_step = None # load from checkpoint in format of global_stepX.pt + self.backbone_checkpoint = None # load the backbone from here + self.train_batch_size = 8 + self.predict_batch_size = 32 + self.learning_rate = 1e-3 + self.warmup_proportion = 0.06 + self.trainable_param_names = ".*" + self.non_trainable_param_names = None + self.weight_decay = 0.01 + self.adam_epsilon = 1e-8 + self.max_grad_norm = 0.1 + self.gradient_accumulation_steps = 1 + self.optimizer = "adamw" + self.adafactor_scale_parameter = True + self.adafactor_warmup_init = False + self.adafactor_relative_step = False + self.num_train_epochs = -1 + self.warmup_steps = -1 + self.total_steps = -1 + self.num_tasks_per_batch = None + self.save_every = None + self.eval_every = None + self.eval_every_n_epoch = 1 + self.debug = False + self.seed = 42 + self.eval_before_training = True + + self.subsample_train = None + self.subsample_dev = None + self.subsample_test = None + self.subsample_per_task = False + + self.ni_online_eval = False # zero-shot online eval for ni + self.t0_online_eval = False # zero-shot eval for t0 + self.early_stop_on_zero_shot = False # zero-shot early stopping + + # auxiliary losses + self.ortho_loss = 0.0 # orthogonality between skills + self.task_loss = 0.0 # task prediction loss (mi between tasks and skills) + self.l1_loss = 0.0 # sparsity of the logits + self.mi_loss = ( + 0.0 # mi between tasks and skills (difference of entropies method) + ) + self.mc_loss = 0.0 # T-Few + self.length_norm = 0.0 # T-Few + self.unlikely_loss = 0.0 # T-Few + self.poly_unlikely_loss = 0.0 # poly unlikelihood loss + self.finetune_type = None # ["F", "A", "Z", "MuZ", "Poly", "PolyRand"] + self.finetune_skip_es = False # skip early stopping while fine-tuning + self.finetune_use_last_checkpoint = ( + False # use always the best valid_perf checkpoint if available + ) + + self.model = None + self.model_family = None # model family, either "gpt" or "encdec" + + self.precision = "32" + self.monitor_grad_alignment_on = None + + self.model_modifier = None + self.adapter_type = None + + self.lora_rank = 16 + self.lora_dropout = 0.0 + self.lora_init_scale = 0.01 + self.lora_alpha = 1.0 + self.lora_warmup = False + self.lora_init_b_random = False + self.lora_dropout = 0.0 + + # n-skills for router-based methods + self.n_skills = 8 + self.n_tasks = None + self.task_names = None + + # which modules to modify and which layers for adapters + self.modify_modules = None + self.modify_layers = None + self.tie_params = None # "q_proj\\.lora_a|k_proj\\.lora_a|v_proj\\.lora_a" to share lora_a for q,k,v + + """ + router_granularity : how granular is the module selection + coarsegrained : 1 single selector across all linear layers + coderwise : 2 selectors (1 for encoder, 1 for decoder) + blockwise : 1 selector for each block of K attention layers (and layernorm) + layerwise : 1 selector for each attention layer (and layernorm) + finegrained : 1 selector for every linear layer + """ + self.router_granularity = "finegrained" # router granularity + self.router_selector = None # router selector + self.router_weight_decay = None # router weight decay + self.router_learning_rate = None + + # Polytropon related hyper-parameters + self.n_splits = 1 # number of splits for poly-s + self.router_selector_cluster_temp = 1.0 # temperature for the cluster selector + self.poly_average_correction = False # correct the poly average + self.poly_use_shared_skill = False # use one skill shared by all tasks + self.skip_unseen_tokens = ( + True # skip unseen tokens in PerTokenPoly during evaluation + ) + + self.module_logits_relaxed_bernoulli = True + self.module_logits_straight_through = False + self.module_logits_learning_rate = 0.1 + self.adapters_learning_rate = None + self.adapters_weight_decay = None + self.module_logits_dropout = 0.0 + self.module_logits_l2_norm = False + + self.augment_mmlu: bool = False + + # soft prompts + self.soft_prompt_length: int = 10 + self.patch_last_k_layers: int = -1 + self.soft_prompt_mlp_dim: int = None + self.soft_prompt_hidden_dim: int = None + self.soft_prompt_learn_kv: bool = False + self.prompt_placement: str = "prefix" + self.add_routing_token: bool = False + + # rl training + self.rl_training = "dpo" + self.beta = 0.5 diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index dc9dcfab4..4747c752f 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -968,6 +968,12 @@ def get_datamodule(args, for_generation=False, dataset_override=None): WinograndeMultiChoiceDataModule, ) + from mttl.datamodule.ultrafeedback_data_module import ( + UltrafeedbackSFTmodule, + ) + + from mttl.datamodule.orca_data_module import OrcaDataModule + # if we have a DataArgs object, we can directly create the datamodule if isinstance(args, DataArgs) and args.dataset_type is not None: dataset_config = args.dataset_config @@ -1063,6 +1069,16 @@ def get_datamodule(args, for_generation=False, dataset_override=None): pack_sequences=args.pack_sequences, ) dm = FlatMultiTaskModule(config, for_generation=for_generation) + elif "ultrachat" in dataset: + config = DatasetConfig( + **common_kwargs, + ) + dm = UltrafeedbackSFTmodule(config, for_generation=for_generation) + elif "orca" in dataset: + config = DatasetConfig( + **common_kwargs, + ) + dm = OrcaDataModule(config, for_generation=for_generation) elif "mmlu" in dataset: config = MMLUDataConfig( **common_kwargs, diff --git a/mttl/datamodule/orca_data_module.py b/mttl/datamodule/orca_data_module.py new file mode 100644 index 000000000..aa6248034 --- /dev/null +++ b/mttl/datamodule/orca_data_module.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass + +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DataModule +from mttl.models.library.dataset_library import DatasetLibrary +import ast + +# @dataclass +# class OrcaDataModuleCollator(DefaultCollator): +# def __call__(self, batch): +# sources = [] +# labels = [] +# for item in batch: +# if type(item["messages"]) == str: +# item["messages"] = ast.literal_eval(item["messages"]) + +# sources.append("You are a helpful assistant.") +# labels.append( +# self.tokenizer.apply_chat_template(item["messages"], tokenize=False) +# ) +# output_batch = ( +# self.prepare_inputs_for_gpt_family(sources, labels) +# if self.model_family == "gpt" +# else self.prepare_inputs_for_seq2seq_family(sources, labels) +# ) + +# output_batch["sources_texts"] = sources +# output_batch["labels_texts"] = labels + +# return output_batch + + +@dataclass +class OrcaDataModule(DataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + train_dataset = DatasetLibrary.pull_dataset_with_retry( + "zhan1993/orca_sqs_dataset" + )["train"] + if self.config.finetune_task_name: + train_dataset = train_dataset.filter( + lambda example: example["task_name"] == self.config.finetune_task_name + ) + self.train_dataset, self.dev_dataset = self.create_train_valid_split( + train_dataset, 0.1 + ) + self.test_dataset = self.dev_dataset + + self.print_infos() + + @property + def collate_fn(self): + return DefaultCollator( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +if __name__ == "__main__": + config = DatasetConfig( + model="microsoft/Phi-3-mini-4k-instruct", finetune_task_name="creative_content" + ) + from transformers import AutoTokenizer + + # <|assistant|> + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + message = [ + { + "content": "Create an outline for a presentation that includes five engaging and informative slides to test the audience's understanding of the concepts discussed in your talk, such as p-value functions, compatibility intervals, and the misconceptions about p-values. Each slide should present a multiple-choice question related to these topics. Prepare talking points that explain the correct answers and why they are correct, ensuring the explanations are clear and suitable for an educational presentation.", + "role": "user", + }, + { + "content": "Title: Understanding P-Values and Statistical Inference\n\nSlide 1: Introduction to P-Values\n- Multiple-Choice Question: What does a p-value indicate in hypothesis testing?\n A) The probability that the null hypothesis is true\n B) The probability of observing the data, or something more extreme, if the null hypothesis is true\n C) The probability that the alternative hypothesis is true\n D) The likelihood of making a Type I error\n\nTalking Points:\n- The correct answer is B. The p-value is the probability of observing the data, or something more extreme, if the null hypothesis is true. It is not a direct statement about the truth of the null or alternative hypothesis (A and C are incorrect). While it is related to the risk of a Type I error, it is not the likelihood of making this error (D is incorrect), but rather a measure of evidence against the null hypothesis given the data.\n\nSlide 2: Understanding Compatibility Intervals\n- Multiple-Choice Question: What does a 95% compatibility interval (CI) for a parameter estimate indicate?\n A) 95% of the data falls within this interval\n B) There is a 95% chance that the interval contains the true parameter value\n C) If we were to repeat the experiment many times, 95% of the calculated intervals would contain the true parameter value\n D) The parameter has a 95% probability of being within this interval\n\nTalking Points:\n- The correct answer is C. A 95% compatibility interval means that if we were to repeat the experiment many times, we would expect 95% of the calculated intervals to contain the true parameter value. It does not mean that 95% of the data falls within this interval (A is incorrect), nor does it imply a probability statement about the parameter itself (B and D are incorrect), as frequentist intervals do not assign probabilities to fixed parameters.\n\nSlide 3: Misconceptions About P-Values\n- Multiple-Choice Question: Which of the following is a common misconception about p-values?\n A) A low p-value confirms the research hypothesis\n B) A high p-value proves the null hypothesis is true\n C) P-values can indicate the size of an effect\n D) All of the above\n\nTalking Points:\n- The correct answer is D. All of the options listed are common misconceptions about p-values. A low p-value does not confirm the research hypothesis (A is incorrect); it simply suggests that the data are unlikely under the null hypothesis. A high p-value does not prove the null hypothesis is true (B is incorrect); it suggests that the data are not sufficiently unusual under the null hypothesis. P-values do not provide information about the size or importance of an effect (C is incorrect); they only indicate how incompatible the data are with the null hypothesis.\n\nSlide 4: P-Value Functions\n- Multiple-Choice Question: What information does a p-value function provide that a single p-value does not?\n A) The probability that the null hypothesis is true for different parameter values\n B) The p-values for a range of different alternative hypotheses\n C) The likelihood of the data under various parameter values\n D) The confidence level of the compatibility interval\n\nTalking Points:\n- The correct answer is B. A p-value function provides the p-values for a range of different alternative hypotheses, not just a single point alternative. It does not give the probability that the null hypothesis is true for different parameter values (A is incorrect), nor does it directly provide the likelihood of the data (C is incorrect). The confidence level of the compatibility interval is a separate concept (D is incorrect).\n\nSlide 5: Interpreting P-Values and Compatibility Intervals Together\n- Multiple-Choice Question: How should p-values and compatibility intervals be used together to interpret statistical results?\n A) To determine the probability that the null hypothesis is true\n B) To assess the evidence against the null hypothesis and the range of parameter values that are compatible with the data\n C) To calculate the effect size and its significance\n D) To prove the research hypothesis\n\nTalking Points:\n- The correct answer is B. P-values and compatibility intervals should be used together to assess the evidence against the null hypothesis (provided by the p-value) and to understand the range of parameter values that are compatible with the data (provided by the compatibility interval). They do not determine the probability that the null hypothesis is true (A is incorrect), nor do they prove the research hypothesis (D is incorrect). While they can help contextualize the effect size, they do not calculate it (C is incorrect).", + "role": "assistant", + }, + ] + output = tokenizer.apply_chat_template(message, tokenize=False) + + # dataset = DatasetLibrary.pull_dataset_with_retry( + # "zhan1993/orca_sqs_dataset" + # ) + + # dataset = dataset.filter(lambda example: example["task_name"] == config.finetune_task_name) + # splited_dataset = dataset.map(get_source_target,remove_columns=["messages","split"]) + # breakpoint() + + datamodule = OrcaDataModule(config) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + print(batch) + breakpoint() diff --git a/mttl/datamodule/preference_data_module.py b/mttl/datamodule/preference_data_module.py new file mode 100644 index 000000000..630a0b637 --- /dev/null +++ b/mttl/datamodule/preference_data_module.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass + +import torch + +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DataModule +from mttl.models.library.dataset_library import DatasetLibrary + + +@dataclass +class DataCollatorForDPO(DefaultCollator): + def __call__(self, batch): + prompts = ["Instruct: " + item["prompt"] + "\n" for item in batch] + chosen_responses = ["Output: " + item["chosen"] for item in batch] + rejected_responses = ["Output: " + item["rejected"] for item in batch] + + prompt_ids = self.tokenizer.batch_encode_plus( + prompts, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + )["input_ids"] + + prefered_tokenize = self.tokenizer.batch_encode_plus( + chosen_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + ) + prefered_ids = prefered_tokenize["input_ids"] + + disprefered_tokenize = self.tokenizer.batch_encode_plus( + rejected_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + ) + disprefered_ids = disprefered_tokenize["input_ids"] + + prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1) + prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1) + + prompt_prefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1 + ) + # compute the each length of the prefered + prefered_y_len = prefered_tokenize["attention_mask"].sum(dim=1) + disprefered_y_len = disprefered_tokenize["attention_mask"].sum(dim=1) + + prompt_disprefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1 + ) + + return { + "prompt_prefered_ids": prompt_prefered_ids, + "prompt_disprefered_ids": prompt_disprefered_ids, + "prompt_prefered_mask": prompt_prefered_mask, + "prompt_disprefered_mask": prompt_disprefered_mask, + "prefered_y_len": prefered_y_len, + "disprefered_y_len": disprefered_y_len, + } + + +@dataclass +class Preferencemodule(DataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + train_dataset = DatasetLibrary.pull_dataset_with_retry( + "jondurbin/truthy-dpo-v0.1" + )["train"] + + self.train_dataset, self.dev_dataset = self.create_train_valid_split( + train_dataset, 0.1 + ) + self.test_dataset = self.dev_dataset + + self.print_infos() + + @property + def collate_fn(self): + return DataCollatorForDPO( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +if __name__ == "__main__": + config = DatasetConfig(model="microsoft/phi-2") + datamodule = Preferencemodule(config) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # get the length of the response + prefered_y_len = batch["prefered_y_len"] + disprefered_y_len = batch["disprefered_y_len"] + print(prefered_y_len, disprefered_y_len) + breakpoint() diff --git a/mttl/datamodule/ultrafeedback_data_module.py b/mttl/datamodule/ultrafeedback_data_module.py new file mode 100644 index 000000000..df3ad6e57 --- /dev/null +++ b/mttl/datamodule/ultrafeedback_data_module.py @@ -0,0 +1,199 @@ +from dataclasses import dataclass +from typing import Any +import torch + +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DataModule +from mttl.models.library.dataset_library import DatasetLibrary + + +def is_openai_format(messages: Any) -> bool: + """ + Check if the input messages are in OpenAI format. + Args: + messages (`Any`): + Messages to check. + Returns: + `bool`: Whether the messages are in OpenAI format. + """ + if isinstance(messages, list) and all( + isinstance(message, dict) for message in messages + ): + return all("role" in message and "content" in message for message in messages) + return False + + +@dataclass +class UltrafeedbackDPOCollator(DefaultCollator): + def __call__(self, batch): + + # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + prompts = [] + chosen_responses = [] + rejected_responses = [] + for example in batch: + if "prompt" in example and is_openai_format(example["prompt"]): + prompt_messages = example["prompt"] + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + else: + prompt_messages = example["chosen"][:-1] + # Now we extract the final turn to define chosen/rejected responses + chosen_messages = example["chosen"][-1:] + rejected_messages = example["rejected"][-1:] + prompts.append( + self.tokenizer.apply_chat_template(prompt_messages, tokenize=False) + ) + chosen_responses.append( + self.tokenizer.apply_chat_template(chosen_messages, tokenize=False) + ) + rejected_responses.append( + self.tokenizer.apply_chat_template(rejected_messages, tokenize=False) + ) + + prompt_ids = self.tokenizer.batch_encode_plus( + prompts, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + )["input_ids"] + + prefered_tokenize = self.tokenizer.batch_encode_plus( + chosen_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + ) + prefered_ids = prefered_tokenize["input_ids"] + + disprefered_tokenize = self.tokenizer.batch_encode_plus( + rejected_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + ) + disprefered_ids = disprefered_tokenize["input_ids"] + + prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1) + prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1) + + prompt_prefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1 + ) + # compute the each length of the prefered + prefered_y_len = prefered_tokenize["attention_mask"].sum(dim=1) + disprefered_y_len = disprefered_tokenize["attention_mask"].sum(dim=1) + + prompt_disprefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1 + ) + + return { + "prompt_prefered_ids": prompt_prefered_ids, + "prompt_disprefered_ids": prompt_disprefered_ids, + "prompt_prefered_mask": prompt_prefered_mask, + "prompt_disprefered_mask": prompt_disprefered_mask, + "prefered_y_len": prefered_y_len, + "disprefered_y_len": disprefered_y_len, + } + + +@dataclass +class UltrafeedbackDPOmodule(DataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + dataset = DatasetLibrary.pull_dataset_with_retry( + "princeton-nlp/gemma2-ultrafeedback-armorm" + ) + + # format the ultrafeedback dataset to chatbot format + self.train_dataset = dataset["train"] + self.test_dataset = dataset["test"] + self.dev_dataset = self.test_dataset + + self.print_infos() + + @property + def collate_fn(self): + return UltrafeedbackDPOCollator( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +@dataclass +class UltrafeedbackSFTCollator(DefaultCollator): + def __call__(self, batch): + + # For SFT, the inputs are triples of (prompt, message), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + prompts = [] + messages = [] + for example in batch: + prompt_messages = example["prompt"] + chosen_messages = example["messages"] + prompts.append(prompt_messages) + messages.append( + self.tokenizer.apply_chat_template(chosen_messages, tokenize=False) + ) + + return { + "sources_texts": prompts, + "labels_texts": messages, + } + + +@dataclass +class UltrafeedbackSFTmodule(DataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + dataset = DatasetLibrary.pull_dataset_with_retry("HuggingFaceH4/ultrachat_200k") + + # format the ultrafeedback dataset to chatbot format + self.train_dataset = dataset["train_sft"] + self.test_dataset = dataset["test_sft"] + self.dev_dataset = self.test_dataset + + self.print_infos() + + @property + def collate_fn(self): + return UltrafeedbackSFTCollator( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +if __name__ == "__main__": + config = DatasetConfig(model="microsoft/Phi-3-mini-4k-instruct") + datamodule = UltrafeedbackSFTmodule(config) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + # prompt_prefered_mask = batch["prompt_prefered_mask"] + # prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # get the length of the response + # prefered_y_len = batch["prefered_y_len"] + # disprefered_y_len = batch["disprefered_y_len"] + print(batch) + breakpoint() diff --git a/mttl/evaluators/loglike_evaluator.py b/mttl/evaluators/loglike_evaluator.py index ee7b33adc..668081080 100644 --- a/mttl/evaluators/loglike_evaluator.py +++ b/mttl/evaluators/loglike_evaluator.py @@ -70,7 +70,11 @@ def evaluate( ) loss_per_option = loss_per_option.cpu() - if loss_per_option.dtype in [torch.bfloat16, torch.float16]: + if loss_per_option.dtype in [ + torch.bfloat16, + torch.float16, + torch.float32, + ]: loss_per_option = loss_per_option.float().numpy() loss_per_example = [ @@ -105,6 +109,5 @@ def evaluate( "predictions": all_predictions, "accuracy": float(np.mean(all_accuracies)) if all_accuracies else None, } - self.save_metrics(metrics, output_path) return metrics["accuracy"] diff --git a/mttl/models/containers/selectors/arrow_selector.py b/mttl/models/containers/selectors/arrow_selector.py index fbfdb3877..3121f1586 100644 --- a/mttl/models/containers/selectors/arrow_selector.py +++ b/mttl/models/containers/selectors/arrow_selector.py @@ -13,7 +13,6 @@ def compute_arrow_embeddings( ab_only=True, tie_params=None, tie_op="concat", - add_base_proto=False, recompute_prototypes=False, ) -> str: from mttl.models.library.library_transforms import ( @@ -26,7 +25,6 @@ def compute_arrow_embeddings( ab_only=ab_only, tie_params=tie_params or "default", tie_op=tie_op, - add_base_proto=add_base_proto, ) ArrowTransform(cfg).transform( library, diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 1fb2a28bc..fc08c8f1c 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -32,6 +32,8 @@ Modifier, ) from mttl.models.modifiers.modify_model import modify_transformer +from mttl.models.lightning.base_module import LightningEfficientCheckpoint +import torch.nn.functional as F @contextlib.contextmanager @@ -64,7 +66,6 @@ def __init__( **loading_kwargs, ): super().__init__(config, model_object=model_object, **loading_kwargs) - if config.modifier_config is not None: modify_transformer(self.model, config.modifier_config) @@ -544,6 +545,512 @@ def from_pretrained_peft( self.add_expert_instance(expert, is_default=set_as_default) return self + def set_routing_infos(self, batch, generate=False): + self.model.info_container["routing_infos"] = RoutingInfo.from_batch(batch) + + +def calculate_DPO_loss( + original_prefered_logprob, + original_disprefered_logprob, + ref_prefered_logprob, + ref_disprefered_logprob, + beta=2.0, +): + """ + Calculate the DPO loss. + original_prefered_logprob: the logprob of the prefered expert in the original model + original_disprefered_logprob: the logprob of the disprefered expert in the original model + ref_prefered_logprob: the logprob of the prefered expert in the reference model + ref_disprefered_logprob: the logprob of the disprefered expert in the reference model + """ + + original_prefered_relative_logprob = ( + original_prefered_logprob - ref_prefered_logprob + ) + disprefered_relative_logprob = ( + original_disprefered_logprob - ref_disprefered_logprob + ) + + reward_accuracies = ( + (original_prefered_relative_logprob > disprefered_relative_logprob) + .float() + .mean(dim=-1) + ) + reward_margins = ( + original_prefered_relative_logprob - disprefered_relative_logprob + ).mean(dim=-1) + + loss = -F.logsigmoid( + beta * (original_prefered_relative_logprob - disprefered_relative_logprob) + ).mean(dim=-1) + + return loss, reward_accuracies, reward_margins + + +def get_log_prob(logits, labels): + log_probs = F.log_softmax(logits, dim=-1) + return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1).mean(-1) + + +class ExpertModelSimPO(LightningEfficientCheckpoint): + def __init__(self, preference_model, **kwargs): + super().__init__(**kwargs) + self.preference_model = preference_model + self.trainable_param_names = kwargs.get("trainable_param_names", None) + self.beta = kwargs.get("beta", 0.5) + self.loss_type = kwargs.get("loss_type", "sigmoid") + self.label_smoothing = kwargs.get("label_smoothing", 0.1) + self.gamma_beta_ratio = kwargs.get("gamma_beta_ratio", 0.5) + # log hyperparameters + self.save_hyperparameters(kwargs) + + def simpo_loss( + self, + original_prefered_logprob, + original_disprefered_logprob, + gamma_beta_ratio, + prefered_y_len, + disprefered_y_len, + ): + """ + Compute the SIMPO loss. + + ref: https://github.com/princeton-nlp/SimPO/blob/main/scripts/simpo_trainer.py + + args: original_prefered_logps: log probabiliteis of the prefered expert in the original model + original_disprefered_logps: log probabiliteis of the disprefered expert in the original model + """ + + # normalize the log probabilities with the length of the response + original_prefered_logprob = original_prefered_logprob / prefered_y_len + original_disprefered_logprob = original_disprefered_logprob / disprefered_y_len + + pi_logratios = original_prefered_logprob - original_disprefered_logprob + logits = pi_logratios - gamma_beta_ratio + + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + else: + raise ValueError( + f"Loss type {self.loss_type} not supported. Choose from ['sigmoid', 'hinge']" + ) + + # normalize the log probabilities with the length of the response + chosen_rewards = self.beta * original_prefered_logprob.detach() + + reject_rewards = self.beta * original_disprefered_logprob.detach() + + return losses, chosen_rewards, reject_rewards + + def forward(self, batch): + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # get the length of the response + prefered_y_len = batch["prefered_y_len"] + disprefered_y_len = batch["disprefered_y_len"] + + # get the log probabilities of the prefered and disprefered experts + model_prefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + loss, chosen_rewards, rejected_rewards = self.simpo_loss( + model_prefered_log_prob, + model_disprefered_log_prob, + gamma_beta_ratio=self.gamma_beta_ratio, + prefered_y_len=prefered_y_len, + disprefered_y_len=disprefered_y_len, + ) + + return loss.mean(), chosen_rewards, rejected_rewards + + def training_step(self, batch, _): + + loss, chosen_rewards, rejected_rewards = self.forward(batch) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + metrices = { + "loss": loss.mean(), + "reward_accuracies": reward_accuracies.mean().cpu(), + "chosen_rewards": chosen_rewards.mean(), + "rejected_rewards": rejected_rewards.mean(), + "reward_margins": (chosen_rewards - rejected_rewards).mean().cpu(), + } + + for key, value in metrices.items(): + self.log(f"train/{key}", value, on_step=True, on_epoch=True, prog_bar=True) + + return loss.mean() + + def validation_step(self, batch, _): + + loss, chosen_rewards, rejected_rewards = self.forward(batch) + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + metrices = { + "loss": loss.mean(), + "reward_accuracies": reward_accuracies.mean().cpu(), + "chosen_rewards": chosen_rewards.mean(), + "rejected_rewards": rejected_rewards.mean(), + "reward_margins": (chosen_rewards - rejected_rewards).mean().cpu(), + } + for key, value in metrices.items(): + self.log(f"val/{key}", value, on_step=True, on_epoch=True, prog_bar=True) + + return loss.mean() + + def test_step(self, batch, _): + loss, chosen_rewards, rejected_rewards = self.forward(batch) + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + metrices = { + "loss": loss.mean(), + "reward_accuracies": reward_accuracies.mean().cpu(), + "chosen_rewards": chosen_rewards.mean(), + "rejected_rewards": rejected_rewards.mean(), + "reward_margins": (chosen_rewards - rejected_rewards).mean().cpu(), + } + for key, value in metrices.items(): + self.log(f"test/{key}", value, on_step=True, on_epoch=True, prog_bar=True) + + return loss.mean() + + +from functools import wraps + + +def gpu_memory_usage_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # 清除GPU缓存,确保测量准确 + torch.cuda.empty_cache() + + # 获取函数执行前的GPU内存使用情况 + memory_allocated_before = torch.cuda.memory_allocated(device) + memory_reserved_before = torch.cuda.memory_reserved(device) + print( + f"[Before] Memory allocated: {memory_allocated_before / (1024 ** 2):.2f} MB" + ) + print( + f"[Before] Memory reserved: {memory_reserved_before / (1024 ** 2):.2f} MB" + ) + + # 执行目标函数 + result = func(*args, **kwargs) + + # 获取函数执行后的GPU内存使用情况 + memory_allocated_after = torch.cuda.memory_allocated(device) + memory_reserved_after = torch.cuda.memory_reserved(device) + print( + f"[After] Memory allocated: {memory_allocated_after / (1024 ** 2):.2f} MB" + ) + print(f"[After] Memory reserved: {memory_reserved_after / (1024 ** 2):.2f} MB") + + # 计算内存使用的差值 + memory_allocated_diff = memory_allocated_after - memory_allocated_before + memory_reserved_diff = memory_reserved_after - memory_reserved_before + print( + f"Memory allocated difference: {memory_allocated_diff / (1024 ** 2):.2f} MB" + ) + print( + f"Memory reserved difference: {memory_reserved_diff / (1024 ** 2):.2f} MB" + ) + + return result + + return wrapper + + +class ExpertModelDPO(LightningEfficientCheckpoint): + + def __init__(self, preference_model, ref_expert_model, **kwargs): + super().__init__(**kwargs) + self.preference_model = preference_model + self.ref_expert_model = ref_expert_model + self.trainable_param_names = kwargs.get("trainable_param_names", None) + self.beta = kwargs.get("beta", 2.0) + # log hyperparameters + self.save_hyperparameters(kwargs) + + def training_step(self, batch, _): + + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Measure GPU memory before forward pass + memory_allocated_before = torch.cuda.memory_allocated(device) + memory_reserved_before = torch.cuda.memory_reserved(device) + + # logits = self.preference_model.model.forward(prompt_prefered_ids).logits + + # loss = torch.mean(logits) + # original model + model_prefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + # reference model + ref_prefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + ref_disprefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + # Measure GPU memory after forward pass + memory_allocated_after = torch.cuda.memory_allocated(device) + memory_reserved_after = torch.cuda.memory_reserved(device) + + # Calculate the difference in memory usage + memory_allocated_diff = memory_allocated_after - memory_allocated_before + memory_reserved_diff = memory_reserved_after - memory_reserved_before + + print( + f"Memory allocated before forward pass: {memory_allocated_before / (1024 ** 2):.2f} MB" + ) + print( + f"Memory allocated after forward pass: {memory_allocated_after / (1024 ** 2):.2f} MB" + ) + print( + f"Memory allocated difference: {memory_allocated_diff / (1024 ** 2):.2f} MB" + ) + + print( + f"Memory reserved before forward pass: {memory_reserved_before / (1024 ** 2):.2f} MB" + ) + print( + f"Memory reserved after forward pass: {memory_reserved_after / (1024 ** 2):.2f} MB" + ) + print( + f"Memory reserved difference: {memory_reserved_diff / (1024 ** 2):.2f} MB" + ) + + loss = -F.logsigmoid( + self.beta * (model_prefered_log_prob - model_disprefered_log_prob) + - (ref_prefered_log_prob - ref_disprefered_log_prob) + ).mean() + + loss, reward_accuracies, reward_margins = calculate_DPO_loss( + model_prefered_log_prob, + model_disprefered_log_prob, + ref_prefered_log_prob, + ref_disprefered_log_prob, + beta=self.beta, + ) + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + + # self.log( + # "train/reward_accuracies", + # reward_accuracies, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # ) + # self.log( + # "train/reward_margins", + # reward_margins, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # ) + # clear the gpu memory + + return loss + + def validation_step(self, batch, _): + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # original model + model_prefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + # reference model + ref_prefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + ref_disprefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + loss, reward_accuracies, reward_margins = calculate_DPO_loss( + model_prefered_log_prob, + model_disprefered_log_prob, + ref_prefered_log_prob, + ref_disprefered_log_prob, + beta=self.beta, + ) + + self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log( + "val/reward_accuracies", + reward_accuracies, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log( + "val/reward_margins", + reward_margins, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + + return loss + + def test_step(self, batch, _): + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # original model + model_prefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + # reference model + ref_prefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + ref_disprefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + loss, reward_accuracies, reward_margins = calculate_DPO_loss( + model_prefered_log_prob, + model_disprefered_log_prob, + ref_prefered_log_prob, + ref_disprefered_log_prob, + beta=self.beta, + ) + self.log("test/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log( + "test/reward_accuracies", + reward_accuracies, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log( + "test/reward_margins", + reward_margins, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + return loss + + +class MoEModel(MultiExpertModel): + def __init__(self, expert_library: ExpertLibrary = None, **kwargs): + kwargs["top_k"] = kwargs["moe_top_k"] + kwargs["emb_dim"] = kwargs["moe_emb_dim"] + kwargs["rkhs_dim"] = kwargs["moe_rkhs_dim"] + init_from_scratch = kwargs.get("init_from_scratch", False) + + super().__init__(**kwargs) + + if not self.hparams.library_id and expert_library is None or init_from_scratch: + for i in range(self.hparams.moe_num_experts): + # Adding a Skilled LoRA with 1 skill. + exp_config = SkilledLoRAConfig( + n_skills=1, + modify_layers=self.hparams.modify_layers, + modify_modules=self.hparams.modify_modules, + lora_alpha=self.hparams.lora_alpha, + lora_dropout=self.hparams.lora_dropout, + lora_rank=self.hparams.lora_rank, + lora_init_b_random=True, + n_splits=self.hparams.n_splits, + phi_2_align_heads=self.hparams.phi_2_align_heads, + ) + self.add_empty_expert(f"e{i}", exp_config) + self.moe_num_experts = kwargs["moe_num_experts"] + @classmethod def from_pretrained_library( cls, diff --git a/mttl/models/library/expert_library.py b/mttl/models/library/expert_library.py index 8065dccdc..96c88f544 100644 --- a/mttl/models/library/expert_library.py +++ b/mttl/models/library/expert_library.py @@ -15,7 +15,8 @@ CommitOperationCopy, CommitOperationDelete, ) -from huggingface_hub.errors import RepositoryNotFoundError + +# from huggingface_hub.errors import RepositoryNotFoundError from mttl.logging import logger from mttl.models.library.backend_engine import ( diff --git a/projects/modular_llm/eval_library.py b/projects/modular_llm/eval_library.py index 8a387cb89..af8b1eea9 100644 --- a/projects/modular_llm/eval_library.py +++ b/projects/modular_llm/eval_library.py @@ -117,7 +117,6 @@ def fetch_prototypes(args: EvaluationConfig, library: ExpertLibrary) -> str: ab_only=args.ab_only, tie_params=args.tie_params, tie_op=args.tie_op, - add_base_proto=args.add_base_proto, recompute_prototypes=args.recompute_prototypes, ) elif args.merge_or_route == "hidden": diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_llm/get_clusters.py new file mode 100644 index 000000000..efd8cf32f --- /dev/null +++ b/projects/modular_llm/get_clusters.py @@ -0,0 +1,187 @@ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, Subset +from mttl.models.ranker.adapter_ranker import AdapterRankerHelper +from tqdm import tqdm +from nomic import atlas +import numpy as np +from sklearn.cluster import KMeans +import argparse +import logging +from mttl.datamodule.mt_seq_to_seq_module import FlanModule, FlanConfig +from sentence_transformers import SentenceTransformer +from datasets import concatenate_datasets +import os +import huggingface_hub + +huggingface_token = os.environ.get("HF_TOKEN") +huggingface_hub.login(token=huggingface_token) + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", type=int, default=64) + +parser.add_argument("--subsample", type=float, default=0.2) + +parser.add_argument("--num_clusters", type=int, default=256) + +parser.add_argument("--dataset", type=str, default="orca") + +parser.add_argument("--output_file", type=str, default="orca_cluster.json") + +parser.add_argument("--encoding", type=str, default="classifier") + +parser.add_argument( + "--new_dataname", type=str, default="zhan1993/flan-10k-flat-cluster-embedding" +) + +parser.add_argument( + "--model", type=str, default="sentence-transformers/sentence-t5-xxl" +) +args = parser.parse_args() +np.random.seed(42) + + +def get_text_encode(text, model): + + if args.encoding == "classifier": + return model.get_text_encode(text).cpu().detach().numpy() + elif args.encoding == "embedding": + return model.encode(text) + + +if args.encoding == "classifier": + model = AdapterRankerHelper.get_ranker_instance( + ranker_model="classifier", + ranker_path="zhan1993/classifier_ranker_t5_v3", + ) +elif args.encoding == "embedding": + model = SentenceTransformer(args.model) + +def get_dataset(args): + dataset = DatasetLibrary.pull_dataset(args.dataset, split="train") + + # create the subsample of the dataset if cutoff is set. + if args.cutoff > 0: + dataset = dataset.shuffle(seed=args.seed) + dataset = dataset.select(range(args.cutoff)) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + np.random.shuffle(indices) + split = int(np.floor(args.subsample * dataset_size)) + subset_indices = indices[:split] + subset_dataset = dataset.select(subset_indices) + + train_dataloader = DataLoader( + subset_dataset, batch_size=args.batch_size, num_workers=args.num_workers + ) + all_dataloader = DataLoader( + dataset, batch_size=args.batch_size, num_workers=args.num_workers + ) + + return train_dataloader, all_dataloader, dataset, subset_dataset + +# def get_orca_dataset(): + +# dataset = load_dataset("Open-Orca/OpenOrca") + +# # create the subsample of the dataset. +# dataset_size = len(dataset["train"]) +# indices = list(range(dataset_size)) +# # random indices +# np.random.shuffle(indices) +# split = int(np.floor(args.subsample * dataset_size)) +# subset_indices = indices[:split] +# subset_dataset = Subset(dataset["train"], subset_indices) + +# train_dataloader = DataLoader( +# subset_dataset, batch_size=args.batch_size, num_workers=8 +# ) +# all_dataloader = DataLoader( +# dataset["train"], batch_size=args.batch_size, num_workers=8 +# ) + +# return train_dataloader, all_dataloader, dataset["train"] + + +# def get_flan_dataset(): + +# flan = FlanModule( +# FlanConfig( +# model="EleutherAI/gpt-neo-125m", +# model_family="gpt", +# train_batch_size=4, +# predict_batch_size=4, +# dataset="sordonia/flan-10k-flat", +# remove_phi_eval_tasks=True, +# ) +# ) + +# dataset = flan.train_dataset +# # create the subsample of the dataset. +# dataset_size = len(dataset) +# indices = list(range(dataset_size)) +# # random indices +# np.random.shuffle(indices) +# split = int(np.floor(args.subsample * dataset_size)) +# subset_indices = indices[:split] +# subset_dataset = Subset(dataset, subset_indices) + +# train_dataloader = DataLoader( +# subset_dataset, batch_size=args.batch_size, num_workers=8 +# ) +# all_dataloader = flan.train_dataloader() + +# all_dataset = concatenate_datasets( +# [flan.train_dataset, flan.dev_dataset, flan.test_dataset] +# ) + +# return train_dataloader, all_dataloader, all_dataset + + +if __name__ == "__main__": + + train_dataloader, all_dataloader, all_dataset, subset_dataset = get_dataset(args) + breakpoint() + + embedding_list = [] + + for i, batch in tqdm( + enumerate(train_dataloader), total=len(train_dataloader), desc="dataset" + ): + if "source" in batch: + embedding = get_text_encode(batch["source"], model) + else: + embedding = get_text_encode(batch["question"], model) + embedding_list.append(embedding) + + all_embedding = np.concatenate(embedding_list, axis=0).reshape(-1, 768) + logger.info(f"all_embedding shape: {all_embedding.shape}") + kmeans = KMeans( + n_clusters=args.num_clusters, + init="k-means++", + n_init=10, + random_state=42, + ).fit(all_embedding) + + # map the new item with kmeans cluster + + def add_cluster_id(example): + if "source" in example: + embedding = get_text_encode(example["source"], model) + else: + embedding = get_text_encode(example["question"], model) + embedding = embedding.reshape(1, -1) + example["cluster_id"] = str(kmeans.predict(embedding)[0]) + return example + + # all_dataset = all_dataset.select(list(range(100))) + + dataset = all_dataset.map(add_cluster_id) + + # Push the merged dataset back to Hugging Face Hub + dataset.push_to_hub(args.new_dataname) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py new file mode 100644 index 000000000..9fd4928e1 --- /dev/null +++ b/projects/modular_llm/train_dpo.py @@ -0,0 +1,209 @@ +import copy +import os +import shutil +import sys +from tempfile import TemporaryDirectory + +import torch +from pytorch_lightning import Trainer, seed_everything + + +from mttl.models.lightning.callbacks import LiveCheckpointCallback +from mttl.datamodule.base import DatasetConfig +from mttl.datamodule.preference_data_module import Preferencemodule +from mttl.datamodule.ultrafeedback_data_module import UltrafeedbackDPOmodule + +# from mttl.datamodule.base import get_datamodule +from mttl.arguments import ExpertConfig, MultiExpertConfig +from mttl.models.expert_model import ExpertModelConfig +from mttl.models.expert_model import ( + ExpertModel, + ExpertModelDPO, + MoEModel, + ExpertModelSimPO, +) +from mttl.models.library.expert import Expert, load_expert +from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary +from mttl.models.monitors import get_monitors +from mttl.models.lightning.loggers import get_pl_loggers +from mttl.logging import logger, setup_logging +from mttl.utils import generate_random_string, rank_zero_only_and_wait, remote_login + + +def run_multitask(args: ExpertConfig): + seed_everything(args.seed, workers=True) + + # get directory of the current file + setup_logging(args.output_dir) + logger.info("Args: {}".format(args.to_json())) + + remote_login(args.remote_token) + expert_library = None + if args.library_id: + + @rank_zero_only_and_wait(before=False, after=True) + def create_library(args): + expert_library = ExpertLibrary.get_expert_library( + repo_id=args.library_id, + create=True, + destination_id=args.destination_library_id, + ) + return expert_library + + expert_library = create_library(args) + + loggers = get_pl_loggers(args) + # select dataloader + if args.model_modifier == "poly": + model_class = MoEModel + else: + model_class = ExpertModel + config = DatasetConfig( + model=args.model, + train_batch_size=args.train_batch_size, + predict_batch_size=args.predict_batch_size, + max_input_length=args.max_input_length, + max_output_length=args.max_output_length, + ) + + if "ultrafeedback" in args.dataset: + dm = UltrafeedbackDPOmodule(config) + else: + dm = Preferencemodule(config) + + model_config = ExpertModelConfig( + base_model=args.model, + modifier_config=args.modifier_config, + ) + + model = model_class( + config=model_config, + expert_library=expert_library, + **vars(args), + ) + if args.rl_training == "dpo": + # args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" + ref_model = model_class( + config=model_config, + expert_library=expert_library, + **vars(args), + ) + # eval mode + ref_model.eval() + module = ExpertModelDPO( + **vars(args), preference_model=model, ref_expert_model=ref_model + ) + elif args.rl_training == "simpo": + # args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" + module = ExpertModelSimPO(**vars(args), preference_model=model) + else: + module = model + # get metric monitors for models + callbacks = get_monitors(args) + if "mbpp" in args.dataset: + monitor = "downstream/mbpp" + mode = "max" + else: + monitor = "val/loss" + mode = "min" + + checkpoint_callback = LiveCheckpointCallback( + dirpath=args.output_dir, + monitor=monitor, + save_last=True, + mode=mode, + save_each_epoch=args.save_each_epoch, + ) + callbacks.append(checkpoint_callback) + + val_check_interval = args.eval_every + if val_check_interval == -1 or val_check_interval is None: + val_check_interval = None + else: + val_check_interval = args.gradient_accumulation_steps * args.eval_every + if val_check_interval > len(dm.train_dataloader()): + val_check_interval = len(dm.train_dataloader()) + elif val_check_interval > args.total_steps and args.total_steps != -1: + val_check_interval = args.total_steps + + trainer = Trainer( + devices=-1, + accelerator="gpu", + logger=loggers, + num_sanity_val_steps=0, + default_root_dir=args.output_dir, + max_epochs=args.num_train_epochs, + max_steps=args.total_steps + 1 if args.total_steps != -1 else -1, + gradient_clip_val=args.max_grad_norm, + strategy=args.compute_strategy if args.compute_strategy else "auto", + callbacks=callbacks, + enable_checkpointing=False, + log_every_n_steps=args.gradient_accumulation_steps, + accumulate_grad_batches=args.gradient_accumulation_steps, + precision=( + int(args.precision) if args.precision in ["16", "32"] else args.precision + ), + val_check_interval=val_check_interval, + ) + + # initial validation only for a bunch of datasets... ? + if args.eval_before_training: + # validating before training fails with deepspeed + trainer.validate(module, dm) + + if args.do_train: + trainer.fit(module, dm) + + torch.cuda.empty_cache() + + # reload best model before pushing! + checkpoint = ( + checkpoint_callback.best_model_path or checkpoint_callback.last_model_path + ) + if args.compute_strategy == "deepspeed": + from deepspeed.utils.zero_to_fp32 import ( + convert_zero_checkpoint_to_fp32_state_dict, + ) + + new_path = checkpoint.replace(".ckpt", "_fp32.ckpt") + + @rank_zero_only_and_wait(before=True, after=True) + def convert_ckpt(path, new_path): + convert_zero_checkpoint_to_fp32_state_dict(path, new_path) + + convert_ckpt(checkpoint, new_path) + checkpoint = torch.load(new_path) + else: + checkpoint = torch.load(checkpoint)["state_dict"] + + module.load_state_dict(checkpoint) + trainer.test(module, dm) + + @rank_zero_only_and_wait(before=False, after=True) + def upload_library(expert_library, module): + if expert_library is not None: + # refresh expert library: so we dont overwrite the readme if the remote has changed. + expert_library.refresh_from_remote() + + if isinstance(module, MoEModel): + with expert_library.batched_commit(): + for expert_name in module.experts_names: + expert = module.get_expert_instance(expert_name) + expert_library.add_expert(expert, expert_name) + elif isinstance(module, ExpertModel): + expert = module.as_expert() + expert_name = ( + args.expert_name + or args.finetune_task_name + or generate_random_string() + ) + expert_library.add_expert(expert, expert_name) + else: + raise ValueError("Model class not recognized") + + # upload_library(expert_library, module) + + +if __name__ == "__main__": + args = MultiExpertConfig.parse() ## in case we only train the routing + run_multitask(args)