In [1]:
# Load model directly
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm
from datasets import load_dataset

import transformers
from tqdm.auto import tqdm, trange
assert torch.cuda.is_available()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = AutoModelForCausalLM.from_pretrained(
    "ai-forever/ruGPT-3.5-13B",
    load_in_4bit=True,
    torch_dtype=torch.float16,
    device_map='auto',
    low_cpu_mem_usage=True,
    offload_state_dict=True, 
    cache_dir="/home/nikita_u/study/nir/repo/rugpt-memory/checkpoints/base/huggingface/"
)

Loading checkpoint shards: 100%|██████████| 6/6 [01:16<00:00, 12.77s/it]


In [3]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50272, 5120)
    (wpe): Embedding(2048, 5120)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-39): 40 x GPT2Block(
        (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Linear4bit(in_features=5120, out_features=15360, bias=True)
          (c_proj): Linear4bit(in_features=5120, out_features=5120, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear4bit(in_features=5120, out_features=20480, bias=True)
          (c_proj): Linear4bit(in_features=20480, out_features=5120, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((5120,), eps=1e-05, elem

In [5]:
batch, sentence_length, embedding_dim = 20, 5, 5120
x = torch.randn(batch, sentence_length, embedding_dim, requires_grad=True, dtype=torch.float16)
x.shape

torch.Size([20, 5, 5120])

In [6]:
model.transformer.h[39].ln_1.normalized_shape

(5120,)

In [7]:
model.transformer.h[39]

GPT2Block(
  (ln_1): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Linear4bit(in_features=5120, out_features=15360, bias=True)
    (c_proj): Linear4bit(in_features=5120, out_features=5120, bias=True)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Linear4bit(in_features=5120, out_features=20480, bias=True)
    (c_proj): Linear4bit(in_features=20480, out_features=5120, bias=True)
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [8]:
x_o = model.transformer.h[39](x)
x_o.shape



AttributeError: 'tuple' object has no attribute 'shape'

In [9]:
x_o[0].shape

torch.Size([20, 5, 5120])

In [3]:
model.transformer.h[39].attn

GPT2Attention(
  (c_attn): Linear4bit(in_features=5120, out_features=15360, bias=True)
  (c_proj): Linear4bit(in_features=5120, out_features=5120, bias=True)
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

In [23]:
nn.MultiheadAttention(embed_dim=5120, num_heads=4, dropout=0.1)

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=5120, out_features=5120, bias=True)
)

In [22]:
attn = nn.MultiheadAttention(embed_dim=5120, num_heads=4, dropout=0.1, dtype=torch.float16)
attn

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=5120, out_features=5120, bias=True)
)

In [20]:
x.shape

torch.Size([20, 5, 5120])

In [24]:
o_attn = attn(
    query=x, 
    key=x, 
    value=x
)
o_attn.shape

AttributeError: 'tuple' object has no attribute 'shape'

In [26]:
len(o_attn)

2

In [27]:
o_attn[0].shape

torch.Size([20, 5, 5120])

In [28]:
o_attn[1].shape

torch.Size([5, 20, 20])

In [30]:
x.shape

torch.Size([20, 5, 5120])

In [32]:
len(nn.Linear(5120, 5120, dtype=torch.float16)(x))

20

In [39]:
y = nn.Linear(5120, 2*5120, dtype=torch.float16)(x)
y = nn.ReLU()(y)
print(y.shape)
y = nn.Linear(2*5120, 5120, dtype=torch.float16)(y)
y.shape

torch.Size([20, 5, 20480])


torch.Size([20, 5, 5120])

In [15]:
nn.Linear(5120, 5120, dtype=torch.float16)(x)

tensor([[[-3.6125e-03,  1.3330e+00,  4.0894e-01,  ..., -1.6333e-01,
           2.9639e-01, -3.5156e-01],
         [ 7.3242e-02, -7.8174e-01,  6.9189e-01,  ..., -4.7168e-01,
           6.6992e-01,  1.3662e+00],
         [-4.0967e-01, -7.3779e-01,  4.2676e-01,  ...,  7.9590e-01,
           4.1016e-01, -4.2432e-01],
         [-1.1914e+00,  3.3667e-01,  2.9102e-01,  ...,  6.7871e-01,
          -4.3359e-01,  3.1763e-01],
         [ 6.9434e-01,  1.4526e-02, -4.5471e-02,  ..., -1.5495e-02,
          -1.4783e-01, -1.9409e-02]],

        [[ 4.6533e-01, -7.1533e-02,  5.8008e-01,  ..., -5.9473e-01,
          -2.5977e-01,  4.3896e-01],
         [ 5.6201e-01, -5.0439e-01, -2.9810e-01,  ...,  1.0217e-01,
          -7.3128e-03,  9.5654e-01],
         [ 1.2002e+00, -2.9150e-01,  6.4893e-01,  ..., -1.8787e-01,
          -2.0312e-01, -2.8809e-01],
         [ 1.2549e-01,  4.6313e-01,  2.3267e-01,  ...,  6.6309e-01,
          -1.0712e-01,  6.1182e-01],
         [-9.1699e-01,  4.3427e-02, -8.5596e-01,  ...

In [14]:
x.shape

torch.Size([20, 5, 5120])

In [60]:
upper_level_embeddings = model.transformer.h[39].forward(x)

In [68]:
type(upper_level_embeddings[0])

torch.Tensor

In [69]:
upper_level_embeddings[0].shape

torch.Size([20, 5, 5120])

In [62]:
x.shape

torch.Size([20, 5, 5120])

In [41]:
(x+x).shape

torch.Size([20, 5, 5120])

In [68]:
x.float()

tensor([[[-0.7881,  1.1523, -0.0551,  ...,  0.7070, -0.1191,  0.7812],
         [-1.2559, -1.3057, -2.3125,  ..., -0.1284, -1.5430, -0.4556],
         [ 0.1301,  0.6895, -0.9297,  ...,  0.1860,  1.9502, -0.7021],
         [-0.6045,  1.0312,  0.0507,  ...,  1.4180, -0.6758, -0.0483],
         [-1.2959, -0.1399,  0.3103,  ..., -1.4834,  0.3401, -1.2705]],

        [[-2.0645, -0.9365, -1.3574,  ..., -1.3594,  0.8984,  0.3293],
         [ 0.5850,  1.0352, -0.6978,  ..., -1.3691,  0.6421, -0.7061],
         [ 0.9795, -1.0713, -1.1191,  ...,  1.2744,  0.3896, -1.7441],
         [ 0.4529, -0.2864,  1.1650,  ..., -0.8154,  0.8726,  1.0547],
         [ 0.2610,  0.3921, -0.4343,  ..., -0.5005,  0.0078,  0.5220]],

        [[-0.0681,  1.4385,  0.4812,  ...,  1.2188, -1.2520, -1.0400],
         [-0.6118,  1.1787, -0.3931,  ...,  1.1338,  0.8755,  0.1853],
         [-1.2627,  0.6382, -0.2644,  ..., -1.0850,  0.3674, -0.6143],
         [ 0.0480,  0.6294,  0.6533,  ...,  0.0059,  1.6172,  0.3538],
  

In [75]:
ln = nn.LayerNorm(5120)
ln(x.float()).type(torch.float16)

tensor([[[-0.8096,  1.1475, -0.0701,  ...,  0.6987, -0.1348,  0.7734],
         [-1.2646, -1.3145, -2.3164,  ..., -0.1418, -1.5508, -0.4675],
         [ 0.1290,  0.6880, -0.9302,  ...,  0.1848,  1.9482, -0.7031],
         [-0.6104,  1.0312,  0.0470,  ...,  1.4189, -0.6821, -0.0524],
         [-1.2842, -0.1288,  0.3213,  ..., -1.4717,  0.3511, -1.2588]],

        [[-2.0801, -0.9434, -1.3682,  ..., -1.3701,  0.9072,  0.3333],
         [ 0.5952,  1.0459, -0.6885,  ..., -1.3604,  0.6523, -0.6968],
         [ 0.9917, -1.0791, -1.1270,  ...,  1.2900,  0.3962, -1.7578],
         [ 0.4624, -0.2764,  1.1738,  ..., -0.8052,  0.8818,  1.0645],
         [ 0.2393,  0.3696, -0.4521,  ..., -0.5181, -0.0124,  0.4988]],

        [[-0.0587,  1.4541,  0.4929,  ...,  1.2334, -1.2471, -1.0342],
         [-0.5972,  1.1973, -0.3779,  ...,  1.1523,  0.8936,  0.2018],
         [-1.2852,  0.6694, -0.2588,  ..., -1.1025,  0.3911, -0.6187],
         [ 0.0564,  0.6387,  0.6631,  ...,  0.0141,  1.6289,  0.3628],
  

In [63]:
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
embedding.shape

torch.Size([20, 5, 10])

In [3]:
class DenseNetwork(nn.Module):
    """ DenseNetwork layer(FeedForward in original paper) """
    def __init__(
        self, 
        embed_dim=5120,
        hidden_size=10240, 
        dtype=torch.float16,
        initialize_with_zeros=False
    ):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.hidden_size = hidden_size
        self.dtype = dtype
        
        self.ln1 = nn.Linear(self.embed_dim, self.hidden_size, dtype=self.dtype)
        self.relu = nn.ReLU()
        self.ln2 = nn.Linear(self.hidden_size, self.embed_dim, dtype=self.dtype)
        
        if initialize_with_zeros:
            nn.init.zeros_(self.ln1.weight)
            nn.init.zeros_(self.ln1.bias)
            nn.init.zeros_(self.ln2.weight)
            nn.init.zeros_(self.ln2.bias)
    
    def forward(self, x): # x: (sentence_length, batch_size, self.embed_dim)
        x = self.ln1(x)
        x = self.relu(x)
        x = self.ln2(x)
        return x

In [57]:
layer_full = DenseNetwork(initialize_with_zeros=True)

In [58]:
layer_full.ln1.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16,
       requires_grad=True)

In [59]:
layer_full.ln1.bias

Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], dtype=torch.float16, requires_grad=True)

In [60]:
layer_full.ln2.weight

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float16,
       requires_grad=True)

In [61]:
layer_full.ln2.bias

Parameter containing:
tensor([0., 0., 0.,  ..., 0., 0., 0.], dtype=torch.float16, requires_grad=True)

In [4]:
class LTMGPT2Block(nn.Module):
    """ Custom LTMGPT2Block layer with memory """
    def __init__(
        self, 
        gpt2_block,
        num_heads=4,
        attn_dropout=0.1,
        dense_network_hidden_size=10240,
        dtype=torch.float32
    ):
        super().__init__()
        self.gpt2_block = gpt2_block
        
        self.embed_dim = self.gpt2_block.ln_1.normalized_shape[0]
        self.dense_network_hidden_size = dense_network_hidden_size
        
        assert dtype in [torch.float16, torch.float32]
        
        # self.memory: ( , , ) / (target_sentence_length, batch_size, self.embed_dim) (5120) | torch.FloatTensor / nn.Embedding
        self.memory = None
        
        # goal: convert memory from ( , , ) to (source_sentence_length, batch_size, self.embed_dim)
        self.dense_network1 = DenseNetwork(
            embed_dim=self.embed_dim,
            hidden_size=self.dense_network_hidden_size, 
            dtype=dtype,
            initialize_with_zeros=False
        )
        
        self.attn = nn.MultiheadAttention( # TODO masked ????
            embed_dim=self.embed_dim, 
            num_heads=num_heads, 
            dropout=attn_dropout,
            batch_first=False,
            dtype=dtype
        )
        
        self.ln1 = nn.LayerNorm(self.embed_dim)
        
        self.dense_network2 = DenseNetwork(
            embed_dim=self.embed_dim,
            hidden_size=self.dense_network_hidden_size, 
            dtype=dtype,
            initialize_with_zeros=True
        )
        
        self.ln2 = nn.LayerNorm(self.embed_dim)
        
        
        
    
    def forward(self, x): # x: (sentence_length, batch_size, self.embed_dim)
        assert not self.memory
        
        # TransformerBlock
        query = self.gpt2_block(x) # query: (sentence_length, batch_size, self.embed_dim)
        residual = query
        
        # DenseNetowork
        memory = self.dense_network1(self.memory)
        
        # MultiHead Attention
        key, value = memory, memory
        x, _ = self.attn(
            query=query, 
            key=key, 
            value=value
        )
        
        # Norm & Concat
        x = x + residual
        if self.dtype == torch.float16:
            x = self.ln1(x.float()).type(torch.float16)
        else:
            x = self.ln1(x)
        
        # DenseNetowork initialized with zeroes
        x = self.dense_network2(x)
        
        # Norm & Concat
        x = x + residual
        if self.dtype == torch.float16:
            x = self.ln2(x.float()).type(torch.float16)
        else:
            x = self.ln2(x)
        
        return x
    
    def update_memory(new_memory):
        self.memory = new_memory

In [None]:
model

In [None]:
class LTM_GPT(nn.Module):
    """ Custom LTM GPT2 layer with memory """
    def __init__(self, model):
        self.model_freeze = model
        
        pass
    
    def forward(self, x):
        pass

In [7]:
for param in model.transformer.h[38:].parameters():
        param.requires_grad=True
        param.data = param.data.to(torch.float32)

    for param in model.transformer.ln_f.parameters():
        param.requires_grad=True
        param.data = param.data.to(torch.float32)

    for param in model.lm_head.parameters():
        param.requires_grad=True
        param.data = param.data.to(torch.float32)

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 5)

In [8]:
# Init LTM gpt blocks
model.transformer.h[-2] = LTMGPT2Block(model.transformer.h[-2])
model.transformer.h[-1] = LTMGPT2Block(model.transformer.h[-1])

In [9]:
# Upcast
for param in model.transformer.h[-2:].parameters():
    param.data = param.data.to(torch.float32)
    
for param in model.transformer.ln_f.parameters():
    param.requires_grad=True
    param.data = param.data.to(torch.float32)

for param in model.lm_head.parameters():
    param.requires_grad=True
    param.data = param.data.to(torch.float32)

In [12]:
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ruGPT-3.5-13B")

In [10]:
code_dataset = load_dataset("codeparrot/codeparrot-clean-valid")



In [13]:
prompts =  ['import', 'from', 'while', 'try', 'if', 'for', 'torch']  # feel free to add a few more that are not 100% assiciated with Python

MAX_STEPS = 100

for prompt in tqdm(prompts):
    print(tokenizer(prompt, return_tensors='pt', return_token_type_ids=False))

100%|██████████| 7/7 [00:00<00:00, 378.18it/s]

{'input_ids': tensor([[33076]]), 'attention_mask': tensor([[1]])}
{'input_ids': tensor([[34958]]), 'attention_mask': tensor([[1]])}
{'input_ids': tensor([[29631]]), 'attention_mask': tensor([[1]])}
{'input_ids': tensor([[  89, 2286]]), 'attention_mask': tensor([[1, 1]])}
{'input_ids': tensor([[1271]]), 'attention_mask': tensor([[1]])}
{'input_ids': tensor([[9949]]), 'attention_mask': tensor([[1]])}
{'input_ids': tensor([[23652,  1028]]), 'attention_mask': tensor([[1, 1]])}





In [14]:
def custom_generate(prompt, model, device, max_steps):
    batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)

    for i in range(max_steps):
        outputs = model(**batch)
        #print(outputs)
        probs = outputs.logits[0, -1].nan_to_num(nan=0.0).div(0.8).softmax(-1) #.argmax(-1).reshape(1, 1)
        old_token = outputs.logits[0, -1].argmax(-1).reshape(1, 1)
        #print(old_token)
        next_token = torch.multinomial(probs, 1).reshape(1, 1)
        #print(next_token)
        batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
        batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

    return tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()[1:])

In [15]:
after_finetuning_samples = []
for prompt in tqdm(prompts):
    after_finetuning_samples.append(custom_generate(prompt, model, device, MAX_STEPS))
after_finetuning_samples

  0%|          | 0/7 [00:00<?, ?it/s]


RuntimeError: expected scalar type Float but found Half