In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.mistral.modeling_mistral import *
from transformers.cache_utils import DynamicCache
import torch
from torch import nn
import copy
from types import MethodType
import datasets
from torch.utils.data import DataLoader
from itertools import islice
from tqdm import tqdm
import gc

In [3]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1").to('cuda')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
Token = {v: k for k, v in tokenizer.get_vocab().items()}

In [6]:
dataset = datasets.load_from_disk(f'/home/wwu/msmarco_mistral_test').with_format('torch', device=torch.device('cuda'))
loader = DataLoader(dataset, batch_size=128)

In [7]:
class LoraLinear(nn.Module):
    def __init__(self, in_dim, out_dim, rank, bias):
        super().__init__()
        self.A = nn.Linear(in_dim, rank, bias=False)
        self.B = nn.Linear(rank, out_dim, bias=bias)

    def forward(self, x):
        return self.B(self.A(x))

In [8]:
def to_lora(model, rank, module_names):
    '''
    Returns a copy of model Linear switched to LoraLinear modules.
    '''
    modules = dict(model.named_modules())
    for name, module in modules.items():
        parent = '.'.join(name.split('.')[:-1])
        child = name.split('.')[-1]
        if type(module) == nn.Linear and child in module_names:
            lora_module = LoraLinear(
                module.in_features, module.out_features, rank, module.bias is not None
            ).to(model.device)
            setattr(modules[parent], child, lora_module)
    return model

In [9]:
def to_myopic(model, past_key_value):
    past_key_value = DynamicCache.from_legacy_cache(past_key_value)
    def forward(*args, **kwargs):
        # This is very hacky, but otherwise it's hard to provide past_key_values
        # to myopic_forward without breaking a lot of MistralModel
        nonlocal past_key_value
        kwargs.pop('past_key_value')
        return myopic_forward(*args, **kwargs, past_key_value=past_key_value)
    
    for module in model.modules():
        if type(module) == MistralAttention:
            module.forward = MethodType(forward, module)
            module.extra_repr = lambda: 'MYOPIC'
    return model

In [16]:
def myopic_forward_mistral(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Cache] = None,
    output_attentions: bool = False,
    **kwargs,
):
    bsz, q_len, _ = hidden_states.size()

    query_states = self.q_proj(hidden_states)
    key_states = self.k_proj(hidden_states)
    value_states = self.v_proj(hidden_states)

    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
    
    kv_seq_len = key_states.shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    past_key_states = past_key_value.key_cache[self.layer_idx].detach()
    past_value_states = past_key_value.value_cache[self.layer_idx].detach()

    assert key_states.shape == past_key_states.shape, \
        f'past_key_states is wrong shape: {past_key_states.shape} instead of {key_states.shape}'
    assert value_states.shape == past_value_states.shape, \
        f'past_value_states is wrong shape: {past_value_states.shape} instead of {value_states.shape}'

    # repeat k/v heads if n_kv_heads < n_heads
    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)
    past_key_states = repeat_kv(past_key_states, self.num_key_value_groups)
    past_value_states = repeat_kv(past_value_states, self.num_key_value_groups)
    #print('KEY DIFF', torch.norm(key_states - past_key_states).item())
    #print('VAL DIFF', torch.norm(value_states - past_value_states).item())

    # query @ past_key on off-diagonal
    attn_weights = torch.matmul(query_states, past_key_states.transpose(2, 3))
    # query @ key on diagonal
    #print('ATTN DIFF', torch.norm(attn_weights.diagonal(dim1=2, dim2=3)-(query_states * key_states).sum(dim=3)).item())
    attn_weights.diagonal(dim1=2, dim2=3).copy_((query_states * key_states).sum(dim=3))
    attn_weights /= math.sqrt(self.head_dim)

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )

        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
    # attn @ past_value on off-diagonal
    attn_output = torch.matmul(attn_weights, past_value_states)
    # attn @ value on diagonal
    #print('VAL DIFF', torch.norm(attn_weights.diagonal(dim1=2, dim2=3).unsqueeze(dim=3) * (value_states - past_value_states)).item())
    attn_output += attn_weights.diagonal(dim1=2, dim2=3).unsqueeze(dim=3) * (value_states - past_value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

In [42]:
model

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  

In [33]:
del myopic
gc.collect()
torch.cuda.empty_cache()

NameError: name 'myopic' is not defined

In [34]:
myopic = copy.deepcopy(model)

In [14]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f517548c910>

In [35]:
for name, param in myopic.named_parameters():
    #if 'k_proj' in name or 'v_proj' in name:
    if 'mlp' in name:
        param.zero_()

In [36]:
num_batches = 10
it = islice(iter(loader), num_batches)
loss = 0
myopic_loss = 0
for batch in tqdm(it, total=num_batches):
    out = model(
        input_ids=batch['input_ids'], 
        labels=batch['input_ids'], 
        attention_mask=batch['attention_mask']
    )
    myopic = to_myopic(myopic, out.past_key_values)
    myopic_out = myopic(
        input_ids=batch['input_ids'], 
        labels=batch['input_ids'], 
        attention_mask=batch['attention_mask']
    )
    loss += out.loss.item()
    myopic_loss += myopic_out.loss.item()
loss /= num_batches
myopic_loss /= num_batches

100%|██████████| 10/10 [02:14<00:00, 13.44s/it]


full: 2.247
attn only: 7.708
mlp only: 6.929
mlp + q: 2.309

In [37]:
loss, myopic_loss

(2.246629166603088, 7.708231830596924)

In [231]:
to_myopic(myopic, out.past_key_values)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          MYOPIC
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): Mi

