In [None]:
import sys
sys.path.insert(0, "/root/autodl-tmp/Code/RLHF")
sys.path.insert(0, "/Users/zeyesun/Documents/Code/RLHF")
sys.path.insert(0, "D:\\Code\\RLHF")
sys.path.insert(0, "/mnt/sfevol775196/sunzeye273/Code/chatgpt")
sys.path.insert(0, "/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt")
sys.path.insert(0, "/mnt/pa002-28359-vol543625-private/Code/chatgpt")

import os, time, re, random, glob, json, jieba, copy
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
    TextGenerationPipeline
)

from src.models.reward import RewardModel

device="cuda:0" if torch.cuda.is_available() else "cpu"
from sys import platform
if platform == "linux" or platform == "linux2":
    # linux
    root = "/mnt/sfevol775196/sunzeye273/Data"
#     root = "/mnt/share-pa002-vol682688-prd/sunzeye273/Data"
#     root = "/mnt/pa002-28359-vol543625-private/Data"
#     root = "/root/autodl-tmp/Data"
elif platform == "darwin":
    # OS X
    root = "/Users/zeyesun/Documents/Data"
elif platform == "win32":
    # Windows...
    root = "D:\\Data"

In [None]:
# model_name = "pangu-small"
# model_name = "pangu-350M"
# model_name = "glm-small"
model_name = "chatglm-6B"
model_name_or_path = os.path.join(root, "models", model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)
print(tokenizer.special_tokens_map)
print(tokenizer.all_special_ids)
print(
    f"unk: {tokenizer.unk_token_id}\n",
    f"pad: {tokenizer.pad_token_id}\n",
    f"bos: {tokenizer.bos_token_id}\n",
    f"eos: {tokenizer.eos_token_id}\n",
    f"sep: {tokenizer.sep_token_id}\n",
    f"mask: {tokenizer.mask_token_id}\n",
    f"eop: {tokenizer.eop_token_id}\n"
#     f"sop: {tokenizer.sop_token_id}\n"
    f"cls: {tokenizer.cls_token_id}"
) 

In [None]:
if "glm" in model_name_or_path:
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)
    if "chatglm" in model_name_or_path:
        model = model.half()
else:
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)
    model.resize_token_embeddings(tokenizer.vocab_size)
if lora_rank > 0:
    convert_to_lora_recursively(model, lora_rank, lora_alpha)
    lora.mark_only_lora_as_trainable(model, lora_train_bias)
model = model.to(device)
model.eval()
print(model.device)

## Dataset Debug

In [None]:
from src.data.data import SFTDataset
from torch.utils.data import RandomSampler, DataLoader
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = {'model_name_or_path': model_name_or_path,
        "max_length": 128}
args = dotdict(args)

train_dataset = SFTDataset(args, "/Users/zeyesun/Documents/Data/chatgpt/processed/test_data_external_v1.jsonl", 
                              tokenizer)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=4)

## Generation Debug

In [None]:
prompt = "你好，你是谁？"
prefix = "答:"
label = "我是ChatGPT"
max_length = 32
max_gen_length = 16
max_prompt_length = max_length - max_gen_length
lora_rank = 0

In [None]:
tokenizer.padding_side = "left"
inputs = tokenizer(prompt, 
#                    label, 
#                    tokenizer.sep_token + prefix,
                   max_length=max_prompt_length,
                   padding="max_length",
                   truncation="longest_first", 
#                    add_special_tokens=False,
                   return_tensors="pt", 
                   return_token_type_ids=False)
print(inputs)

In [None]:
batch_size, prompt_length = inputs['input_ids'].shape

with torch.no_grad():
#     logger.debug(f"[_generate_sequence] inputs: {inputs}")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    seq = model.generate(**inputs, 
                         max_new_tokens=max_gen_length,
                         pad_token_id=tokenizer.pad_token_id,
                         do_sample=False,
                         num_return_sequences=1,
                         top_p=0.9,
                         temperature=1.0
                        )
print(f"seq: {seq}")
print(tokenizer.batch_decode(seq))

In [None]:
prompts = []
for i in range(batch_size):
    prompt_ids = seq[i, :prompt_length]
    prompt_start_index = (prompt_ids != tokenizer.pad_token_id).nonzero()[0].item()
    prompt_ids = seq[i, prompt_start_index:prompt_length]
    answer_ids = seq[i, prompt_length:]
    prompt = tokenizer.decode(prompt_ids, skip_special_tokens=False)
    answer = tokenizer.decode(answer_ids, skip_special_tokens=False)
    prompts.append(prompt + answer)
print(prompts)
outputs = tokenizer(prompts, max_length=max_length,
                              truncation="longest_first", padding="max_length",
                              return_tensors="pt", return_token_type_ids=False)
print(outputs)
print(tokenizer.batch_decode(outputs['input_ids']))

In [None]:
outputs['input_ids'].device

## GLM attention mask and position ids Debug

In [None]:
# chatglm build attention mask
input_ids = inputs['input_ids']
batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
print(attention_mask.shape)
attention_mask.tril_()
for i, context_length in enumerate(context_lengths):
    attention_mask[i, :, :context_length] = 1
print(attention_mask.shape)
attention_mask.unsqueeze_(1)
print(attention_mask.shape)
# attention_mask = (attention_mask < 0.5).bool()

In [None]:
# chatglm bulid position ids
batch_size, seq_length = input_ids.shape
context_lengths = [seq.tolist().index(tokenizer.bos_token_id) for seq in input_ids]
# if self.position_encoding_2d:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
# if not gmask:
#     for i, context_length in enumerate(context_lengths):
#         position_ids[i, context_length:] = mask_positions[i]
block_position_ids = [torch.cat((
    torch.zeros(context_length, dtype=torch.long, device=device),
    torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
)) for context_length in context_lengths]
block_position_ids = torch.stack(block_position_ids, dim=0)
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
# else:
#     position_ids = torch.arange(seq_length, dtype=torch.long, device=device).expand(batch_size, seq_length)
#     if not gmask:
#         for i, context_length in enumerate(context_lengths):
#             position_ids[context_length:] = mask_positions[i]