<a href="https://colab.research.google.com/github/vivekdurai/bluepencil-beta/blob/main/RWKV_v4_RNN_Pile_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RWKV-v4-RNN-Pile Fine-Tuning

[RWKV](https://github.com/BlinkDL/RWKV-LM) is an RNN with transformer-level performance


This notebook aims to streamline fine-tuning RWKV-v4 models


## Setup

In [1]:
#@title Google Drive Options { display-mode: "form" }
save_models_to_drive = True #@param {type:"boolean"}
drive_mount = '/content/drive' #@param {type:"string"}
output_dir = 'rwkv-v4-rnn-pile-tuning' #@param {type:"string"}
tuned_model_name = 'tuned' #@param {type:"string"}

import os
from google.colab import drive
if save_models_to_drive:
    from google.colab import drive
    drive.mount(drive_mount)
    
output_path = f"{drive_mount}/MyDrive/{output_dir}" if save_models_to_drive else f"/content/{output_dir}"
os.makedirs(f"{output_path}/{tuned_model_name}", exist_ok=True)
os.makedirs(f"{output_path}/base_models/", exist_ok=True)

print(f"Saving models to {output_path}")

Mounted at /content/drive
Saving models to /content/drive/MyDrive/rwkv-v4-rnn-pile-tuning


In [2]:
!nvidia-smi

Sun May 14 14:59:22 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   31C    P0    41W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
!git clone https://github.com/blinkdl/RWKV-LM
repo_dir = "/content/RWKV-LM/RWKV-v4"
%cd $repo_dir

Cloning into 'RWKV-LM'...
remote: Enumerating objects: 1776, done.[K
remote: Counting objects: 100% (330/330), done.[K
remote: Compressing objects: 100% (174/174), done.[K
remote: Total 1776 (delta 212), reused 255 (delta 154), pack-reused 1446[K
Receiving objects: 100% (1776/1776), 10.75 MiB | 8.95 MiB/s, done.
Resolving deltas: 100% (1111/1111), done.
/content/RWKV-LM/RWKV-v4


In [4]:
!pip install transformers pytorch-lightning==1.9 deepspeed wandb ninja

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.1-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m82.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pytorch-lightning==1.9
  Downloading pytorch_lightning-1.9.0-py3-none-any.whl (825 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m825.8/825.8 kB[0m [31m64.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting deepspeed
  Downloading deepspeed-0.9.2.tar.gz (779 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.3/779.3 kB[0m [31m68.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting wandb
  Downloading wandb-0.15.2-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m101.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ninja
  Downl

## Load Base Model




In [5]:
#@title Base Model Options
#@markdown Using any of the listed options will download the checkpoint from huggingface

base_model_name = "RWKV-4-Pile-169M" #@param ["RWKV-4-Pile-1B5", "RWKV-4-Pile-430M", "RWKV-4-Pile-169M"]
base_model_url = f"https://huggingface.co/BlinkDL/{base_model_name.lower()}"

# This may take a while
!git lfs clone $base_model_url

from glob import glob
base_model_path = glob(f"{base_model_name.lower()}/{base_model_name}*.pth")[0]

print(f"Using {base_model_path} as base")

          with new flags from 'git clone'

'git clone' has been updated in upstream Git to have comparable
speeds to 'git lfs clone'.
Cloning into 'rwkv-4-pile-169m'...
remote: Enumerating objects: 59, done.[K
remote: Counting objects: 100% (34/34), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 59 (delta 20), reused 20 (delta 11), pack-reused 25[K
Unpacking objects: 100% (59/59), 6.94 KiB | 646.00 KiB/s, done.
Using rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023.pth as base


## Generate Training Data

In [6]:
#@title Training Data Options
#@markdown `input_file` should be the path to a single file that contains the text you want to fine-tune with.
#@markdown Either upload a file to this notebook instance or reference a file in your Google drive.

import numpy as np
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(tokenizer_file=f'{repo_dir}/20B_tokenizer.json')

input_file = "/content/sample.txt" #@param {type:"string"}
output_file = 'train.npy'

print(f'Tokenizing {input_file} (VERY slow. please wait)')

data_raw = open(input_file, encoding="utf-8").read()
print(f'Raw length = {len(data_raw)}')

data_code = tokenizer.encode(data_raw)
print(f'Tokenized length = {len(data_code)}')

out = np.array(data_code, dtype='uint16')
np.save(output_file, out, allow_pickle=False)

Tokenizing /content/sample.txt (VERY slow. please wait)
Raw length = 71693
Tokenized length = 17253


## Training

In [7]:
#@title Training Options { display-mode: "form" }
from shutil import copy
import os

def training_options():
    EXPRESS_PILE_MODE = True
    EXPRESS_PILE_MODEL_NAME = base_model_path.split(".")[0]
    EXPRESS_PILE_MODEL_TYPE = base_model_name
    n_epoch = 100 #@param {type:"integer"}
    epoch_save_frequency = 25 #@param {type:"integer"}
    batch_size =  11#@param {type:"integer"} 
    ctx_len = 384 #@param {type:"integer"}
    epoch_save_path = f"{output_path}/{tuned_model_name}"
    return locals()

def model_options():
    T_MAX = 384 #@param {type:"integer"}
    return locals()

def env_vars():
    RWKV_FLOAT_MODE = 'fp16' #@param ['fp16', 'bf16', 'bf32'] {type:"string"}
    RWKV_DEEPSPEED = '0' #@param ['0', '1'] {type:"string"}
    return {f"os.environ['{key}']": value for key, value in locals().items()}

def replace_lines(file_name, to_replace):
    with open(file_name, 'r') as f:
        lines = f.readlines()
    with open(f'{file_name}.tmp', 'w') as f:
        for line in lines:
            key = line.split(" =")[0]
            if key.strip() in to_replace:
                value = to_replace[key.strip()]
                if isinstance(value, str):
                    f.write(f'{key} = "{value}"\n')
                else:
                    f.write(f'{key} = {value}\n')
            else:
                f.write(line)
    copy(f'{file_name}.tmp', file_name)
    os.remove(f'{file_name}.tmp')

values = training_options()
values.update(env_vars())
replace_lines('train.py', values)
replace_lines('src/model.py', model_options())

In [None]:
!python train.py 

loading numpy data... train.npy
current vocab size = 50277 (make sure it's correct)
data has 17253 tokens.
2023-05-14 15:03:22 - INFO - torch.distributed.nn.jit.instantiator - Created a temporary directory at /tmp/tmpm_bw_37c
2023-05-14 15:03:22 - INFO - torch.distributed.nn.jit.instantiator - Writing /tmp/tmpm_bw_37c/_remote_module_non_scriptable.py
2023-05-14 15:03:22 - INFO - numexpr.utils - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2023-05-14 15:03:22 - INFO - numexpr.utils - NumExpr defaulting to 8 threads.

RWKV_HEAD_QK_DIM 0

Using /root/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py310_cu118/wkv...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py310_cu118/wkv/build.ninja...
Building extension module wkv...
Allowing ninja to set a default number of workers... (overridable by setting the environment