From 0f2c6f9a8bfc0a58c9e677e122adeba9983982f0 Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 10:59:25 -0400
Subject: [PATCH 1/6] peer contained is modifier
---
mttl/models/containers/base.py | 5 +++
mttl/models/containers/peer_container.py | 56 +++++++++++++++++++-----
mttl/models/expert_model.py | 2 +-
mttl/models/modifiers/mlp.py | 29 +-----------
4 files changed, 51 insertions(+), 41 deletions(-)
diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py
index 1f1406f2f..6e7915742 100644
--- a/mttl/models/containers/base.py
+++ b/mttl/models/containers/base.py
@@ -35,7 +35,12 @@ def __init__(self, config, layer, selector=None):
self.selector = selector or TaskNameSelector()
self._default_expert_name = None
self.expert_infos = {}
+ self.experts = nn.ModuleDict({})
+ @property
+ def num_experts(self):
+ return len(self.experts)
+
@property
def default_expert_name(self):
return self._default_expert_name
diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py
index 99a9a3045..b945eca0e 100644
--- a/mttl/models/containers/peer_container.py
+++ b/mttl/models/containers/peer_container.py
@@ -1,3 +1,5 @@
+from dataclasses import dataclass
+
import torch
from torch import nn
@@ -7,14 +9,26 @@
MultiheadBatchSequenceExpertsAndWeightsSelectorOutput,
)
from mttl.models.library.expert import Expert
-from mttl.models.modifiers.mlp import PEERConfig, PEERModifier
+from mttl.models.modifiers.base import Modifier, ModifierConfig
+
+# from mttl.models.modifiers.mlp import PEERConfig, PEERModifier
# diff architectures name those layers differently
DOWN_NAMES = ["fc1", "c_fc"]
UP_NAMES = ["fc2", "c_proj"]
-class PEERMLPContainer(ExpertContainer):
+@dataclass
+class PEERConfig(ModifierConfig):
+ n_heads: int = 8
+ moe__num_experts: int = 100
+ emb_dim: int = 128
+ down_proj_layer: str = "fc1"
+ up_proj_layer: str = "fc2"
+
+
+@Modifier.register("peer", config_cls=PEERConfig)
+class PEERMLPContainer(ExpertContainer, Modifier):
"""
PEER layer from Mixture of A Million Experts (https://arxiv.org/pdf/2407.04153)
@@ -33,7 +47,7 @@ def __init__(
**kwargs,
):
super().__init__(config, module)
- self.num_experts = 0
+ self._num_experts = 0
down_names = DOWN_NAMES + [
config.down_proj_layer
] # names of the up and down projection layers in the MLP block
@@ -55,21 +69,25 @@ def __init__(
self.dtype = next(self.layer.parameters()).dtype
self.layer = nn.Identity()
+ self.expert_name = None
self.layer.in_features = self.input_dim
- self.experts = PEERModifier(config)
+
+ @property
+ def num_experts(self):
+ return self._num_experts
def initialize_experts(self, expert_config: PEERConfig) -> None:
- self.num_experts = expert_config.moe_num_experts
+ self._num_experts = expert_config.moe__num_experts
assert (
- self.num_experts**0.5
+ self._num_experts**0.5
).is_integer(), "Number of experts must be a square number"
self.peer_weight_down_embed = nn.Embedding(
- num_embeddings=self.num_experts,
+ num_embeddings=self._num_experts,
embedding_dim=self.input_dim,
dtype=self.dtype,
)
self.peer_weight_up_embed = nn.Embedding(
- num_embeddings=self.num_experts,
+ num_embeddings=self._num_experts,
embedding_dim=self.output_dim,
dtype=self.dtype,
)
@@ -96,10 +114,24 @@ def add_expert(self, expert: Expert, **kwargs) -> None:
return self.on_add_expert(expert, **kwargs)
def on_add_expert(self, expert: Expert, **kwargs) -> None:
+ """
+ 'initialize_experts' is called from here instead of __init__ to allow for laoding expert weights from expert object that is passed here
+ """
expert_config: PEERConfig = expert.expert_config
- if self.num_experts == expert_config.moe_num_experts:
+ if self._num_experts == expert_config.moe__num_experts:
raise ContainerFullException()
self.initialize_experts(expert_config)
-
- def __getitem__(self, key):
- pass
+ self.expert_infos[expert.name] = expert.expert_info
+ if expert.expert_weights:
+ self.load_state_dict(expert.expert_weights)
+ self.expert_name = expert.name
+
+ def __getitem__(self, name):
+ if name != self.expert_name:
+ raise ValueError(
+ f"Expert with name {name} does not exist in this container."
+ )
+ return self
+
+ def __len__(self):
+ return self._num_experts
diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py
index 288861522..5aebfd71d 100644
--- a/mttl/models/expert_model.py
+++ b/mttl/models/expert_model.py
@@ -163,7 +163,7 @@ def experts_containers(self) -> List[ExpertContainer]:
containers = []
for _, module in self.model.named_modules():
for _, child in dict(module.named_children()).items():
- if isinstance(child, ExpertContainer) and len(child.experts) > 0:
+ if isinstance(child, ExpertContainer) and child.num_experts > 0:
containers.append(child)
return containers
diff --git a/mttl/models/modifiers/mlp.py b/mttl/models/modifiers/mlp.py
index 8d2b4cf35..66ac3d1eb 100644
--- a/mttl/models/modifiers/mlp.py
+++ b/mttl/models/modifiers/mlp.py
@@ -68,31 +68,4 @@ def parallel_linear_forward(
hidden_states = input[indices]
hidden_states = mlp._modifier_forward(hidden_states)
output.index_add_(0, indices, hidden_states.to(input.dtype))
- return mlps[0].layer(input) + output
-
-
-@dataclass
-class PEERConfig(ModifierConfig):
- n_heads: int = 8
- moe_num_experts: int = 100
- emb_dim: int = 128
- down_proj_layer: str = "fc1"
- up_proj_layer: str = "fc2"
-
-
-@Modifier.register("peer", config_cls=PEERConfig)
-class PEERModifier(Modifier):
- """
- Peer modifier basically does nothing, the job is done in the container.
- """
-
- def __init__(
- self,
- config: PEERConfig,
- **kwargs,
- ):
- super().__init__()
- self.config = config
-
- def __len__(self):
- return self.config.moe_num_experts
+ return mlps[0].layer(input) + output
\ No newline at end of file
From 2ff6d899aac94dc98393fd944d706b47cf8d5c8b Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 11:55:20 -0400
Subject: [PATCH 2/6] make sure peer can be stored to library
---
mttl/models/lightning/expert_module.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/mttl/models/lightning/expert_module.py b/mttl/models/lightning/expert_module.py
index 5d2918588..dc274ded8 100644
--- a/mttl/models/lightning/expert_module.py
+++ b/mttl/models/lightning/expert_module.py
@@ -24,6 +24,14 @@
class LightningTrainingMixin:
+
+ @property
+ def experts_names(self):
+ return self.model.experts_names
+
+ def get_expert_instance(self, name):
+ return self.model.get_expert_instance(name)
+
@property
def _log_pref(self):
return getattr(self.hparams, "logging_prefix", "")
@@ -215,7 +223,7 @@ def load_from_checkpoint(
return model
-class MoEModule(LightningEfficientCheckpoint, LightningTrainingMixin):
+class MoEModule(LightningTrainingMixin, LightningEfficientCheckpoint):
def __init__(
self,
model_object: PreTrainedModel = None,
@@ -269,7 +277,7 @@ def load_from_checkpoint(
)
model.load_state_dict(ckpt["state_dict"], strict=False)
return model
-
+
def training_step(self, batch, _):
output, context = self.forward(**batch, return_context=True)
loss = output.loss
From 8abb4de4cf47cb001ae5a2cfb95db3fcaa5a0a8c Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 11:57:29 -0400
Subject: [PATCH 3/6] black
---
mttl/models/modifiers/mlp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mttl/models/modifiers/mlp.py b/mttl/models/modifiers/mlp.py
index 66ac3d1eb..2ce59fef3 100644
--- a/mttl/models/modifiers/mlp.py
+++ b/mttl/models/modifiers/mlp.py
@@ -68,4 +68,4 @@ def parallel_linear_forward(
hidden_states = input[indices]
hidden_states = mlp._modifier_forward(hidden_states)
output.index_add_(0, indices, hidden_states.to(input.dtype))
- return mlps[0].layer(input) + output
\ No newline at end of file
+ return mlps[0].layer(input) + output
From 9082f0a1cdcfb3575650f89a3a03d49c2547e685 Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 12:01:46 -0400
Subject: [PATCH 4/6] black
---
mttl/models/containers/base.py | 2 +-
mttl/models/lightning/expert_module.py | 8 ++++----
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py
index 6e7915742..94aaeaddb 100644
--- a/mttl/models/containers/base.py
+++ b/mttl/models/containers/base.py
@@ -40,7 +40,7 @@ def __init__(self, config, layer, selector=None):
@property
def num_experts(self):
return len(self.experts)
-
+
@property
def default_expert_name(self):
return self._default_expert_name
diff --git a/mttl/models/lightning/expert_module.py b/mttl/models/lightning/expert_module.py
index dc274ded8..6d5da6b63 100644
--- a/mttl/models/lightning/expert_module.py
+++ b/mttl/models/lightning/expert_module.py
@@ -24,14 +24,14 @@
class LightningTrainingMixin:
-
+
@property
def experts_names(self):
return self.model.experts_names
-
+
def get_expert_instance(self, name):
return self.model.get_expert_instance(name)
-
+
@property
def _log_pref(self):
return getattr(self.hparams, "logging_prefix", "")
@@ -277,7 +277,7 @@ def load_from_checkpoint(
)
model.load_state_dict(ckpt["state_dict"], strict=False)
return model
-
+
def training_step(self, batch, _):
output, context = self.forward(**batch, return_context=True)
loss = output.loss
From 8937c66da7ec667981c2579c3b7152bbccbac2bc Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 13:09:21 -0400
Subject: [PATCH 5/6] save and load peer expert
---
mttl/models/containers/peer_container.py | 4 +---
projects/modular_llm/eval_library.py | 12 ++++++++++--
projects/modular_llm/train_experts.py | 4 +++-
3 files changed, 14 insertions(+), 6 deletions(-)
diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py
index b945eca0e..e2fb29418 100644
--- a/mttl/models/containers/peer_container.py
+++ b/mttl/models/containers/peer_container.py
@@ -115,15 +115,13 @@ def add_expert(self, expert: Expert, **kwargs) -> None:
def on_add_expert(self, expert: Expert, **kwargs) -> None:
"""
- 'initialize_experts' is called from here instead of __init__ to allow for laoding expert weights from expert object that is passed here
+ 'initialize_experts' is called from here
"""
expert_config: PEERConfig = expert.expert_config
if self._num_experts == expert_config.moe__num_experts:
raise ContainerFullException()
self.initialize_experts(expert_config)
self.expert_infos[expert.name] = expert.expert_info
- if expert.expert_weights:
- self.load_state_dict(expert.expert_weights)
self.expert_name = expert.name
def __getitem__(self, name):
diff --git a/projects/modular_llm/eval_library.py b/projects/modular_llm/eval_library.py
index 8ad43c300..d5152ad76 100644
--- a/projects/modular_llm/eval_library.py
+++ b/projects/modular_llm/eval_library.py
@@ -3,9 +3,9 @@
from copy import deepcopy
import torch
-import wandb
from pytorch_lightning import seed_everything
+import wandb
from mttl.arguments import EvaluationConfig
from mttl.datamodule.base import get_datamodule
from mttl.evaluators.base import EvaluatorRunner, setup_evaluators
@@ -20,7 +20,11 @@
WeightedLinearMergeConfig,
)
from mttl.models.lightning.callbacks import LossCallback
-from mttl.models.lightning.expert_module import ExpertModule, MultiExpertModule
+from mttl.models.lightning.expert_module import (
+ ExpertModule,
+ MoEModule,
+ MultiExpertModule,
+)
from mttl.models.modifiers.lora import LoRAConfig
from mttl.utils import remote_login
@@ -186,6 +190,10 @@ def run_eval(args: EvaluationConfig):
module = MultiExpertModule(**vars(expert.training_config)).to("cuda")
module.add_expert_instance(expert, is_default=True)
+ elif args.merge_or_route in ["peer"]:
+ module: MoEModule = MoEModule(**vars(an_expert.training_config)).to("cuda")
+ module.model.model.load_state_dict(an_expert.expert_weights, strict=False)
+
elif args.merge_or_route == "uniform_lora_after_op":
# Here we merge the LoRA experts after the outer product we cannot really do it
# with the lib transform, cause this would require storing large matrices in memory
diff --git a/projects/modular_llm/train_experts.py b/projects/modular_llm/train_experts.py
index bb73ca7fe..025e6e493 100644
--- a/projects/modular_llm/train_experts.py
+++ b/projects/modular_llm/train_experts.py
@@ -7,6 +7,7 @@
from mttl.arguments import Args, ExpertConfig
from mttl.datamodule.base import get_datamodule
from mttl.logging import logger, setup_logging
+from mttl.models.library.expert import Expert
from mttl.models.library.expert_library import ExpertLibrary
from mttl.models.lightning.callbacks import (
DownstreamEvalCallback,
@@ -200,7 +201,8 @@ def upload_library(expert_library, module):
if isinstance(module, MoEModule):
with expert_library.batched_commit():
for expert_name in module.experts_names:
- expert = module.get_expert_instance(expert_name)
+ expert: Expert = module.get_expert_instance(expert_name)
+ expert.expert_info.training_config = args
expert_library.add_expert(expert, expert_name)
elif isinstance(module, ExpertModule):
expert = module.as_expert()
From 002cb280eac19849ba6514f63a2207e2567d6756 Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 13:11:08 -0400
Subject: [PATCH 6/6] comment
---
mttl/models/containers/peer_container.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py
index e2fb29418..6be1ae7df 100644
--- a/mttl/models/containers/peer_container.py
+++ b/mttl/models/containers/peer_container.py
@@ -115,7 +115,7 @@ def add_expert(self, expert: Expert, **kwargs) -> None:
def on_add_expert(self, expert: Expert, **kwargs) -> None:
"""
- 'initialize_experts' is called from here
+ 'initialize_experts' is called from here in order not to break logic in expert model
"""
expert_config: PEERConfig = expert.expert_config
if self._num_experts == expert_config.moe__num_experts: