## CS310 Natural Language Processing
## Lab 12: Play with Prompting and LoRA

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

### T1. Play with zero-shot and few-shot prompting with ChatGLM-3

**Step 1)** Download the ChatGLM-3 model from ModelScope: https://modelscope.cn/models/ZhipuAI/chatglm3-6b/files
 - `model.safetensors.index.json`, `config.json`, `configuration.json`
 - `model-00001-of-00007.safetensors` to `model-00007-of-00007.safetensors`
 - `tokenizer_config.json`, `tokenizer.model`
Put all the files in a folder such as `./chatglm3-6b`. Or, you can directly download the zip file from the course website and unzip it.

**Step 2)** Download and build the tool `chatglm.cpp` (https://github.com/li-plus/chatglm.cpp), which allows you to run most Chinese LLMs locally on your laptop computer. 
 - Follow the instruction in the repository's README, and test it with the ChatGLM-3 model downloaded at Step 1.

**Step 3)** Interact with ChatGLM-3 in the command line, and try to solve the following problems.
 - Use zero-shot and few-shot prompting to solve the problems.
 - Add Chain-of-Thought prompt if needed.

Try solving these problems with prompting:
1. Q: A juggler can juggle 16 balls. Half of the balls are golf balls, and half of the golf balls are blue. How many blue golf balls are there? A: 
2. 鸡和兔在一个笼子里，共有35个头，94只脚，那么鸡有多少只，兔有多少只？
3. Q: 242342 + 423443 = ? A: 
4. 一个人花8块钱买了一只鸡，9块钱卖掉了，然后他觉得不划算，花10块钱又买回来了，11块卖给另外一个人。问他赚了多少?

### T2. Implement LoRA (Basics)

Low rank adaptation (LoRA) applies to the query and value matrcies of the attentation layer, i.e., $W^Q$ and $W^V$. 

$W^Q$ and $W^V$ are usually implemented as `nn.Linear` layers in PyTorch, so here we implement `LoRALinear` as a subclass of `nn.Linear`.

There are two places you need to implement:
1. In the `__init__` function, implement the `A` and `B` matrices as instances of `nn.Parameter`.
   - `A` is in shape of `(lora_rank, in_features)`; `B` is in shape of `(out_features, lora_rank)`.
   - Initialize them with `torch.empty`; `reset_parameters` already takes care of later initialization.
2. In the `forward` function, implement the LoRA equation of computing the hidden state `h`:
   - `h = W(x) + B(A(lora_dropout(x))) * scaling`, where `W(x)` is already implemented for you; `lora_dropout` is defined for you. 
   - The parentheses `()` calls need be implemented using `torch.nn.functional.linear` (https://pytorch.org/docs/stable/nn.functional.html#linear-functions).
   - `scaling` is the class attribute `self.lora_scaling`.

In [11]:
# Overwriting the methods of nn.Linear:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
class LoRALinear(nn.Linear):

    def __init__(self,
                 # nn.Linear parameters
                 in_features: int,
                 out_features: int,
                 bias: bool = True,
                 device=None,
                 dtype=None,
                 # LoRA parameters
                 lora_rank: int = 0,
                 lora_alpha: float = 0.0,
                 lora_dropout: float = 0.0,
                ) -> None:
        nn.Linear.__init__(
            self,
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            device=device,
            dtype=dtype
        )

        # LoRA stuff
        self.has_weights_merged = False
        if lora_rank > 0:
            self.lora_dropout = nn.Dropout(lora_dropout)

            self.lora_scaling = lora_alpha / lora_rank
            ### START YOUR CODE ###
            # self.lora_A = None
            # self.lora_B = None
            self.lora_A = nn.Parameter(torch.empty(lora_rank, in_features))
            self.lora_B = nn.Parameter(torch.empty(out_features, lora_rank))
            ### END YOUR CODE ###

            self.lora_A.requires_grad = False
            self.lora_B.requires_grad = False

            self.reset_parameters()

    def is_lora(self) -> bool:
        return hasattr(self, 'lora_A')

    def reset_parameters(self) -> None:
        nn.Linear.reset_parameters(self)
        if self.is_lora():
            torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) # Same as nn.Linear
            torch.nn.init.zeros_(self.lora_B)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = nn.Linear.forward(self, input) # This is W(x)
        if not self.has_weights_merged and self.is_lora():
            ### START YOUR CODE ###
            # h = W(x) + B(A(lora_dropout(x))) * scaling
            # h = None
            h = x + self.lora_B @ torch.nn.functional.linear(self.lora_dropout(x), self.lora_A) * self.lora_scaling
        else:
            h=x
            ### END YOUR CODE ###
        return h

    def extra_repr(self) -> str:
        out = nn.Linear.extra_repr(self)
        if self.is_lora():
            out += f', lora_rank={self.lora_A.shape[0]}, lora_scaling={self.lora_scaling}, lora_dropout={self.lora_dropout.p}'
        return out

    def train(self, mode: bool = True) -> "LoRALinear":
        nn.Linear.train(self, mode)
        if self.has_weights_merged and self.is_lora():
            # de-merge weights, i.e., remove BA from W = W + BA
            self.weight.data -= self.lora_scaling * self.lora_B @ self.lora_A
            self.has_weights_merged = False
        return self

    def eval(self) -> "LoRALinear":
        nn.Linear.eval(self)
        if not self.has_weights_merged and self.is_lora():
            # merge weights, i.e., add BA to W
            self.weight.data += self.lora_scaling * self.lora_B @ self.lora_A
            self.has_weights_merged = True
        return self

In [12]:
# Test
from dataclasses import dataclass
torch.random.manual_seed(42)

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True 
    # LoRA parameters
    lora_rank: int = 0
    lora_alpha: float = 0.0
    lora_dropout: float = 0.0

config = GPTConfig()
attn = LoRALinear(
            in_features=config.n_embd,
            out_features=3 * config.n_embd,
            bias=config.bias,
            lora_rank=config.lora_rank,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout
        )

with torch.no_grad():
    x = torch.randn(1, config.block_size, config.n_embd)
    x2 = attn(x)
    print(x.shape)
    print(x2.shape)
    print(x2[0, 0, :5])

# Expected output:
# torch.Size([1, 1024, 768])
# torch.Size([1, 1024, 2304])
# tensor([-0.7818,  0.0917,  0.1308, -0.3660,  1.2284])

torch.Size([1, 1024, 768])
torch.Size([1, 1024, 2304])
tensor([-0.7818,  0.0917,  0.1308, -0.3660,  1.2284])
