From 0f2c6f9a8bfc0a58c9e677e122adeba9983982f0 Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 10:59:25 -0400
Subject: [PATCH 1/6] peer contained is modifier

---
 mttl/models/containers/base.py           |  5 +++
 mttl/models/containers/peer_container.py | 56 +++++++++++++++++++-----
 mttl/models/expert_model.py              |  2 +-
 mttl/models/modifiers/mlp.py             | 29 +-----------
 4 files changed, 51 insertions(+), 41 deletions(-)

diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py
index 1f1406f2f..6e7915742 100644
--- a/mttl/models/containers/base.py
+++ b/mttl/models/containers/base.py
@@ -35,7 +35,12 @@ def __init__(self, config, layer, selector=None):
         self.selector = selector or TaskNameSelector()
         self._default_expert_name = None
         self.expert_infos = {}
+        self.experts = nn.ModuleDict({})
 
+    @property
+    def num_experts(self):
+        return len(self.experts)
+    
     @property
     def default_expert_name(self):
         return self._default_expert_name
diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py
index 99a9a3045..b945eca0e 100644
--- a/mttl/models/containers/peer_container.py
+++ b/mttl/models/containers/peer_container.py
@@ -1,3 +1,5 @@
+from dataclasses import dataclass
+
 import torch
 from torch import nn
 
@@ -7,14 +9,26 @@
     MultiheadBatchSequenceExpertsAndWeightsSelectorOutput,
 )
 from mttl.models.library.expert import Expert
-from mttl.models.modifiers.mlp import PEERConfig, PEERModifier
+from mttl.models.modifiers.base import Modifier, ModifierConfig
+
+# from mttl.models.modifiers.mlp import PEERConfig, PEERModifier
 
 # diff architectures name those layers differently
 DOWN_NAMES = ["fc1", "c_fc"]
 UP_NAMES = ["fc2", "c_proj"]
 
 
