Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LowRankAdapter not working with Bert models #23

Closed
MrigankRaman opened this issue Jun 13, 2022 · 1 comment
Closed

LowRankAdapter not working with Bert models #23

MrigankRaman opened this issue Jun 13, 2022 · 1 comment

Comments

@MrigankRaman
Copy link

MrigankRaman commented Jun 13, 2022

Ok I am trying to use LowRankAdapterModel with bert-base-uncased and bert-large-uncased and I am getting the following error. Please look into it


KeyError Traceback (most recent call last)
in ()
1 from opendelta import LowRankAdapterModel
----> 2 delta_model1 = LowRankAdapterModel(backbone_model=model)
3 delta_model1.freeze_module(set_state_dict = True)
4 delta_model1.log(delta_ratio=True, trainable_ratio=True, visualization=True)

5 frames
/usr/local/lib/python3.7/dist-packages/opendelta/delta_models/low_rank_adapter.py in init(self, backbone_model, reduction_factor, non_linearity, low_rank_w_init, low_rank_rank, modified_modules, exclude_modules, unfrozen_modules, common_structure, interactive_modify)
167 unfrozen_modules=unfrozen_modules,
168 common_structure=common_structure,
--> 169 interactive_modify=interactive_modify,
170 )
171 arg_names = get_arg_names_inside_func(self.init)

/usr/local/lib/python3.7/dist-packages/opendelta/basemodel.py in init(self, backbone_model, modified_modules, exclude_modules, unfrozen_modules, interactive_modify, common_structure)
130 self.common_structure = common_structure
131 if self.common_structure:
--> 132 self.structure_mapping = CommonStructureMap.load(self.backbone_model)
133 else:
134 self.structure_mapping = None

/usr/local/lib/python3.7/dist-packages/opendelta/utils/structure_mapping.py in load(cls, backbone_model, strict, warining, visualize)
317 if backbone_class not in cls.Mappings:
318 raise KeyError(backbone_class)
--> 319 mapping = cls.Mappings[backbone_class]
320 if visualize:
321 logger.info("Since you are using the common structure mapping, draw the transformed parameter structure for checking.")

/usr/local/lib/python3.7/dist-packages/opendelta/utils/structure_mapping.py in getitem(self, key)
279 raise KeyError(key)
280 value = self._mapping_string[key]
--> 281 self._mapping[key] = eval(value)
282 return self._mapping[key]
283

/usr/local/lib/python3.7/dist-packages/opendelta/utils/structure_mapping.py in ()

/usr/local/lib/python3.7/dist-packages/opendelta/utils/structure_mapping.py in mapping_for_SequenceClassification(mapping, type)
252 }
253 elif type == "bert":
--> 254 mapping.pop("lm_head")
255 mapping["classifier"] = {"name": "classifier"}
256 elif type == "deberta":

KeyError: 'lm_head'

This is how model is defined

config = AutoConfig.from_pretrained(
"bert-base-uncased"
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
config.dropout_rate = 0.0
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased",
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model = AutoModelForSequenceClassification.from_pretrained(
"bert-base-uncased",
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model.resize_token_embeddings(len(tokenizer))

@ShengdingHu
Copy link
Collaborator

It is because the BertModel you used is a wrapped one (in the example, BertForSequenceClassification). You can solve it by providing customized modified_modules. Moreover, the default configuration of delta models of a wrapped model is supported systematically in #34. If you are still interested in this issue, you can have a look :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants