<a href="https://colab.research.google.com/github/resloved/RWKV-notebooks/blob/master/RWKV_v3_RNN_Pile_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RWKV-v3-RNN-Pile Fine-Tuning

[RWKV](https://github.com/BlinkDL/RWKV-LM) is an RNN with transformer-level performance


This notebook aims to streamline fine-tuning [RWKV-v2-RNN-Pile](https://github.com/BlinkDL/RWKV-v2-RNN-Pile) as detailed [here](https://github.com/BlinkDL/RWKV-v2-RNN-Pile)


## Setup

In [None]:
#@title Google Drive Options { display-mode: "form" }
save_models_to_drive = True #@param {type:"boolean"}
drive_mount = '/content/drive' #@param {type:"string"}
output_dir = 'rwkv-v3-rnn-pile-tuning' #@param {type:"string"}
tuned_model_name = 'tuned' #@param {type:"string"}

import os
from google.colab import drive

drive.mount(drive_mount, force_remount=True)

output_path = f"{drive_mount}/MyDrive/{output_dir}" if save_models_to_drive else f"/content/{output_dir}"
os.makedirs(f"{output_path}/{tuned_model_name}", exist_ok=True)
os.makedirs(f"{output_path}/base_models/", exist_ok=True)

print(f"Saving models to {output_path}")

In [None]:
!nvidia-smi

In [None]:
if save_models_to_drive:
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
!git clone https://github.com/blinkdl/RWKV-v2-RNN-Pile
repo_dir = "/content/RWKV-v2-RNN-Pile/RWKV-v3"

In [None]:
!pip install transformers wandb ninja

## Load Base Model




In [None]:
#@title Base Model Options
#@markdown Using any of the listed options will download the checkpoint from huggingface

base_model_name = "rwkv-3-pile-430m" #@param ["rwkv-3-pile-1b5", "rwkv-3-pile-430m", "rwkv-3-pile-169m"]

!git lfs clone https://huggingface.co/BlinkDL/$base_model_name

from glob import glob
base_model_path = glob(f"{base_model_name}/RWKV*.pth")[0]

print(f"Using {base_model_path} as base")

## Generate Training Data

In [None]:
#@title Training Data Options
#@markdown `input_file` should be the path to a single file that contains the text you want to fine-tune with.
#@markdown Either upload a file to this notebook instance or reference a file in your Google drive.

import numpy as np
from transformers import PreTrainedTokenizerFast

tokenizer = PreTrainedTokenizerFast(tokenizer_file=f'{repo_dir}/20B_tokenizer.json')

input_file = "/content/drive/MyDrive/training.txt" #@param {type:"string"}
output_file = 'train.npy'

print(f'Tokenizing {input_file} (VERY slow. please wait)')

data_raw = open(input_file, encoding="utf-8").read()
print(f'Raw length = {len(data_raw)}')

data_code = tokenizer.encode(data_raw)
print(f'Tokenized length = {len(data_code)}')

out = np.array(data_code, dtype='uint16')
np.save(output_file, out, allow_pickle=False)

## Fine-tune

In [None]:
#@title Fine-tuning Options { display-mode: "form" }

#@markdown By default the fine tuning is handled in GPT mode as it trains much faster,
#@markdown however it uses much more VRAM.
#@markdown
#@markdown The suggested settings for training the 430M paramater model are normally a `ctx_len` of 768, a `batch_size` of 8, and `B_GROUP_FORWARD` being 8.
#@markdown However with the limited VRAM you get with a P100 I've found `336`/`4`/`4` doable. 
#@markdown
#@markdown As always your mileage may vary, fiddle with the numbers yourself.
#@markdown
#@markdown ---
#@markdown 

#@markdown Enable `use_wandb` if you want to track your training run via [Weights & Biases](https://wandb.ai)
use_wandb = True #@param {type:"boolean"}
#@markdown Epochs in this context are really "mini-epochs" that are quite short. They have a fixed length of `ctx_len * epoch_length_fixed` tokens
n_epoch = 100#@param {type:"integer"}
epoch_save_frequency = 4 #@param {type:"integer"}
ctx_len = 384 #@param {type:"integer"}
#@markdown If finetuning OOMs, consider lowering the batch size before lowering `ctx_len`
batch_size =  4#@param {type:"integer"} 
#@markdown `batch_size` must be divisible by both `B_GROUP_FORWARD` and `B_GROUP_BACKWARD`
B_GROUP_FORWARD =  4#@param {type:"integer"}
B_GROUP_BACKWARD =  2#@param {type:"integer"}

import logging
import datetime
import torch
import numpy as np
import math
import json
import sys

import wandb

import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.cpp_extension import load
from torch.utils.data.dataloader import DataLoader
from torch.optim.lr_scheduler import LambdaLR

from tqdm.auto import tqdm

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.empty_cache()

n_layer = 24
n_embd = 2048

if "430" in base_model_name:
    n_layer = 24
    n_embd = 1024
elif "169" in base_model_name:
    n_layer = 12
    n_embd = 768

vocab_size = 50277

model_type = 'RWKV'
datafile = 'train.npy'

#@markdown If your training data uses something similar to what is already in [The Pile](https://pile.eleuther.ai/)
#@markdown consider setting `lr_init` to `1e-5` otherwise `2e-5`
lr_init = 1e-5#@param {type:"number"}
lr_final = 1e-5#@param {type:"number"}

T_MAX = ctx_len
epoch_length_fixed = 10000

epoch_save_path = f"{output_path}/"

grad_norm_clip = 1.0
warmup_tokens = 0

eps=1e-8

betas = (0.9, 0.999)

num_workers = 0

logger = logging.getLogger(__name__)

In [None]:
#@title RWKV_GPT
timex_cuda = load(name="timex", sources=[f"{repo_dir}/cuda/timex_op.cpp", f"{repo_dir}/cuda/timex_cuda.cu"],
                  verbose=True, extra_cuda_cflags=['--use_fast_math', '--extra-device-vectorization', f'-DTmax={T_MAX}', f'-DBF={B_GROUP_FORWARD}', f'-DBB={B_GROUP_BACKWARD}'])


class TimeX(torch.autograd.Function):
    @staticmethod
    def forward(ctx, w, k, B, C, T, eps):
        ctx.B = B
        ctx.C = C
        ctx.T = T
        assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
        w = w.contiguous()
        k = k.contiguous()
        ctx.save_for_backward(w, k)
        wk = torch.empty((B, C, T), device='cuda',
                         memory_format=torch.contiguous_format)
        timex_cuda.forward(w, k, wk, eps, B, C, T)
        return wk

    @staticmethod
    def backward(ctx, gwk):
        assert ctx.T % 4 == 0 and ctx.T <= T_MAX and ctx.B % B_GROUP_FORWARD == 0 and ctx.B % B_GROUP_BACKWARD == 0
        w, k = ctx.saved_tensors
        gw = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
                         memory_format=torch.contiguous_format)
        gk = torch.empty((ctx.B, ctx.C, ctx.T), device='cuda',
                         memory_format=torch.contiguous_format)
        timex_cuda.backward(w, k, gwk.contiguous(), gw,
                            gk, ctx.B, ctx.C, ctx.T)
        return (gw.sum(dim=0), gk, None, None, None, None)

########################################################################################################
# RWKV: RWKV Time-mix + RWKV Channel-mix
########################################################################################################

RWKV_K_CLAMP = 60  # e^60 = 1e26
RWKV_K_EPS = 1e-9
RWKV_HEAD_QK_DIM = 256

class RWKV_TimeMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.ctx_len = config.ctx_len
        self.n_embd = config.n_embd

        attn_sz = config.n_embd

        self.time_decay = nn.Parameter(torch.ones(attn_sz, 1))
        self.time_curve = torch.tensor(
            [-(config.ctx_len - 2 - i) for i in range(config.ctx_len-1)]).unsqueeze(0)
        self.time_curve = self.time_curve.to('cuda')
        self.time_first = nn.Parameter(torch.ones(attn_sz, 1) * math.log(0.3))
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        with torch.no_grad():
            ww = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd // 2):
                ww[0, 0, i] = 0
        self.time_mix_k = nn.Parameter(ww)
        self.time_mix_v = nn.Parameter(ww)
        self.time_mix_r = nn.Parameter(ww)

        self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)

        self.output = nn.Linear(attn_sz, config.n_embd, bias=False)

        self.key.scale_init = 0
        self.receptance.scale_init = 0
        self.output.scale_init = 0

    def forward(self, x):
        B, T, C = x.size()
        assert T == self.ctx_len

        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        k = self.key(xk).transpose(-1, -2)
        v = self.value(xv).transpose(-1, -2)
        r = self.receptance(xr)

        # RWKV_K_CLAMP can be removed if the CUDA kernel substracts the correct k_max for each k (I will do this later)
        k = torch.clamp(k, max=RWKV_K_CLAMP)
        k = torch.exp(k)
        kv = k * v

        self.time_w = torch.cat(
            [torch.exp(self.time_decay) * self.time_curve, self.time_first], dim=-1)
        w = torch.exp(self.time_w)

        wkv = TimeX.apply(w, kv, B, C, T, 0)
        # RWKV_K_EPS can be removed if the CUDA kernel sets 0/0 = 0 (I will do this later)
        wk = TimeX.apply(w, k, B, C, T, RWKV_K_EPS)

        rwkv = torch.sigmoid(r) * (wkv / wk).transpose(-1, -2)
        rwkv = self.output(rwkv)
        return rwkv


class RWKV_ChannelMix(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id

        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

        with torch.no_grad():  # init to "shift half of the channels"
            x = torch.ones(1, 1, config.n_embd)
            for i in range(config.n_embd // 2):
                x[0, 0, i] = 0
        self.time_mix_k = nn.Parameter(x)
        self.time_mix_r = nn.Parameter(x)

        hidden_sz = 4 * config.n_embd
        self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
        self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)

        self.value.scale_init = 0
        self.receptance.scale_init = 0

    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
        xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

        k = self.key(xk)
        k = torch.square(torch.relu(k))
        kv = self.value(k)
        
        rkv = torch.sigmoid(self.receptance(xr)) * kv
        return rkv

########################################################################################################
# The GPT Model with our blocks
########################################################################################################


class GPTConfig:
    def __init__(self, vocab_size, ctx_len, **kwargs):
        self.vocab_size = vocab_size
        self.ctx_len = ctx_len
        for k, v in kwargs.items():
            setattr(self, k, v)


class Block(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.config = config
        self.layer_id = layer_id

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)
        if self.layer_id == 0:
            self.ln0 = nn.LayerNorm(config.n_embd)

        self.att = RWKV_TimeMix(config, layer_id)
        self.ffn = RWKV_ChannelMix(config, layer_id)

    def forward(self, x):
        
        if self.layer_id == 0:
            x = self.ln0(x)
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))

        return x


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.step = 0
        self.config = config

        self.emb = nn.Embedding(config.vocab_size, config.n_embd)

        self.blocks = nn.Sequential(*[Block(config, i)
                                    for i in range(config.n_layer)])

        self.ln_out = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.ctx_len = config.ctx_len

        # RWKV_Init(self, config)

        logger.info("number of parameters: %e", sum(p.numel()
                    for p in self.parameters()))

    def get_ctx_len(self):
        return self.ctx_len

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear)):
            module.weight.data.normal_(mean=0.0, std=0.01)
        if isinstance(module, (nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=1e-5)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

    def configure_optimizers(self, train_config):
        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()

        for mn, m in self.named_modules():  # here we disable weight_decay
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
                no_decay.add(fpn)

        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(
            inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
            % (str(param_dict.keys() - union_params), )

        optim_groups = [
            {"params": [param_dict[pn]
                        for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]

        optimizer = torch.optim.Adam(
            optim_groups, lr=train_config.learning_rate, betas=train_config.betas, eps=train_config.eps)

        return optimizer

    def forward(self, idx, targets=None):
        self.step += 1
        B, T = idx.size()
        assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
        x = self.emb(idx)
        x = self.blocks(x)
        x = self.ln_out(x)
        x = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))

        return x, loss

In [None]:
#@title Trainer { display-mode: "form" }

log_file = open("mylog.txt", "a")

class TrainerConfig:
    max_epochs = 10
    batch_size = 64
    learning_rate = 4e-4
    betas = (0.9, 0.99)
    eps = 1e-8
    grad_norm_clip = 1.0
    lr_decay = True  # linear warmup followed by cosine decay
    warmup_tokens = 0
    final_tokens = 0
    epoch_save_frequency = 0
    epoch_save_path = 'trained-'
    num_workers = 0  # for DataLoader

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


class Trainer:

    def __init__(self, model, train_dataset, test_dataset, config):
        self.model = model
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.config = config
        self.avg_loss = -1
        self.steps = 0

        if use_wandb and 'wandb' in sys.modules:
            cfg = model.config
            for k in config.__dict__:
                setattr(cfg, k, config.__dict__[k])  # combine cfg
            wandb.init(project="RWKV-LM", name=f"{tuned_model_name}-{self.get_run_name()}", config=cfg, save_code=False)

        self.device = 'cpu'
        if torch.cuda.is_available():  # take over whatever gpus are on the system
            self.device = torch.cuda.current_device()

    def get_run_name(self):
        raw_model = self.model.module if hasattr(
            self.model, "module") else self.model
        cfg = raw_model.config
        run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd)
        return run_name

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset

            if config.num_workers > 0:
                loader = DataLoader(data, shuffle=False, pin_memory=True,
                                    batch_size=config.batch_size,
                                    num_workers=config.num_workers)
            else:
                loader = DataLoader(data, shuffle=False,
                                    batch_size=config.batch_size,
                                    num_workers=config.num_workers)

            pbar = tqdm(enumerate(loader), total=len(
                loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader)

            for it, (x, y) in pbar:
                x = x.to(self.device)  # place data on the correct device
                y = y.to(self.device)

                with torch.set_grad_enabled(is_train):
                    _, loss = model(x, y)  # forward the model

                if is_train:  # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()

                    if config.grad_norm_clip > 0:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), config.grad_norm_clip)

                    optimizer.step()

                    if config.lr_decay:  # decay the learning rate based on our progress
                        # number of tokens processed this step (i.e. label is not -100)
                        self.tokens += (y >= 0).sum()
                        lr_final_factor = config.lr_final / config.learning_rate
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = lr_final_factor + \
                                (1 - lr_final_factor) * float(self.tokens) / \
                                float(config.warmup_tokens)
                            progress = 0
                        else:
                            # exponential learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            if progress >= 1:
                                lr_mult = lr_final_factor
                            else:
                                lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1))
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    now_loss = loss.item()  # report progress
                    self.lr = lr

                    if use_wandb and 'wandb' in sys.modules:
                        wandb.log({"loss": now_loss},
                                  step=self.steps * self.config.batch_size)
                    self.steps += 1

                    if self.avg_loss < 0:
                        self.avg_loss = now_loss
                    else:
                        factor = 1 / (it + 1)
                        self.avg_loss = self.avg_loss * \
                            (1.0 - factor) + now_loss * factor
                    pbar.set_description(
                        f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}")

        self.tokens = 0  # counter used for learning rate decay
        for epoch in range(config.max_epochs):

            run_epoch('train')

            log_file.write(
                f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n')
            log_file.flush()

            if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1):
                # DataParallel wrappers keep raw model object in .module
                raw_model = self.model.module if hasattr(
                    self.model, "module") else self.model
                torch.save(raw_model.state_dict(), f"{output_path}/{tuned_model_name}/{tuned_model_name}-{base_model_name}-{trainer.get_run_name()}-{epoch}.pth")

In [None]:
#@title Dataset
from torch.utils.data import Dataset

class Dataset(Dataset):
    def __init__(self, data, vocab_size, ctx_len, epoch_length_fixed):
        data_size, vocab_size = len(data), vocab_size
        print('data has %d tokens, %d unique.' % (data_size, vocab_size))
        self.ctx_len = ctx_len
        self.epoch_length_fixed = epoch_length_fixed
        self.vocab_size = vocab_size
        self.data = data

    def __len__(self):
        return self.epoch_length_fixed

    def __getitem__(self, idx):
        # cheat: pick a random spot in dataset
        i = np.random.randint(0, len(self.data) - (self.ctx_len + 1))
        dix = self.data[i:i+self.ctx_len+1]
        x = torch.tensor(dix[:-1], dtype=torch.long,
                         device=torch.device('cuda'))
        y = torch.tensor(dix[1:], dtype=torch.long,
                         device=torch.device('cuda'))
        return x, y

print('loading data... ' + datafile)
train_dataset = Dataset(np.load(datafile).astype('int'), vocab_size, ctx_len, epoch_length_fixed)

In [None]:
#@title Start fine-tuning { display-mode: "form" }
np.set_printoptions(precision=4, suppress=True, linewidth=200)
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,)

model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type,
                      n_layer=n_layer, n_embd=n_embd)).cuda()

print('loading ' + base_model_path)
m2 = torch.load(base_model_path)
model.load_state_dict(m2)

print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas',
      betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, )
tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size,
                      learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip,
                      warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path)
trainer = Trainer(model, train_dataset, None, tconf)

trainer.train()

tuned_model_path = f"{output_path}/{tuned_model_name}/{tuned_model_name}-{base_model_name}-{trainer.get_run_name()}.pth"
torch.save(model.state_dict(), tuned_model_path)