<a href="https://colab.research.google.com/github/ymdysk/chatrwkv-notebook/blob/main/ChatRWKV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Notebook for running ChatRWKV with Google Colab
## How it works

- Mount Google Drive and clone the ChatRWKV Git repository in the src folder.
- Go to the ChatRWKV/v2/ folder and download a model.
- Set variables related to the model in the form, replace some contents of chat.py and save it in chat-notebook.py to load the model.
- Set variables related to text generation and run the chat.

## Tips
- If you want to use GPU, go to Google Colab's "Runtime" menu -> "Change Runtime Type" -> "Hardware Accelerator" and select "GPU" and save the setting.
- If GPU is not available, setting strategy = 'cpu fp32', RWKV_CUDA_ON = 0 will work, but the generation speed will be slower.
- The model to be used can be either a 1B parameter model of about 3GB or a 3B parameter model of about 6GB by selecting MODEL_URL or entering an arbitrary URL. Using a model with a higher number of parameters will result in a more intelligent response, but will also require a larger storage size and higher spec runtime for VRAM, RAM, etc.
- The free version of Google Colaboratory seems to be able to run 1B and 3B models as is.
- Google's free storage is 15GB, but it is easy to get tight if you store models in Google Drive.

## License
- This notebook  
  Copyright 2023 Yosuke Yamada  
  Licensed under the Apache License, Version 2.0  
  http://www.apache.org/licenses/LICENSE-2.0
- Please check the license of software and models used/downloaded from the notebook individually.



# Check the environment

In [None]:
# Check the status of the CUDA environment on NVIDIA's system management interface (for GPUs, not necessary if only CPU is used)
!nvidia-smi

Sat Apr 22 16:38:49 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   47C    P8    10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# Check the version of the Cuda compiler (for GPU, not necessary if using CPU only).
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0


# Environment Settings

In [None]:
# If you use the default torch in the Google Colab environment, an error may occur if you set RWKV_CUDA_ON = 1 to speed up the process, so re-install the torch (first time only).
!pip uninstall -y torch
!pip install torch==2.0.0+cu118 -f https://download.pytorch.org/whl/torch_stable.html

Found existing installation: torch 2.0.0+cu118
Uninstalling torch-2.0.0+cu118:
  Successfully uninstalled torch-2.0.0+cu118
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==2.0.0+cu118
  Downloading https://download.pytorch.org/whl/cu118/torch-2.0.0%2Bcu118-cp39-cp39-linux_x86_64.whl (2267.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 GB[0m [31m772.7 kB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch
Successfully installed torch-2.0.0+cu118


In [None]:
# Check if CUDA is available in Pytorch. If so, "True" and the device number will be returned. If it fails, run the cell above again. If only CPU-only (no GPU) runtime is used, there is no need to run this cell.
import torch
print(torch.cuda.is_available())
print(torch.cuda.current_device())

True
0


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
# If there is no source code folder, create it and cd it.
import os
os.makedirs("/content/drive/My Drive/src", exist_ok=True)
%cd '/content/drive/My Drive/src'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/src


In [None]:
# Get source code of ChatRWKV by git clone (first time only)
!git clone https://github.com/BlinkDL/ChatRWKV

