diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py index 1f1406f2f..94aaeaddb 100644 --- a/mttl/models/containers/base.py +++ b/mttl/models/containers/base.py @@ -35,6 +35,11 @@ def __init__(self, config, layer, selector=None): self.selector = selector or TaskNameSelector() self._default_expert_name = None self.expert_infos = {} + self.experts = nn.ModuleDict({}) + + @property + def num_experts(self): + return len(self.experts) @property def default_expert_name(self): diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py index 99a9a3045..6be1ae7df 100644 --- a/mttl/models/containers/peer_container.py +++ b/mttl/models/containers/peer_container.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + import torch from torch import nn @@ -7,14 +9,26 @@ MultiheadBatchSequenceExpertsAndWeightsSelectorOutput, ) from mttl.models.library.expert import Expert -from mttl.models.modifiers.mlp import PEERConfig, PEERModifier +from mttl.models.modifiers.base import Modifier, ModifierConfig + +# from mttl.models.modifiers.mlp import PEERConfig, PEERModifier # diff architectures name those layers differently DOWN_NAMES = ["fc1", "c_fc"] UP_NAMES = ["fc2", "c_proj"] -class PEERMLPContainer(ExpertContainer): +@dataclass +class PEERConfig(ModifierConfig): + n_heads: int = 8 + moe__num_experts: int = 100 + emb_dim: int = 128 + down_proj_layer: str = "fc1" + up_proj_layer: str = "fc2" + + +@Modifier.register("peer", config_cls=PEERConfig) +class PEERMLPContainer(ExpertContainer, Modifier): """ PEER layer from Mixture of A Million Experts (https://arxiv.org/pdf/2407.04153) @@ -33,7 +47,7 @@ def __init__( **kwargs, ): super().__init__(config, module) - self.num_experts = 0 + self._num_experts = 0 down_names = DOWN_NAMES + [ config.down_proj_layer ] # names of the up and down projection layers in the MLP block @@ -55,21 +69,25 @@ def __init__( self.dtype = next(self.layer.parameters()).dtype self.layer = nn.Identity() + self.expert_name = None self.layer.in_features = self.input_dim - self.experts = PEERModifier(config) + + @property + def num_experts(self): + return self._num_experts def initialize_experts(self, expert_config: PEERConfig) -> None: - self.num_experts = expert_config.moe_num_experts + self._num_experts = expert_config.moe__num_experts assert ( - self.num_experts**0.5 + self._num_experts**0.5 ).is_integer(), "Number of experts must be a square number" self.peer_weight_down_embed = nn.Embedding( - num_embeddings=self.num_experts, + num_embeddings=self._num_experts, embedding_dim=self.input_dim, dtype=self.dtype, ) self.peer_weight_up_embed = nn.Embedding( - num_embeddings=self.num_experts, + num_embeddings=self._num_experts, embedding_dim=self.output_dim, dtype=self.dtype, ) @@ -96,10 +114,22 @@ def add_expert(self, expert: Expert, **kwargs) -> None: return self.on_add_expert(expert, **kwargs) def on_add_expert(self, expert: Expert, **kwargs) -> None: + """ + 'initialize_experts' is called from here in order not to break logic in expert model + """ expert_config: PEERConfig = expert.expert_config - if self.num_experts == expert_config.moe_num_experts: + if self._num_experts == expert_config.moe__num_experts: raise ContainerFullException() self.initialize_experts(expert_config) - - def __getitem__(self, key): - pass + self.expert_infos[expert.name] = expert.expert_info + self.expert_name = expert.name + + def __getitem__(self, name): + if name != self.expert_name: + raise ValueError( + f"Expert with name {name} does not exist in this container." + ) + return self + + def __len__(self): + return self._num_experts diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 288861522..5aebfd71d 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -163,7 +163,7 @@ def experts_containers(self) -> List[ExpertContainer]: containers = [] for _, module in self.model.named_modules(): for _, child in dict(module.named_children()).items(): - if isinstance(child, ExpertContainer) and len(child.experts) > 0: + if isinstance(child, ExpertContainer) and child.num_experts > 0: containers.append(child) return containers diff --git a/mttl/models/lightning/expert_module.py b/mttl/models/lightning/expert_module.py index 5d2918588..6d5da6b63 100644 --- a/mttl/models/lightning/expert_module.py +++ b/mttl/models/lightning/expert_module.py @@ -24,6 +24,14 @@ class LightningTrainingMixin: + + @property + def experts_names(self): + return self.model.experts_names + + def get_expert_instance(self, name): + return self.model.get_expert_instance(name) + @property def _log_pref(self): return getattr(self.hparams, "logging_prefix", "") @@ -215,7 +223,7 @@ def load_from_checkpoint( return model -class MoEModule(LightningEfficientCheckpoint, LightningTrainingMixin): +class MoEModule(LightningTrainingMixin, LightningEfficientCheckpoint): def __init__( self, model_object: PreTrainedModel = None, diff --git a/mttl/models/modifiers/mlp.py b/mttl/models/modifiers/mlp.py index 8d2b4cf35..2ce59fef3 100644 --- a/mttl/models/modifiers/mlp.py +++ b/mttl/models/modifiers/mlp.py @@ -69,30 +69,3 @@ def parallel_linear_forward( hidden_states = mlp._modifier_forward(hidden_states) output.index_add_(0, indices, hidden_states.to(input.dtype)) return mlps[0].layer(input) + output - - -@dataclass -class PEERConfig(ModifierConfig): - n_heads: int = 8 - moe_num_experts: int = 100 - emb_dim: int = 128 - down_proj_layer: str = "fc1" - up_proj_layer: str = "fc2" - - -@Modifier.register("peer", config_cls=PEERConfig) -class PEERModifier(Modifier): - """ - Peer modifier basically does nothing, the job is done in the container. - """ - - def __init__( - self, - config: PEERConfig, - **kwargs, - ): - super().__init__() - self.config = config - - def __len__(self): - return self.config.moe_num_experts diff --git a/projects/modular_llm/eval_library.py b/projects/modular_llm/eval_library.py index 8ad43c300..d5152ad76 100644 --- a/projects/modular_llm/eval_library.py +++ b/projects/modular_llm/eval_library.py @@ -3,9 +3,9 @@ from copy import deepcopy import torch -import wandb from pytorch_lightning import seed_everything +import wandb from mttl.arguments import EvaluationConfig from mttl.datamodule.base import get_datamodule from mttl.evaluators.base import EvaluatorRunner, setup_evaluators @@ -20,7 +20,11 @@ WeightedLinearMergeConfig, ) from mttl.models.lightning.callbacks import LossCallback -from mttl.models.lightning.expert_module import ExpertModule, MultiExpertModule +from mttl.models.lightning.expert_module import ( + ExpertModule, + MoEModule, + MultiExpertModule, +) from mttl.models.modifiers.lora import LoRAConfig from mttl.utils import remote_login @@ -186,6 +190,10 @@ def run_eval(args: EvaluationConfig): module = MultiExpertModule(**vars(expert.training_config)).to("cuda") module.add_expert_instance(expert, is_default=True) + elif args.merge_or_route in ["peer"]: + module: MoEModule = MoEModule(**vars(an_expert.training_config)).to("cuda") + module.model.model.load_state_dict(an_expert.expert_weights, strict=False) + elif args.merge_or_route == "uniform_lora_after_op": # Here we merge the LoRA experts after the outer product we cannot really do it # with the lib transform, cause this would require storing large matrices in memory diff --git a/projects/modular_llm/train_experts.py b/projects/modular_llm/train_experts.py index bb73ca7fe..025e6e493 100644 --- a/projects/modular_llm/train_experts.py +++ b/projects/modular_llm/train_experts.py @@ -7,6 +7,7 @@ from mttl.arguments import Args, ExpertConfig from mttl.datamodule.base import get_datamodule from mttl.logging import logger, setup_logging +from mttl.models.library.expert import Expert from mttl.models.library.expert_library import ExpertLibrary from mttl.models.lightning.callbacks import ( DownstreamEvalCallback, @@ -200,7 +201,8 @@ def upload_library(expert_library, module): if isinstance(module, MoEModule): with expert_library.batched_commit(): for expert_name in module.experts_names: - expert = module.get_expert_instance(expert_name) + expert: Expert = module.get_expert_instance(expert_name) + expert.expert_info.training_config = args expert_library.add_expert(expert, expert_name) elif isinstance(module, ExpertModule): expert = module.as_expert()