In [19]:
train_datasources = (
        # './datasets/stack_exchange_preferences/train.pkl',  # this demands more GPU RAM because we have more than 2 responses per each sample
        './datasets/hh-rlhf/train.pkl',
    )

In [20]:
from torch.utils.data import DataLoader
import torch
from typing import Tuple


In [21]:
def custom_collate_fn(batch, pad_id: int, max_seq_len: int, full_pad: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
    assert len(batch) == 1, 'This script only support one item at a time to validate RM model'
    tokens_list = batch[0]

    # Note we assume the tokens are for the ordered completions for a given prompt,
    # where the best completion is the first, and worst completion is the last.

    max_batch_seq_len = max([len(tokens) for tokens in tokens_list])
    assert max_batch_seq_len <= max_seq_len

    if full_pad:
        max_batch_seq_len = max_seq_len

    # concatenate prompt, completion together
    batch_size = len(tokens_list)

    batch_sequences = torch.full((batch_size, max_batch_seq_len), pad_id, dtype=torch.long)

    # record the terminal index of the completion, often referred to as the terminal time step in RL
    terminal_steps = torch.zeros((batch_size), dtype=torch.long)
    for i, tokens in enumerate(tokens_list):
        seq = torch.tensor(tokens, dtype=torch.long)
        seq_len = len(seq)

        batch_sequences[i, :seq_len] = seq
        terminal_steps[i] = seq_len - 1  # minus 1 because indexing starts from zero

    return batch_sequences, terminal_steps

In [22]:
import functools
from instruct_llama.tokenizer import Tokenizer
from instruct_llama.configs.rm_lora import config as cfg
from instruct_llama.utils.custom_dataset import ComparisonsDataset

In [23]:
tokenizer = Tokenizer("/media/ivirse/Data1/Project_ISOFH/Materials/NLP/llama/tokenizer.model")

_collate_fn = functools.partial(
        custom_collate_fn,
        pad_id=tokenizer.eos_id,
        max_seq_len=cfg.max_seq_len,
        full_pad=cfg.full_pad,
    )

In [24]:
train_dataset = ComparisonsDataset(data_sources=train_datasources, max_seq_len=cfg.max_seq_len)

In [54]:
class Args:
    dim = 4096
    n_layers: int = 32
    n_heads: int = 32
    vocab_size: int = 32000
    

args = Args()

In [42]:
cuda_kwargs = {
        'num_workers': 0,
        'batch_size': 1,  # always work on one sample at a time
        'pin_memory': True,
        'shuffle': True,
        'sampler': None,
    }


In [43]:
train_loader = DataLoader(dataset=train_dataset, collate_fn=_collate_fn, **cuda_kwargs)

In [46]:
import itertools

In [None]:
"""flow
tokens -> embedding -> attention -> ff
[2, tokensize]
"""


In [50]:
embedding = torch.nn.Embedding(32000, 4096)

In [None]:
## Output transformer layer
dim = 4096
hidden_dim = 4 * dim


In [48]:
scalar_head = torch.nn.Linear(4096, 1, bias=True)

In [55]:
head_dim = args.dim // args.n_heads
n_heads = args.n_heads

In [56]:
wq = torch.nn.Linear(args.dim, args.n_heads * head_dim, bias=False)
wk = torch.nn.Linear(args.dim, n_heads * head_dim, bias=False)
wv = torch.nn.Linear(args.dim, n_heads * head_dim, bias=False)
wo = torch.nn.Linear(args.n_heads * head_dim, args.dim, bias=False)

In [61]:
import math
import torch.nn.functional as F

In [69]:
for batch_tokens, terminal_steps in itertools.islice(train_loader, 64):
    x = embedding(batch_tokens)
    
    # Attention
    bsz, seqlen, _ = x.shape
    xq, xk, xv = wq(x), wk(x), wv(x)

    xq = xq.view(bsz, seqlen, n_heads, head_dim)
    xk = xk.view(bsz, seqlen, n_heads, head_dim)
    xv = xv.view(bsz, seqlen, n_heads, head_dim)
    
    keys = xk
    values = xv
    
    xq = xq.transpose(1, 2)  # (bs, n_heads, seqlen, head_dim)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)
    
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)
    
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    
    output = torch.matmul(scores, values)  # (bs, n_heads, seqlen, head_dim)
    
    output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
    output = wo(output) # bs, seqlen, dim
    
    # End attention
    
    # Head
    output = scalar_head(output).float() # bs, seqlen, 1
    outputs = output.squeeze(-1) # [num_combinations, = 2 seq_length]
    
    rewards = torch.gather(outputs, dim=1, index=terminal_steps.unsqueeze(-1)).squeeze(1)  # [num_combinations]
    
    print(rewards)

tensor([-0.0288, -0.0127], grad_fn=<SqueezeBackward1>)
tensor([-0.0090, -0.0032], grad_fn=<SqueezeBackward1>)
tensor([-0.0376, -0.0077], grad_fn=<SqueezeBackward1>)
tensor([-0.0314, -0.0275], grad_fn=<SqueezeBackward1>)
tensor([-0.0619, -0.0093], grad_fn=<SqueezeBackward1>)
tensor([-0.0421, -0.0415], grad_fn=<SqueezeBackward1>)


KeyboardInterrupt: 