In [None]:
import sys
sys.path.insert(0, "/root/autodl-tmp/Code/RLHF")
sys.path.insert(0, "/Users/zeyesun/Documents/Code/RLHF")
sys.path.insert(0, "D:\\Code\\RLHF")
sys.path.insert(0, "/mnt/sfevol775196/sunzeye273/Code/chatgpt")
sys.path.insert(0, "/mnt/share-pa002-vol682688-prd/sunzeye273/Code/chatgpt")
sys.path.insert(0, "/mnt/pa002-28359-vol543625-private/Code/chatgpt")

import os, time, re, random, glob, json, jieba, copy
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoModelForMultipleChoice,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    default_data_collator,
    TextGenerationPipeline
)

from src.models.reward import RewardModel

device="cuda:0" if torch.cuda.is_available() else "cpu"
from sys import platform
if platform == "linux" or platform == "linux2":
    # linux
    root = "/mnt/sfevol775196/sunzeye273/Data"
#     root = "/mnt/share-pa002-vol682688-prd/sunzeye273/Data"
#     root = "/mnt/pa002-28359-vol543625-private/Data"
#     root = "/root/autodl-tmp/Data"
elif platform == "darwin":
    # OS X
    root = "/Users/zeyesun/Documents/Data"
elif platform == "win32":
    # Windows...
    root = "D:\\Data"

In [None]:
model_name = "pangu-small"
# model_name = "pangu-350M"
# model_name = "glm-small"
# model_name = "chatglm-6B"
model_name_or_path = os.path.join(root, "models", model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_cache=False, trust_remote_code=True)
print(tokenizer.special_tokens_map)
print(tokenizer.all_special_ids)
print(
    f"unk: {tokenizer.unk_token_id}\n",
    f"pad: {tokenizer.pad_token_id}\n",
    f"bos: {tokenizer.bos_token_id}\n",
    f"eos: {tokenizer.eos_token_id}\n",
    f"sep: {tokenizer.sep_token_id}\n",
    f"mask: {tokenizer.mask_token_id}\n",
#     f"eop: {tokenizer.eop_token_id}\n"
#     f"sop: {tokenizer.sop_token_id}\n"
#     f"cls: {tokenizer.cls_token_id}"
) 

In [None]:
prompt = "你好，你是谁？"
prefix = "答:"
label = "我是ChatGPT"
max_length = 128
max_gen_length = 16
max_prompt_length = max_length - max_gen_length
lora_rank = 0

In [None]:
if "glm" in model_name_or_path:
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, trust_remote_code=True)
    if "chatglm" in model_name_or_path:
        model = model.half()
else:
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, use_cache=False)
    model.resize_token_embeddings(tokenizer.vocab_size)
if lora_rank > 0:
    convert_to_lora_recursively(model, lora_rank, lora_alpha)
    lora.mark_only_lora_as_trainable(model, lora_train_bias)
model = model.to(device)
model.eval()
print(model.device)

In [None]:
from src.data.data import SFTDataset
from torch.utils.data import RandomSampler, DataLoader
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = {'model_name_or_path': model_name_or_path,
              "max_length": max_length}
args = dotdict(args)

train_dataset = SFTDataset(args, "/Users/zeyesun/Documents/Data/chatgpt/processed/test_data_external_v1.jsonl", 
                              tokenizer)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
        train_dataset,
        sampler=train_sampler,
        batch_size=4)

In [None]:
tokenizer.padding_side = "left"
inputs = tokenizer(prompt, tokenizer.sep_token + prefix,
                   max_length=max_prompt_length,
                   padding="max_length",
                   truncation="longest_first", 
                   add_special_tokens=False,
                   return_tensors="pt", 
                   return_token_type_ids=False)

batch_size, prompt_length = inputs['input_ids'].shape

with torch.no_grad():
#     logger.debug(f"[_generate_sequence] inputs: {inputs}")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    seq = model.generate(**inputs, 
                         max_new_tokens=max_gen_length,
                         pad_token_id=tokenizer.pad_token_id,
                         do_sample=False,
                         num_return_sequences=1,
                         top_p=0.9,
                         temperature=1.0
                        )
print(f"seq: {seq}")
print(tokenizer.batch_decode(seq))
# Filter out seq with no asnwers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning
# NOTE: this will causes each GPU has different number of examples

# print(f"prompt_length: {prompt_length}")
# ans = seq[:, prompt_length:]
# print(f"ans: {ans}")
# valid_ans_len = (ans != tokenizer.pad_token_id).sum(dim=-1)
# print(f"valid_ans_len: {valid_ans_len}")
# out_seq = []
# for i in range(batch_size):
#     # if the answer is shorter than 1 token, drop it
#     if valid_ans_len[i] <= 1:
#         continue
#     else:
#         out_seq.append(seq[i:i + 1])
# out_seq = torch.cat(out_seq, dim=0) 
# print(f"out_seq: {out_seq}")

In [None]:
prompts = []
for i in range(batch_size):
    prompt_ids = seq[i, :prompt_length]
    prompt_start_index = (prompt_ids != tokenizer.pad_token_id).nonzero()[0].item()
    prompt_ids = seq[i, prompt_start_index:prompt_length]
    answer_ids = seq[i, prompt_length:]
    prompt = tokenizer.decode(prompt_ids, skip_special_tokens=False)
    answer = tokenizer.decode(answer_ids, skip_special_tokens=False)
    prompts.append(prompt + answer)
print(prompts)
outputs = tokenizer(prompts, max_length=max_length,
                              truncation="longest_first", padding="max_length",
                              return_tensors="pt", return_token_type_ids=False)
print(outputs)
print(tokenizer.batch_decode(outputs['input_ids']))

In [None]:
outputs['input_ids'].device

In [None]:
chosen_id = inputs['input_ids'][0]
print(chosen_id)
seq_len = len(chosen_id)
c_inds = (chosen_id == tokenizer.pad_token_id).nonzero()
c_ind = c_inds[0].item() if len(c_inds) > 0 else seq_len

In [None]:
prompt_token = tokenizer(prompt, return_tensors="pt")
# prompt_token["input_ids"] = prompt_token["input_ids"]
# prompt_token["attention_mask"] = prompt_token["attention_mask"]
# print(prompt_token)
print(length - (max_length - 1))
for key_word in ["input_ids", "attention_mask"]:
    length = prompt_token[key_word].size()[-1]
    if length > max_seq_len:
        y = prompt_token[key_word].squeeze(0)[length - (max_length - 1):].flip(0)
    else:
        y = prompt_token[key_word].squeeze(0).flip(0)
    prompt_token[key_word] = y
# prompt_dataset.append(prompt_token)

In [None]:
# data = []
data.append([prompt_token["input_ids"], prompt_token["attention_mask"], tokenizer.pad_token_id])

In [None]:
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

batch = {}

pad_token_id = data[-1][-1]

prompt = pad_sequence([f[0] for f in data],
                      padding_value=pad_token_id,
                      batch_first=True)
prompt_mask = pad_sequence([f[1] for f in data],
                           padding_value=0,
                           batch_first=True)

### make sure the final ouput is a seqence of 2**?
length = prompt.size()[-1]
pad_length = max_seq_len - length
if pad_length > 0:
    batch["prompt"] = F.pad(prompt,
                            pad=(pad_length, 0),
                            mode='constant',
                            value=pad_token_id)
    batch["prompt_att_mask"] = F.pad(prompt_mask,
                                     pad=(pad_length, 0),
                                     mode='constant',
                                     value=0)
else:
    batch["prompt"] = prompt
    batch["prompt_att_mask"] = prompt_mask
batch["prompt"] = batch["prompt"].flip(1)
batch["prompt_att_mask"] = batch["prompt_att_mask"].flip(1)
# return batch