In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch import nn
import copy
from types import MethodType
import datasets
from torch.utils.data import DataLoader
from itertools import islice
from tqdm import tqdm
import gc
import pandas as pd
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.pytorch_utils import Conv1D

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
model = AutoModelForCausalLM.from_pretrained("gpt2").to('cuda')

In [6]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
Token = {v: k for k, v in tokenizer.get_vocab().items()}

In [7]:
dataset = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_GPT2_64tokens_1m').with_format('torch', device=torch.device('cuda'))
loader = DataLoader(dataset['test'], batch_size=128)

In [37]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [16]:
def to_myopic_gpt2(model, past_key_values):    
    def forward(self, *args, **kwargs):
        nonlocal past_key_values
        kwargs.pop('layer_past')
        return myopic_forward_gpt2(self, *args, **kwargs, layer_past=past_key_values[self.layer_idx])
    for name, module in model.named_modules():
        #if type(module) == GPT2Attention:  # type doesn't match? idk why
        if name.split('.')[-1] == 'attn':
            layer_past = past_key_values[module.layer_idx]            
            module.forward = MethodType(forward, module)
            module.extra_repr = lambda: 'MYOPIC'
    return model

In [20]:
def myopic_attn_gpt2(
    query, key, value, past_key, past_value, attention_mask, head_mask,
    bias,
    attn_dropout,
    scale_attn_weights=True,
):
    #import pdb; pdb.set_trace()
    attn_weights = torch.matmul(query, past_key.transpose(-1, -2))
    attn_weights.diagonal(dim1=2, dim2=3).copy_((query * key).sum(dim=3))

    if scale_attn_weights:
        attn_weights = attn_weights / torch.full(
            [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
        )

    query_length, key_length = query.size(-2), key.size(-2)
    causal_mask = bias[:, :, key_length - query_length : key_length, :key_length]
    mask_value = torch.finfo(attn_weights.dtype).min
    # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
    # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
    mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
    attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

    if attention_mask is not None:
        # Apply the attention mask
        attn_weights = attn_weights + attention_mask

    attn_weights = nn.functional.softmax(attn_weights, dim=-1)

    # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
    attn_weights = attn_weights.type(value.dtype)
    attn_weights = attn_dropout(attn_weights)

    # Mask heads if we want to
    if head_mask is not None:
        attn_weights = attn_weights * head_mask

    attn_output = torch.matmul(attn_weights, past_value)
    attn_output += attn_weights.diagonal(dim1=2, dim2=3).unsqueeze(dim=3) * (value - past_value)

    return attn_output, attn_weights

In [19]:
def myopic_forward_gpt2(
    self,
    hidden_states,
    layer_past,
    attention_mask=None,
    head_mask=None,
    output_attentions=False,
    **kwargs,
):
    assert kwargs.get('encoder_hidden_states') is None, 'Only decoder is supported'
    assert layer_past is not None, 'layer_past must be provided'
    
    query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
    query = self._split_heads(query, self.num_heads, self.head_dim)
    key = self._split_heads(key, self.num_heads, self.head_dim)
    value = self._split_heads(value, self.num_heads, self.head_dim)

    past_key, past_value = layer_past
    past_key, past_value = past_key.detach(), past_value.detach()
    #import pdb; pdb.set_trace()
    present = (key, value)

    assert not self.reorder_and_upcast_attn, 'Not supported!'
    assert not self.is_cross_attention, 'Not supported!'
    assert not self.scale_attn_by_inverse_layer_idx, 'Not supported!'
    attn_output, attn_weights = myopic_attn_gpt2(
        query, key, value, past_key, past_value, attention_mask, head_mask,
        self.bias, self.attn_dropout,
        scale_attn_weights=self.scale_attn_weights,
    )

    attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
    attn_output = self.c_proj(attn_output)
    attn_output = self.resid_dropout(attn_output)

    outputs = (attn_output, present)
    if output_attentions:
        outputs += (attn_weights,)

    return outputs  # a, present, (attentions)

In [11]:
def topk(v, k=40, aux=None):
    # Takes in logits
    #v = softmax(v.flatten())
    if type(v) == torch.Tensor:
        v = v.detach().cpu().numpy()
    v = v.flatten()
    idxs = v.argsort()[-k:][::-1]
    if aux:
        ret = [(Token[i], v[i]) + tuple(aux[i]) for i in idxs]
        return pd.DataFrame(ret, columns=['token', 'logit'] + list(range(len(aux[0]))))
    else:
        ret = [(Token[i], v[i]) for i in idxs]
        return pd.DataFrame(ret, columns=['token', 'logit'])

In [12]:
input = tokenizer('My favorite element of the periodic table is', return_tensors='pt').to('cuda')

In [13]:
out = model(**input)

In [21]:
myopic = to_myopic_gpt2(copy.deepcopy(model), out.past_key_values)

In [15]:
out.past_key_values[0][0][0, 0, 0, :5]

tensor([-1.0961,  1.8475,  0.8989, -0.1387,  0.9979], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [22]:
myopic_out = myopic(**input)

In [23]:
(out.logits - myopic_out.logits).norm()

tensor(0.0073, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [24]:
out.logits[0,0,:10]

tensor([-33.0735, -32.3349, -35.2379, -34.7751, -33.8666, -34.4521, -33.0241,
        -33.5888, -32.0457, -34.4160], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [25]:
myopic_out.logits[0,0,:10]

tensor([-33.0735, -32.3349, -35.2379, -34.7751, -33.8666, -34.4521, -33.0241,
        -33.5888, -32.0457, -34.4160], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [87]:
(myopic_out.logits-out.logits).norm()

tensor(14419.7354, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [32]:
out.past_key_values[0][1].shape

torch.Size([1, 12, 9, 64])

In [33]:
del myopic
gc.collect()
torch.cuda.empty_cache()

NameError: name 'myopic' is not defined

In [35]:
model.transformer.h[0].mlp.c_fc.weight.shape[0]

768

In [34]:
myopic = copy.deepcopy(model)

In [14]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f517548c910>

In [35]:
for name, param in myopic.named_parameters():
    #if 'k_proj' in name or 'v_proj' in name:
    if 'mlp' in name:
        param.zero_()

In [36]:
num_batches = 10
it = islice(iter(loader), num_batches)
loss = 0
myopic_loss = 0
for batch in tqdm(it, total=num_batches):
    out = model(
        input_ids=batch['input_ids'], 
        labels=batch['input_ids'], 
        attention_mask=batch['attention_mask']
    )
    myopic = to_myopic(myopic, out.past_key_values)
    myopic_out = myopic(
        input_ids=batch['input_ids'], 
        labels=batch['input_ids'], 
        attention_mask=batch['attention_mask']
    )
    loss += out.loss.item()
    myopic_loss += myopic_out.loss.item()
loss /= num_batches
myopic_loss /= num_batches

100%|██████████| 10/10 [02:14<00:00, 13.44s/it]


full: 2.247
attn only: 7.708
mlp only: 6.929
mlp + q: 2.309

In [37]:
loss, myopic_loss

(2.246629166603088, 7.708231830596924)

In [231]:
to_myopic(myopic, out.past_key_values)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          MYOPIC
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): Mi

In [162]:
out = model(**input)

In [233]:
(out.logits - myopic_out.logits).norm()

tensor(0.0003, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [29]:
model = model.to('cuda')

In [30]:
lora_model = lora_model.to('cuda')

In [26]:
model.device

device(type='cpu')

In [28]:
input = tokenizer("Hello my name is", return_tensors='pt').to('cuda')

In [240]:
lora = to_lora(myopic, 8, ['k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])

In [245]:
lora_out = lora(**input)

KEY DIFF 142.73130798339844
VAL DIFF 12.946982383728027
KEY DIFF 217.15673828125
VAL DIFF 18.909976959228516
KEY DIFF 186.13214111328125
VAL DIFF 67.96466827392578
KEY DIFF 188.8489227294922
VAL DIFF 53.43036651611328
KEY DIFF 174.39666748046875
VAL DIFF 78.77042388916016
KEY DIFF 186.60130310058594
VAL DIFF 68.82740020751953
KEY DIFF 188.38514709472656
VAL DIFF 74.36092376708984
KEY DIFF 195.12734985351562
VAL DIFF 70.07445526123047
KEY DIFF 199.66671752929688
VAL DIFF 72.25618743896484
KEY DIFF 182.92108154296875
VAL DIFF 97.33354949951172
KEY DIFF 212.75222778320312
VAL DIFF 94.85765075683594
KEY DIFF 205.39918518066406
VAL DIFF 87.05265045166016
KEY DIFF 209.25807189941406
VAL DIFF 113.35973358154297
KEY DIFF 208.14002990722656
VAL DIFF 117.33596801757812
KEY DIFF 206.93878173828125
VAL DIFF 114.59712219238281
KEY DIFF 204.29835510253906
VAL DIFF 153.5397186279297
KEY DIFF 214.47869873046875
VAL DIFF 134.12245178222656
KEY DIFF 222.3302764892578
VAL DIFF 126.56486511230469
KEY DIFF

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 2.12 MiB is free. Process 1531090 has 79.13 GiB memory in use. Of the allocated memory 78.63 GiB is allocated by PyTorch, and 15.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [62]:
out = model(
    **input,
)
#lora_out = lora_model(**input)

In [85]:
cache = DynamicCache.from_legacy_cache(out.past_key_values)

In [91]:
len(cache.key_cache)

32

In [36]:
lora_out.logits.shape

torch.Size([1, 5, 32000])

In [18]:
type(model.model.layers[0].self_attn.q_proj) == nn.Linear

True

In [6]:
[x for x in dir(model) if 'module' in x]

['__module__',
 '_get_no_split_modules',
 '_keep_in_fp32_modules',
 '_keep_in_fp32_modules',
 '_modules',
 '_no_split_modules',
 'add_module',
 'get_submodule',
 'modules',
 'named_modules',
 'register_module',
 'retrieve_modules_from_names']

In [15]:
model.get_submodule('model.layers.0.self_attn.q_proj') = nn.Linear(5, 5)

SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? (3646139102.py, line 1)

In [16]:
dict(model.named_modules())

{'': MistralForCausalLM(
   (model): MistralModel(
     (embed_tokens): Embedding(32000, 4096)
     (layers): ModuleList(
       (0-31): 32 x MistralDecoderLayer(
         (self_attn): MistralAttention(
           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
           (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
           (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
           (rotary_emb): MistralRotaryEmbedding()
         )
         (mlp): MistralMLP(
           (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
           (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
           (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
           (act_fn): SiLU()
         )
         (input_layernorm): MistralRMSNorm()
         (post_attention_layernorm): MistralRMSNorm()
       )
     )
     