In [162]:
out = model(**input)

In [233]:
(out.logits - myopic_out.logits).norm()

tensor(0.0003, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [29]:
model = model.to('cuda')

In [30]:
lora_model = lora_model.to('cuda')

In [26]:
model.device

device(type='cpu')

In [28]:
input = tokenizer("Hello my name is", return_tensors='pt').to('cuda')

In [240]:
lora = to_lora(myopic, 8, ['k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])

In [245]:
lora_out = lora(**input)

KEY DIFF 142.73130798339844
VAL DIFF 12.946982383728027
KEY DIFF 217.15673828125
VAL DIFF 18.909976959228516
KEY DIFF 186.13214111328125
VAL DIFF 67.96466827392578
KEY DIFF 188.8489227294922
VAL DIFF 53.43036651611328
KEY DIFF 174.39666748046875
VAL DIFF 78.77042388916016
KEY DIFF 186.60130310058594
VAL DIFF 68.82740020751953
KEY DIFF 188.38514709472656
VAL DIFF 74.36092376708984
KEY DIFF 195.12734985351562
VAL DIFF 70.07445526123047
KEY DIFF 199.66671752929688
VAL DIFF 72.25618743896484
KEY DIFF 182.92108154296875
VAL DIFF 97.33354949951172
KEY DIFF 212.75222778320312
VAL DIFF 94.85765075683594
KEY DIFF 205.39918518066406
VAL DIFF 87.05265045166016
KEY DIFF 209.25807189941406
VAL DIFF 113.35973358154297
KEY DIFF 208.14002990722656
VAL DIFF 117.33596801757812
KEY DIFF 206.93878173828125
VAL DIFF 114.59712219238281
KEY DIFF 204.29835510253906
VAL DIFF 153.5397186279297
KEY DIFF 214.47869873046875
VAL DIFF 134.12245178222656
KEY DIFF 222.3302764892578
VAL DIFF 126.56486511230469
KEY DIFF

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 2.12 MiB is free. Process 1531090 has 79.13 GiB memory in use. Of the allocated memory 78.63 GiB is allocated by PyTorch, and 15.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [62]:
out = model(
    **input,
)
#lora_out = lora_model(**input)

In [85]:
cache = DynamicCache.from_legacy_cache(out.past_key_values)

In [91]:
len(cache.key_cache)

32

In [36]:
lora_out.logits.shape

torch.Size([1, 5, 32000])

In [18]:
type(model.model.layers[0].self_attn.q_proj) == nn.Linear

True

In [6]:
[x for x in dir(model) if 'module' in x]

['__module__',
 '_get_no_split_modules',
 '_keep_in_fp32_modules',
 '_keep_in_fp32_modules',
 '_modules',
 '_no_split_modules',
 'add_module',
 'get_submodule',
 'modules',
 'named_modules',
 'register_module',
 'retrieve_modules_from_names']

In [15]:
model.get_submodule('model.layers.0.self_attn.q_proj') = nn.Linear(5, 5)

SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? (3646139102.py, line 1)

In [16]:
dict(model.named_modules())

{'': MistralForCausalLM(
   (model): MistralModel(
     (embed_tokens): Embedding(32000, 4096)
     (layers): ModuleList(
       (0-31): 32 x MistralDecoderLayer(
         (self_attn): MistralAttention(
           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
           (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
           (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
           (rotary_emb): MistralRotaryEmbedding()
         )
         (mlp): MistralMLP(
           (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
           (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
           (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
           (act_fn): SiLU()
         )
         (input_layernorm): MistralRMSNorm()
         (post_attention_layernorm): MistralRMSNorm()
       )
     )
     