In [1]:
import torch
from transformers import BloomForCausalLM
from transformers.models.bloom.configuration_bloom import BloomConfig
from collections import OrderedDict

from node_attribution.bloom_for_gradient_node_attribution import BloomForCausalLMForNodeAttribution

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
temp_model = BloomForCausalLM.from_pretrained(f"bigscience/bloom-560m")
temp_model_statedict = temp_model.state_dict()
num_layers = temp_model.transformer.config.num_hidden_layers
hidden_size = temp_model.config.hidden_size

In [3]:
new_state_dict = OrderedDict()

In [4]:
qk_index = torch.tensor([i for i in range(hidden_size * 2)])
v_index = torch.tensor([i for i in range(hidden_size * 2, hidden_size * 3)])

for layer_name in temp_model_statedict:
    if "query_key_value" in layer_name:
        qk_name = layer_name.replace("query_key_value", "query_key")
        v_name = layer_name.replace("query_key_value", "value")
        
        qkv_weights = temp_model_statedict[layer_name]
        
        qk_weights = torch.index_select(qkv_weights, 0, qk_index)
        new_state_dict[qk_name] = qk_weights
        
        if "bias" in layer_name:
            prev_v_weights = layer_name.replace("query_key_value.bias", "value.weight")
            new_state_dict.move_to_end(prev_v_weights)
        
        v_weights = torch.index_select(qkv_weights, 0, v_index)
        new_state_dict[v_name] = v_weights
        
        
    else:
        layer_value = temp_model_statedict[layer_name]
        new_state_dict[layer_name] = layer_value
    

In [5]:
bloom_config = BloomConfig(
    vocab_size=250880,
    hidden_size=1024,
    n_layer=24,
    n_head=16,
    layer_norm_epsilon=1e-5,
    initializer_range=0.02,
    use_cache=True,
    bos_token_id=1,
    eos_token_id=2,
    apply_residual_connection_post_layernorm=False,
    hidden_dropout=0.0,
    attention_dropout=0.0,
    pretraining_tp=1,  # TP rank used when training with megatron
    slow_but_exact=False,
    attention_softmax_in_fp32=True,
    bias_dropout_fusion=True,
    masked_softmax_fusion=True,
    offset_alibi=100,
    pad_token_id=3,
    seq_length=2048,
    skip_bias_add=True,
    skip_bias_add_qkv=False,
    unk_token_id=0,
    
)

In [6]:
model = BloomForCausalLMForNodeAttribution(bloom_config)

In [7]:
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [8]:
new_state_dict

OrderedDict([('transformer.word_embeddings.weight',
              tensor([[-0.0099, -0.0048, -0.0111,  ..., -0.0426,  0.0099,  0.0212],
                      [ 0.0048, -0.0127,  0.0138,  ..., -0.0448,  0.0003, -0.0120],
                      [ 0.0065,  0.0239,  0.0050,  ..., -0.0431, -0.0067,  0.0137],
                      ...,
                      [-0.0028, -0.0038, -0.0012,  ..., -0.0252,  0.0013,  0.0012],
                      [-0.0028, -0.0038, -0.0012,  ..., -0.0252,  0.0013,  0.0012],
                      [-0.0028, -0.0038, -0.0012,  ..., -0.0252,  0.0013,  0.0012]])),
             ('transformer.word_embeddings_layernorm.weight',
              tensor([0.4409, 0.3167, 0.4749,  ..., 0.0816, 0.2927, 0.6006])),
             ('transformer.word_embeddings_layernorm.bias',
              tensor([-0.0513,  0.0164,  0.0052,  ...,  0.2412,  0.0072, -0.0292])),
             ('transformer.h.0.input_layernorm.weight',
              tensor([0.6621, 0.8457, 0.5884,  ..., 3.2090, 0.9180, 0.55

In [9]:
for key, value in temp_model_statedict.items():
    print(key, value.shape)

transformer.word_embeddings.weight torch.Size([250880, 1024])
transformer.word_embeddings_layernorm.weight torch.Size([1024])
transformer.word_embeddings_layernorm.bias torch.Size([1024])
transformer.h.0.input_layernorm.weight torch.Size([1024])
transformer.h.0.input_layernorm.bias torch.Size([1024])
transformer.h.0.self_attention.query_key_value.weight torch.Size([3072, 1024])
transformer.h.0.self_attention.query_key_value.bias torch.Size([3072])
transformer.h.0.self_attention.dense.weight torch.Size([1024, 1024])
transformer.h.0.self_attention.dense.bias torch.Size([1024])
transformer.h.0.post_attention_layernorm.weight torch.Size([1024])
transformer.h.0.post_attention_layernorm.bias torch.Size([1024])
transformer.h.0.mlp.dense_h_to_4h.weight torch.Size([4096, 1024])
transformer.h.0.mlp.dense_h_to_4h.bias torch.Size([4096])
transformer.h.0.mlp.dense_4h_to_h.weight torch.Size([1024, 4096])
transformer.h.0.mlp.dense_4h_to_h.bias torch.Size([1024])
transformer.h.1.input_layernorm.weight