-class PEERMLPContainer(ExpertContainer):
+@dataclass
+class PEERConfig(ModifierConfig):
+    n_heads: int = 8
+    moe__num_experts: int = 100
+    emb_dim: int = 128
+    down_proj_layer: str = "fc1"
+    up_proj_layer: str = "fc2"
+
+
+@Modifier.register("peer", config_cls=PEERConfig)
+class PEERMLPContainer(ExpertContainer, Modifier):
     """
     PEER layer from Mixture of A Million Experts (https://arxiv.org/pdf/2407.04153)
 
@@ -33,7 +47,7 @@ def __init__(
         **kwargs,
     ):
         super().__init__(config, module)
-        self.num_experts = 0
+        self._num_experts = 0
         down_names = DOWN_NAMES + [
             config.down_proj_layer
         ]  # names of the up and down projection layers in the MLP block
@@ -55,21 +69,25 @@ def __init__(
         self.dtype = next(self.layer.parameters()).dtype
 
         self.layer = nn.Identity()
+        self.expert_name = None
         self.layer.in_features = self.input_dim
-        self.experts = PEERModifier(config)
+
+    @property
+    def num_experts(self):
+        return self._num_experts
 
     def initialize_experts(self, expert_config: PEERConfig) -> None:
-        self.num_experts = expert_config.moe_num_experts
+        self._num_experts = expert_config.moe__num_experts
         assert (
-            self.num_experts**0.5
+            self._num_experts**0.5
         ).is_integer(), "Number of experts must be a square number"
         self.peer_weight_down_embed = nn.Embedding(
-            num_embeddings=self.num_experts,
+            num_embeddings=self._num_experts,
             embedding_dim=self.input_dim,
             dtype=self.dtype,
         )
         self.peer_weight_up_embed = nn.Embedding(
-            num_embeddings=self.num_experts,
+            num_embeddings=self._num_experts,
             embedding_dim=self.output_dim,
             dtype=self.dtype,
         )
@@ -96,10 +114,24 @@ def add_expert(self, expert: Expert, **kwargs) -> None:
         return self.on_add_expert(expert, **kwargs)
 
     def on_add_expert(self, expert: Expert, **kwargs) -> None:
+        """
+        'initialize_experts' is called from here instead of __init__ to allow for laoding expert weights from expert object that is passed here
+        """
         expert_config: PEERConfig = expert.expert_config
-        if self.num_experts == expert_config.moe_num_experts:
+        if self._num_experts == expert_config.moe__num_experts:
             raise ContainerFullException()
         self.initialize_experts(expert_config)
-
-    def __getitem__(self, key):
-        pass
+        self.expert_infos[expert.name] = expert.expert_info
+        if expert.expert_weights:
+            self.load_state_dict(expert.expert_weights)
+        self.expert_name = expert.name
+
+    def __getitem__(self, name):
+        if name != self.expert_name:
+            raise ValueError(
+                f"Expert with name {name} does not exist in this container."
+            )
+        return self
+
+    def __len__(self):
+        return self._num_experts
diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py
index 288861522..5aebfd71d 100644
--- a/mttl/models/expert_model.py
+++ b/mttl/models/expert_model.py
@@ -163,7 +163,7 @@ def experts_containers(self) -> List[ExpertContainer]:
         containers = []
         for _, module in self.model.named_modules():
             for _, child in dict(module.named_children()).items():
-                if isinstance(child, ExpertContainer) and len(child.experts) > 0:
+                if isinstance(child, ExpertContainer) and child.num_experts > 0:
                     containers.append(child)
         return containers
 
diff --git a/mttl/models/modifiers/mlp.py b/mttl/models/modifiers/mlp.py
index 8d2b4cf35..66ac3d1eb 100644
--- a/mttl/models/modifiers/mlp.py
+++ b/mttl/models/modifiers/mlp.py
@@ -68,31 +68,4 @@ def parallel_linear_forward(
             hidden_states = input[indices]
             hidden_states = mlp._modifier_forward(hidden_states)
             output.index_add_(0, indices, hidden_states.to(input.dtype))
-        return mlps[0].layer(input) + output
-
-
-@dataclass
-class PEERConfig(ModifierConfig):
-    n_heads: int = 8
-    moe_num_experts: int = 100
-    emb_dim: int = 128
-    down_proj_layer: str = "fc1"
-    up_proj_layer: str = "fc2"
-
-
-@Modifier.register("peer", config_cls=PEERConfig)
-class PEERModifier(Modifier):
-    """
-    Peer modifier basically does nothing, the job is done in the container.
-    """
-
-    def __init__(
-        self,
-        config: PEERConfig,
-        **kwargs,
-    ):
-        super().__init__()
-        self.config = config
-
-    def __len__(self):
-        return self.config.moe_num_experts
+        return mlps[0].layer(input) + output
\ No newline at end of file

From 2ff6d899aac94dc98393fd944d706b47cf8d5c8b Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 11:55:20 -0400
Subject: [PATCH 2/6] make sure peer can be stored to library

---
 mttl/models/lightning/expert_module.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/mttl/models/lightning/expert_module.py b/mttl/models/lightning/expert_module.py
index 5d2918588..dc274ded8 100644
--- a/mttl/models/lightning/expert_module.py
+++ b/mttl/models/lightning/expert_module.py
@@ -24,6 +24,14 @@
 
 
 class LightningTrainingMixin:
+    
+    @property
+    def experts_names(self):
+        return self.model.experts_names
+    
+    def get_expert_instance(self, name):
+        return self.model.get_expert_instance(name)
+    
     @property
     def _log_pref(self):
         return getattr(self.hparams, "logging_prefix", "")
@@ -215,7 +223,7 @@ def load_from_checkpoint(
         return model
 
 
-class MoEModule(LightningEfficientCheckpoint, LightningTrainingMixin):
+class MoEModule(LightningTrainingMixin, LightningEfficientCheckpoint):
     def __init__(
         self,
         model_object: PreTrainedModel = None,
@@ -269,7 +277,7 @@ def load_from_checkpoint(
         )
         model.load_state_dict(ckpt["state_dict"], strict=False)
         return model
-
+    
     def training_step(self, batch, _):
         output, context = self.forward(**batch, return_context=True)
         loss = output.loss

From 8abb4de4cf47cb001ae5a2cfb95db3fcaa5a0a8c Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 11:57:29 -0400
Subject: [PATCH 3/6] black

---
 mttl/models/modifiers/mlp.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mttl/models/modifiers/mlp.py b/mttl/models/modifiers/mlp.py
index 66ac3d1eb..2ce59fef3 100644
--- a/mttl/models/modifiers/mlp.py
+++ b/mttl/models/modifiers/mlp.py
@@ -68,4 +68,4 @@ def parallel_linear_forward(
             hidden_states = input[indices]
             hidden_states = mlp._modifier_forward(hidden_states)
             output.index_add_(0, indices, hidden_states.to(input.dtype))
-        return mlps[0].layer(input) + output
\ No newline at end of file
+        return mlps[0].layer(input) + output

From 9082f0a1cdcfb3575650f89a3a03d49c2547e685 Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 12:01:46 -0400
Subject: [PATCH 4/6] black

---
 mttl/models/containers/base.py         | 2 +-
 mttl/models/lightning/expert_module.py | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mttl/models/containers/base.py b/mttl/models/containers/base.py
index 6e7915742..94aaeaddb 100644
--- a/mttl/models/containers/base.py
+++ b/mttl/models/containers/base.py
@@ -40,7 +40,7 @@ def __init__(self, config, layer, selector=None):
     @property
     def num_experts(self):
         return len(self.experts)
-    
+
     @property
     def default_expert_name(self):
         return self._default_expert_name
diff --git a/mttl/models/lightning/expert_module.py b/mttl/models/lightning/expert_module.py
index dc274ded8..6d5da6b63 100644
--- a/mttl/models/lightning/expert_module.py
+++ b/mttl/models/lightning/expert_module.py
@@ -24,14 +24,14 @@
 
 
 class LightningTrainingMixin:
-    
+
     @property
     def experts_names(self):
         return self.model.experts_names
-    
+
     def get_expert_instance(self, name):
         return self.model.get_expert_instance(name)
-    
+
     @property
     def _log_pref(self):
         return getattr(self.hparams, "logging_prefix", "")
@@ -277,7 +277,7 @@ def load_from_checkpoint(
         )
         model.load_state_dict(ckpt["state_dict"], strict=False)
         return model
-    
+
     def training_step(self, batch, _):
         output, context = self.forward(**batch, return_context=True)
         loss = output.loss

From 8937c66da7ec667981c2579c3b7152bbccbac2bc Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 13:09:21 -0400
Subject: [PATCH 5/6] save and load peer expert

---
 mttl/models/containers/peer_container.py |  4 +---
 projects/modular_llm/eval_library.py     | 12 ++++++++++--
 projects/modular_llm/train_experts.py    |  4 +++-
 3 files changed, 14 insertions(+), 6 deletions(-)

diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py
index b945eca0e..e2fb29418 100644
--- a/mttl/models/containers/peer_container.py
+++ b/mttl/models/containers/peer_container.py
@@ -115,15 +115,13 @@ def add_expert(self, expert: Expert, **kwargs) -> None:
 
     def on_add_expert(self, expert: Expert, **kwargs) -> None:
         """
-        'initialize_experts' is called from here instead of __init__ to allow for laoding expert weights from expert object that is passed here
+        'initialize_experts' is called from here
         """
         expert_config: PEERConfig = expert.expert_config
         if self._num_experts == expert_config.moe__num_experts:
             raise ContainerFullException()
         self.initialize_experts(expert_config)
         self.expert_infos[expert.name] = expert.expert_info
-        if expert.expert_weights:
-            self.load_state_dict(expert.expert_weights)
         self.expert_name = expert.name
 
     def __getitem__(self, name):
diff --git a/projects/modular_llm/eval_library.py b/projects/modular_llm/eval_library.py
index 8ad43c300..d5152ad76 100644
--- a/projects/modular_llm/eval_library.py
+++ b/projects/modular_llm/eval_library.py
@@ -3,9 +3,9 @@
 from copy import deepcopy
 
 import torch
-import wandb
 from pytorch_lightning import seed_everything
 
+import wandb
 from mttl.arguments import EvaluationConfig
 from mttl.datamodule.base import get_datamodule
 from mttl.evaluators.base import EvaluatorRunner, setup_evaluators
@@ -20,7 +20,11 @@
     WeightedLinearMergeConfig,
 )
 from mttl.models.lightning.callbacks import LossCallback
-from mttl.models.lightning.expert_module import ExpertModule, MultiExpertModule
+from mttl.models.lightning.expert_module import (
+    ExpertModule,
+    MoEModule,
+    MultiExpertModule,
+)
 from mttl.models.modifiers.lora import LoRAConfig
 from mttl.utils import remote_login
 
@@ -186,6 +190,10 @@ def run_eval(args: EvaluationConfig):
         module = MultiExpertModule(**vars(expert.training_config)).to("cuda")
         module.add_expert_instance(expert, is_default=True)
 
+    elif args.merge_or_route in ["peer"]:
+        module: MoEModule = MoEModule(**vars(an_expert.training_config)).to("cuda")
+        module.model.model.load_state_dict(an_expert.expert_weights, strict=False)
+
     elif args.merge_or_route == "uniform_lora_after_op":
         # Here we merge the LoRA experts after the outer product we cannot really do it
         # with the lib transform, cause this would require storing large matrices in memory
diff --git a/projects/modular_llm/train_experts.py b/projects/modular_llm/train_experts.py
index bb73ca7fe..025e6e493 100644
--- a/projects/modular_llm/train_experts.py
+++ b/projects/modular_llm/train_experts.py
@@ -7,6 +7,7 @@
 from mttl.arguments import Args, ExpertConfig
 from mttl.datamodule.base import get_datamodule
 from mttl.logging import logger, setup_logging
+from mttl.models.library.expert import Expert
 from mttl.models.library.expert_library import ExpertLibrary
 from mttl.models.lightning.callbacks import (
     DownstreamEvalCallback,
@@ -200,7 +201,8 @@ def upload_library(expert_library, module):
                 if isinstance(module, MoEModule):
                     with expert_library.batched_commit():
                         for expert_name in module.experts_names:
-                            expert = module.get_expert_instance(expert_name)
+                            expert: Expert = module.get_expert_instance(expert_name)
+                            expert.expert_info.training_config = args
                             expert_library.add_expert(expert, expert_name)
                 elif isinstance(module, ExpertModule):
                     expert = module.as_expert()

From 002cb280eac19849ba6514f63a2207e2567d6756 Mon Sep 17 00:00:00 2001
From: oleksost <ostapy2@gmail.com>
Date: Fri, 18 Oct 2024 13:11:08 -0400
Subject: [PATCH 6/6] comment

---
 mttl/models/containers/peer_container.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mttl/models/containers/peer_container.py b/mttl/models/containers/peer_container.py
index e2fb29418..6be1ae7df 100644
--- a/mttl/models/containers/peer_container.py
+++ b/mttl/models/containers/peer_container.py
@@ -115,7 +115,7 @@ def add_expert(self, expert: Expert, **kwargs) -> None:
 
     def on_add_expert(self, expert: Expert, **kwargs) -> None:
         """
-        'initialize_experts' is called from here
+        'initialize_experts' is called from here in order not to break logic in expert model
         """
         expert_config: PEERConfig = expert.expert_config
         if self._num_experts == expert_config.moe__num_experts: