### Parameter Efficient Fine-Tuning
fine-tune LLM within limited GPU memory.

In [1]:
%pip install --quiet transformers==4.34.1 accelerate==0.24.0 sentencepiece==0.1.99 optimum==1.13.2 peft==0.5.0 bitsandbytes==0.41.2.post2

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from tqdm.auto import tqdm, trange
assert torch.cuda.is_available(), "you need cuda for this part"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Note: you may need to restart the kernel to use updated packages.


In [2]:
model_name = 'Enoch/llama-7b-hf'

tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    load_in_4bit=True, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False

model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

### LoRA

In [3]:
from datasets import load_dataset
ds = load_dataset("codeparrot/codeparrot-clean-valid")

Repo card metadata block was not found. Setting CardData to empty.
  table = cls._concat_blocks(blocks, axis=0)


In [4]:
prompts =  ['', 'import', 'from', 'while', 'try', 'if', 'for', 'torch', 'array =']  # feel free to add a few more that are not 100% assiciated with Python

In [5]:
# cut all code to first 512 symbols
ds = ds.map(lambda x: {'content': x['content'][:512]})

In [6]:
ds['train']['content'][34]
len(ds['train'])

61373

In [7]:
list(model.base_model.modules())[0]

LlamaModel(
  (embed_tokens): Embedding(32000, 4096, padding_idx=0)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLUActivation()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

In [24]:
class LayerLoRA(nn.Module):
    def __init__(self, module, rank):
        super().__init__()
        self.module = module  # frozen
        self.adapter_IN = nn.Parameter(torch.empty(4096, rank), requires_grad=True)
        nn.init.kaiming_uniform_(self.adapter_IN, a=5 ** 0.5)
        self.adapter_OUT = nn.Parameter(torch.zeros(rank, 32000), requires_grad=True)
        
    def forward(self, x):
        fc_x = self.module(x)
        lora_in = x @ self.adapter_IN
        lora_out = lora_in @ self.adapter_OUT
        return lora_out + fc_x
    
    
def replace_module(module, name):
    '''
    set module = net to start code.
    '''
    for attr_str in dir(module):
        target_attr = getattr(model, attr_str)
        if attr_str == 'lm_head':
            print('replaced: ', name, attr_str, type(target_attr))
            
            new_lora_layer = LayerLoRA(module=target_attr, rank=5).to(device)
            setattr(module, attr_str, new_lora_layer)
    return new_lora_layer

In [25]:
layer_lora = replace_module(model, 'model')

replaced:  model lm_head <class 'torch.nn.modules.linear.Linear'>


In [26]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )


In [28]:
# test LoRA Layer

x = torch.ones(4096).to('cuda:1')
layer_lora(x)

tensor([-0.4096,  2.9822, -0.7106,  ...,  1.8868, -1.8370, -1.9048],
       device='cuda:1', grad_fn=<AddBackward0>)

In [29]:
# show adapters
for name, param in model.named_parameters():
    if (param.requires_grad):
        print(name, param)

lm_head.adapter_IN Parameter containing:
tensor([[ 0.3421, -0.3496,  0.4111,  0.1866, -0.2498],
        [ 0.2834, -0.2159,  0.2024,  0.3993,  0.0184],
        [ 0.0933, -0.1524,  0.2940, -0.3534,  0.2739],
        ...,
        [-0.0834, -0.0594,  0.0977, -0.0876, -0.1664],
        [ 0.3234, -0.3376,  0.1591,  0.1747, -0.3358],
        [-0.1995,  0.2182,  0.2223,  0.1540, -0.1868]], device='cuda:1',
       requires_grad=True)
lm_head.adapter_OUT Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1', requires_grad=True)


In [33]:
from accelerate import Accelerator
from torch.utils.data import DataLoader

train_loader = DataLoader(ds['train'], batch_size=1, shuffle=False)
accelerator = Accelerator(gradient_accumulation_steps=4)

In [30]:
model.device

device(type='cuda', index=0)

In [34]:
import gc
def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()
flush()

In [36]:
# train cycle

device='cuda'
n_epochs = 5
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model, optimizer, train_loader = accelerator.prepare(
    model, optimizer, train_loader
)

for i in range(1):
    for batch in tqdm(train_loader):
        preproc = tokenizer(batch['content'], return_tensors='pt', return_token_type_ids=False, padding=True).to(device)
        outputs = model(**preproc)
        next_word_logits = outputs.logits[:, :-1]
        true_next_tokens = preproc['input_ids'][:, 1:]
        loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))
        
        accelerator.backward(loss)

        optimizer.step()

        optimizer.zero_grad()
        flush()
        
    

  0%|          | 0/61373 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [37]:
# test tuned LLM

tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

for prompt in prompts:
    batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)
    for i in range(25):
        next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
        batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
        batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

    print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))


Output: <s>#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os


Output: <s>import os
import sys
import time
import logging
import logging.handlers
import logging.config
import logging.

Output: <s>from __future__ import absolute_import
from __future__ import division
from __future__ import print_

Output: <s>while(1)
while(1) {
    // do something
}
\end{code}

The

Output: <s>try to find the best solution for your needs.
We are a team of 10 people, with a wide range of

Output: <s>if ( !( $post_ID = (int) $post_ID ) ) {
	return;
}


Output: <s>for the 2018-2019 school year.
The application process for the 2018

Output: <s>torchbearer 2017-07-12 17:27:27 UTC #

Output: <s>array = new Array(1000000000000000000000