Cloning into 'ChatRWKV'...
remote: Enumerating objects: 1340, done.[K
remote: Counting objects: 100% (216/216), done.[K
remote: Compressing objects: 100% (149/149), done.[K
remote: Total 1340 (delta 88), reused 162 (delta 52), pack-reused 1124[K
Receiving objects: 100% (1340/1340), 26.97 MiB | 17.77 MiB/s, done.
Resolving deltas: 100% (723/723), done.


In [None]:
# Move to ChatRWKV/v2 folder
%cd 'ChatRWKV/v2'

/content/drive/My Drive/src/ChatRWKV/v2


In [None]:
# Install rwkv and ninja package by pip
!pip install rwkv ninja

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Prepare Model

In [None]:
# Specify model and download
# Raven  https://huggingface.co/BlinkDL/rwkv-4-raven/tree/main
# Others https://huggingface.co/BlinkDL
MODEL_URL = 'https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-1B5-v10-Eng99%25-Other1%25-20230418-ctx4096.pth' #@param ['https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-1B5-v10-Eng99%25-Other1%25-20230418-ctx4096.pth', 'https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v8-EngAndMore-20230408-ctx4096.pth"]  {allow-input: true}
!curl -OLC - $MODEL_URL
# MODEL_NAME is the file name after the "/" sign near the end of MODEL_URL, with the ".pth" extension removed
MODEL_NAME = MODEL_URL[MODEL_URL.rfind('/') + 1:].rstrip('.pth')

** Resuming transfer from byte position 3030279730
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  1217  100  1217    0     0   4791      0 --:--:-- --:--:-- --:--:--  4791
100    49  100    49    0     0    173      0 --:--:-- --:--:-- --:--:--   173


In [None]:
# Settings for RWKV model behavior
strategy = 'cuda fp16' #@param ['cpu fp32', 'cuda fp16', 'cuda:0 fp16 -> cuda:1 fp16', 'cuda fp16i8 *10 -> cuda fp16', 'cuda fp16i8', 'cuda fp16i8 -> cpu fp32 *10', 'cuda fp16i8 *10+'] {allow-input: true}
# 'cpu fp32' # If CUDA is not available and CPU is used
# 'cuda fp16' # Default value when CUDA is available
# 'cuda:0 fp16 -> cuda:1 fp16' # If two GPUs can be used
# 'cuda fp16i8 *10 -> cuda fp16' # first 10 layers cuda int8 quantization, rest cuda fp16i8
# 'cuda fp16i8' # all layers cuda int8 quantization
# 'cuda fp16i8 -> cpu fp32 *10' # first is cuda fp16i8, subsequent 10 layers are cpu fp32
# 'cuda fp16i8 *10+' # first 10 layers cuda int8 quantize, rest dynamically load as needed

# 1 if CUDA is available, 0 if not
RWKV_CUDA_ON = "1" #@param [0, 1]

# Language used in chat
CHAT_LANG = 'English' #@param ["Japanese", "English", "Chinese"]
# English
# Chinese
# Japanese

# Context length of model
ctx_len = 1024 #@param {type:"integer"}

# Create dictionary to replace contents of chat.py
replacements = {
    'args.strategy = .*' : 'args.strategy = \'' + strategy + '\'',
    'os\.environ\[\"RWKV_CUDA_ON\"\] = \'.' : 'os.environ["RWKV_CUDA_ON"] = \'' + str(RWKV_CUDA_ON),
    'CHAT_LANG = .*' : 'CHAT_LANG = \'' + CHAT_LANG + '\'',
    'args.MODEL_NAME = .*' : 'args.MODEL_NAME = \'' + MODEL_NAME + '\'',
    'args.ctx_len = .*' : 'args.ctx_len = \'' + str(ctx_len) + '\'',
    'current_path = os\.path\.dirname\(os\.path\.abspath\(__file__\)\)' : 'current_path = os.getcwd()',
    'while True:\s+msg = prompt.+\s+if len\(msg.+\s+on_message\(msg\)\s+else:\s+print\(.+' : ''
}


In [None]:
# Replace variables in chat.py and save in chat-notebook.py
import re

with open('chat.py', encoding='utf-8') as f:
    text = f.read()

for old, new in replacements.items():
    pattern = re.compile(r'^(\s*)' + old, flags=re.MULTILINE)
    text = pattern.sub(r'\1' + new, text)

with open('chat-notebook.py', 'w', encoding='utf-8') as f:
    f.write(text)

In [None]:
# Load the model
execfile("chat-notebook.py")



ChatRWKV v2 https://github.com/BlinkDL/ChatRWKV

English - cuda fp16 - /content/drive/MyDrive/src/ChatRWKV/v2/prompt/default/English-2.py


Using /root/.cache/torch_extensions/py39_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py39_cu118/wkv_cuda/build.ninja...
Building extension module wkv_cuda...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module wkv_cuda...


Loading model - RWKV-4-Raven-1B5-v10-Eng99%25-Other1%25-20230418-ctx4096
RWKV_JIT_ON 1 RWKV_CUDA_ON 1 RESCALE_LAYER 6

Loading RWKV-4-Raven-1B5-v10-Eng99%25-Other1%25-20230418-ctx4096.pth ...
Strategy: (total 24+1=25 layers)
* cuda [float16, float16], store 25 layers
0-cuda-float16-float16 1-cuda-float16-float16 2-cuda-float16-float16 3-cuda-float16-float16 4-cuda-float16-float16 5-cuda-float16-float16 6-cuda-float16-float16 7-cuda-float16-float16 8-cuda-float16-float16 9-cuda-float16-float16 10-cuda-float16-float16 11-cuda-float16-float16 12-cuda-float16-float16 13-cuda-float16-float16 14-cuda-float16-float16 15-cuda-float16-float16 16-cuda-float16-float16 17-cuda-float16-float16 18-cuda-float16-float16 19-cuda-float16-float16 20-cuda-float16-float16 21-cuda-float16-float16 22-cuda-float16-float16 23-cuda-float16-float16 24-cuda-float16-float16 
emb.weight                        f16      cpu  50277  2048 
blocks.0.ln1.weight               f16   cuda:0   2048       
blocks.0.ln1.bias  

# Chat

In [None]:
# Settings related to sentence generation

# Short response length for chat
CHAT_LEN_SHORT = 40 #@param {type:"integer"}
# Long response length for chat
CHAT_LEN_LONG = 150 #@param {type:"integer"}
# length of freely generated sentences
FREE_GEN_LEN = 256 #@param {type:"integer"}

# For better chat & QA quality: reduce temp, reduce top-p, increase repetition penalties
# Explanation: https://platform.openai.com/docs/api-reference/parameter-details

# GEN_TEMP and GEN_TOP_P: smaller values increase accuracy, larger values increase diversity
GEN_TEMP = 1.1 #@param {type:"number"} # sometimes it's a good idea to increase temp. try it
GEN_TOP_P = 0.7 #@param {type:"number"}
# GEN_alpha_presence, GEN_alpha_frequency: Penalty for presence and frequency of repeated strings. Larger values suppress repetition.
GEN_alpha_presence = 0.2 #@param {type:"number"} # Presence Penalty
GEN_alpha_frequency = 0.2 #@param {type:"number"} # Frequency Penalty
# AVOID_REPEAT: character to prevent repetition
AVOID_REPEAT = '，：？！' #@param {type:"string"}
# Chunk length to split input
CHUNK_LEN = 256 #@param {type:"integer"} # split input into chunks to save VRAM (shorter -> slower)

AVOID_REPEAT_TOKENS = []
for i in AVOID_REPEAT:
    dd = pipeline.encode(i)
    assert len(dd) == 1
    AVOID_REPEAT_TOKENS += dd


In [None]:
# Running a chat
while True:
    msg = input(f'{user}{interface} ')
    if len(msg.strip()) > 0:
        on_message(msg)
    else:
        print('Error: please say something')

Bob: Hi. Please introduce yourself.
Alice: Hello, my name is Alice and I am a strong believer in logic and rationality. I love the beauty of the world and all its complexity. I am interested in politics, economics, and science. I have an excellent understanding of philosophy and I am an active member of a few philosophical societies.

