Skip to content

Commit

Permalink
Merge pull request #149 from aryamanarora/main
Browse files Browse the repository at this point in the history
[P1] Add Gemma + minor fixes
  • Loading branch information
frankaging committed Apr 24, 2024
2 parents f4b2fc9 + a6fe305 commit 7d94cdd
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 3 deletions.
Empty file added pyvene/models/gemma/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions pyvene/models/gemma/modelings_intervenable_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
Each modeling file in this library is a mapping between
abstract naming of intervention anchor points and actual
model module defined in the huggingface library.
We also want to let the intervention library know how to
config the dimensions of intervention based on model config
defined in the huggingface library.
"""


import torch
from ..constants import *


gemma_type_to_module_mapping = {
"block_input": ("layers[%s]", CONST_INPUT_HOOK),
"block_output": ("layers[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("layers[%s].mlp.act_fn", CONST_OUTPUT_HOOK),
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].self_attn.o_proj", CONST_INPUT_HOOK),
"attention_output": ("layers[%s].self_attn", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].self_attn", CONST_INPUT_HOOK),
"query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("layers[%s].self_attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("layers[%s].self_attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("layers[%s].self_attn.v_proj", CONST_OUTPUT_HOOK),
}


gemma_type_to_dimension_mapping = {
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": ("intermediate_size",),
"mlp_output": ("hidden_size",),
"mlp_input": ("hidden_size",),
"attention_value_output": ("hidden_size",),
"head_attention_value_output": ("hidden_size/num_attention_heads",),
"attention_output": ("hidden_size",),
"attention_input": ("hidden_size",),
"query_output": ("hidden_size",),
"key_output": ("hidden_size",),
"value_output": ("hidden_size",),
"head_query_output": ("hidden_size/num_attention_heads",),
"head_key_output": ("hidden_size/num_attention_heads",),
"head_value_output": ("hidden_size/num_attention_heads",),
}


"""gemma model with LM head"""
gemma_lm_type_to_module_mapping = {}
for k, v in gemma_type_to_module_mapping.items():
gemma_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


gemma_lm_type_to_dimension_mapping = gemma_type_to_dimension_mapping


"""gemma model with classifier head"""
gemma_classifier_type_to_module_mapping = {}
for k, v in gemma_type_to_module_mapping.items():
gemma_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


gemma_classifier_type_to_dimension_mapping = gemma_type_to_dimension_mapping


def create_gemma(
name="google/gemma-2b-it", cache_dir=None, dtype=torch.bfloat16
):
"""Creates a Gemma Causal LM model, config, and tokenizer from the given name and revision"""
from transformers import GemmaForCausalLM, GemmaTokenizer, GemmaConfig

config = GemmaConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = GemmaTokenizer.from_pretrained(name, cache_dir=cache_dir)
gemma = GemmaForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=dtype, # save memory
)
print("loaded model")
return config, tokenizer, gemma
9 changes: 9 additions & 0 deletions pyvene/models/intervenable_modelcard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .constants import *
from .llama.modelings_intervenable_llama import *
from .mistral.modellings_intervenable_mistral import *
from .gemma.modelings_intervenable_gemma import *
from .gpt2.modelings_intervenable_gpt2 import *
from .gpt_neo.modelings_intervenable_gpt_neo import *
from .gpt_neox.modelings_intervenable_gpt_neox import *
Expand Down Expand Up @@ -39,12 +40,16 @@
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_module_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_module_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_module_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_module_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_module_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_module_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_module_mapping,
BlipWrapper: blip_wrapper_type_to_module_mapping,
Expand All @@ -65,12 +70,16 @@
hf_models.gpt2.modeling_gpt2.GPT2ForSequenceClassification: gpt2_classifier_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaModel: llama_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForCausalLM: llama_lm_type_to_dimension_mapping,
hf_models.llama.modeling_llama.LlamaForSequenceClassification: llama_classifier_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoModel: gpt_neo_type_to_dimension_mapping,
hf_models.gpt_neo.modeling_gpt_neo.GPTNeoForCausalLM: gpt_neo_lm_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXModel: gpt_neox_type_to_dimension_mapping,
hf_models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM: gpt_neox_lm_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralModel: mistral_type_to_dimension_mapping,
hf_models.mistral.modeling_mistral.MistralForCausalLM: mistral_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaModel: gemma_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForCausalLM: gemma_lm_type_to_dimension_mapping,
hf_models.gemma.modeling_gemma.GemmaForSequenceClassification: gemma_classifier_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForQuestionAnswering: blip_type_to_dimension_mapping,
hf_models.blip.modeling_blip.BlipForImageTextRetrieval: blip_itm_type_to_dimension_mapping,
BlipWrapper: blip_wrapper_type_to_dimension_mapping,
Expand Down
9 changes: 9 additions & 0 deletions pyvene/models/llama/modelings_intervenable_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@
llama_lm_type_to_dimension_mapping = llama_type_to_dimension_mapping


"""llama model with classifier head"""
llama_classifier_type_to_module_mapping = {}
for k, v in llama_type_to_module_mapping.items():
llama_classifier_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])


llama_classifier_type_to_dimension_mapping = llama_type_to_dimension_mapping


def create_llama(
name="sharpbai/alpaca-7b-merged", cache_dir=None, dtype=torch.bfloat16
):
Expand Down
6 changes: 3 additions & 3 deletions pyvene/models/mistral/modellings_intervenable_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
}


"""llama model with LM head"""
"""mistral model with LM head"""
mistral_lm_type_to_module_mapping = {}
for k, v in mistral_type_to_module_mapping.items():
mistral_lm_type_to_module_mapping[k] = (f"model.{v[0]}", v[1])
Expand All @@ -68,11 +68,11 @@ def create_mistral(

config = AutoConfig.from_pretrained(name, cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained(name, cache_dir=cache_dir)
llama = AutoModelForCausalLM.from_pretrained(
mistral = AutoModelForCausalLM.from_pretrained(
name,
config=config,
cache_dir=cache_dir,
torch_dtype=torch.bfloat16, # save memory
)
print("loaded model")
return config, tokenizer, llama
return config, tokenizer, mistral

0 comments on commit 7d94cdd

Please sign in to comment.