# Domain Enhancement


## Example BERT Model


In [None]:
from transformers import AutoModel
from nlppets.torch import (
    count_parameters,
    nested_freeze_tensor,
    count_trainable_parameters,
)
from nlppets.transformers.model.bert import domain_enhance_att, domain_enhance_ffn

model = AutoModel.from_pretrained("hfl/chinese-roberta-wwm-ext")
# if using fp tuning
model = domain_enhance_ffn(model, {"domain_name": 1024})
# if using att enhancement
model = domain_enhance_att(model, {"domain_name": 4})

# freeze tensors
ENHANCEMENTS = {"domain_name"}
model = nested_freeze_tensor(
    model,
    exclude={
        # if using fp tuning
        *(f"bert.encoder.layer.*.intermediate.{e}.*" for e in ENHANCEMENTS),
        *(f"bert.encoder.layer.*.output.{e}.*" for e in ENHANCEMENTS),
        # if using att enhancement
        *(f"bert.encoder.layer.*.attention.self.{e}_query.*" for e in ENHANCEMENTS),
        *(f"bert.encoder.layer.*.attention.self.{e}_key.*" for e in ENHANCEMENTS),
        *(f"bert.encoder.layer.*.attention.self.{e}_value.*" for e in ENHANCEMENTS),
        *(f"bert.encoder.layer.*.attention.output.{e}.*" for e in ENHANCEMENTS),
    },
)

params = count_parameters(model)
trainable = count_trainable_parameters(model)
print(
    "Parameters:",
    params,
    "Trainable:",
    trainable,
    "Trainable/Parameters:",
    f"{trainable/params:.2%}",
)


## Example ChatGLM Model


In [None]:
from typing import Type
from transformers import AutoConfig, PreTrainedModel
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from nlppets.torch import (
    count_parameters,
    nested_freeze_tensor,
    count_trainable_parameters,
)
from nlppets.transformers.model.chatglm import domain_enhance_att, domain_enhance_ffn

config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
class_ref = config.auto_map["AutoModel"]
ChatGLMForConditionalGeneration: Type[PreTrainedModel] = get_class_from_dynamic_module(
    class_ref, "THUDM/chatglm-6b"
)

# if using fp tuning
ChatGLMForConditionalGeneration = domain_enhance_ffn(
    ChatGLMForConditionalGeneration, {"domain_name": 1024}
)
# if using att enhancement
ChatGLMForConditionalGeneration = domain_enhance_att(
    ChatGLMForConditionalGeneration, {"domain_name": 4}
)

model: PreTrainedModel = ChatGLMForConditionalGeneration.from_pretrained(
    "THUDM/chatglm-6b", trust_remote_code=True
)  # type: ignore

# freeze tensors
ENHANCEMENTS = {"domain_name"}
model = nested_freeze_tensor(
    model,
    exclude={
        # if using fp tuning
        *(f"transformer.layers.*.mlp.{e}_up.*" for e in ENHANCEMENTS),
        *(f"transformer.layers.*.mlp.{e}_down.*" for e in ENHANCEMENTS),
        # if using att enhancement
        *(f"transformer.layers.*.attention.{e}.*" for e in ENHANCEMENTS),
        *(f"transformer.layers.*.attention.{e}_output.*" for e in ENHANCEMENTS),
    },
)

params = count_parameters(model)
trainable = count_trainable_parameters(model)
print(
    "Parameters:",
    params,
    "Trainable:",
    trainable,
    "Trainable/Parameters:",
    f"{trainable/params:.2%}",
)

