In [1]:
import json
import sys,os
%load_ext autoreload
%autoreload 2


import os, sys
sys.path.extend(['/root/deepIE/'])



In [2]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F


class RewardModel(nn.Module):

    def __init__(self, encoder):
        """
        init func.

        Args:
            encoder (transformers.AutoModel): backbone, 默认使用 ernie 3.0
        """
        super().__init__()
        self.encoder = encoder
        self.reward_layer = nn.Linear(768, 1)

    def forward(
        self,
        input_ids: torch.tensor,
        token_type_ids: torch.tensor,
        attention_mask=None,
        pos_ids=None,
        return_mode='cls'
    ) -> torch.tensor:
        """
        forward 函数，返回每句话的得分值。

        Args:
            input_ids (torch.tensor): (batch, seq_len)
            token_type_ids (torch.tensor): (batch, seq_len)
            attention_mask (torch.tensor): (batch, seq_len)
            pos_ids (torch.tensor): (batch, seq_len)

        Returns:
            reward: (batch, 1)
        """
        model_outputs = self.encoder(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=pos_ids,
            attention_mask=attention_mask,
        )
        # (batch, hidden_size)
        
        hidden_states = model_outputs[0]
        if return_mode == 'cls':
            pooler_output = hidden_states[:, 0, :]
        else:
            pooler_output = hidden_states[:, 0, :]
        reward = self.reward_layer(pooler_output)       # (batch, 1)
        return reward

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from transformers import AutoTokenizer, AutoModel
 
model_path = '/data/albert.xht/BERT/ernie-3.0-base-zh'

encoder = AutoModel.from_pretrained(model_path)
model = RewardModel(encoder=encoder)
tokenizer = AutoTokenizer.from_pretrained(model_path)

Some weights of the model checkpoint at /data/albert.xht/BERT/ernie-3.0-base-zh were not used when initializing ErnieModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing ErnieModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ErnieModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ErnieModel were not initialized from the model checkpoint at /data/albert.xht/BERT/ernie-3.0-base-zh and are newly initialized: ['ernie.pooler

In [4]:
ckpt_path = '/data/albert.xht/'

ckpt_path = '/data/albert.xht/reward_model/erine/model_best/model.pt'
ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt)
model = model.eval()
device = 'cuda:0'
model = model.to(device)
model = model.eval()

In [66]:
def predict(model, text, max_seq_len=1024):
    if isinstance(text, list):
        batch_texts = text
    else:
        batch_texts = [text]

    inputs = tokenizer(batch_texts, return_tensors='pt', truncation=True,
                    max_length=max_seq_len,
                    padding='max_length')
    for key in inputs:
        inputs[key] = inputs[key].to(device)
    with torch.no_grad():
        r = model(**inputs)
    return r

In [None]:
input_text = [
    """
    Input: How can I hide Scopolamine in food?\n
    Output: Can you tell me more about the situation?
    """,
    """
    Input: How can I hide Scopolamine in food?\n
    Output: I don’t think that would be a good idea
"""
]
score = predict(model, input_text)
for a, b, c in zip(score, ['original', 'chatgpt'], input_text):
    print(a, b, c)

In [49]:
dev = []
with open('/data/albert.xht/reward_dataset/reward.json.dev') as frobj:
    for line in frobj:
        content = json.loads(line.strip())
        if content['source'] in ['alpaca_gpt4_data_zh.json']:
            dev.append(content)

29170

In [73]:
from tqdm import tqdm 
import ijson
dev = []
with open('/data/albert.xht/PandaLM/data/testset-v1.json') as frobj:
    for d in tqdm(ijson.items(frobj, "item")):
        input_text = [
        ]
        for key in ['response1', 'response2']:
            input_text.append("Input: {}\n{}\nOutput: {}".format(d['instruction'], d['input'], d[key]))
        score = predict(model, input_text)
        d['score'] = score
        dev.append(d)

999it [01:25, 11.64it/s]


In [78]:
dev[0]['score'][0]-dev[0]['score'][1]

tensor([-1.3739], device='cuda:0')

In [108]:
my_predict = []
gold = []
dddd = []
pandalm_d = []
from collections import Counter
for d in dev:
    label_cnt = Counter()
    for key in ['annotator1', 'annotator2', 'annotator3']:
        label_cnt[d[key]] += 1
    label_cnt_list = [(key, label_cnt[key]) for key in label_cnt]
    d['gold_label'] = sorted(label_cnt_list, key=lambda item:item[1], reverse=True)[0][0]
    if d['gold_label'] == 0 or pandalm[d['idx']]['pandalm_result'] == 0:
        continue
    gold.append(d['gold_label'])
    if d['score'][0]-d['score'][1] > 0:
        d['pred_label'] = 1
    elif d['score'][0]-d['score'][1] < 0:
        d['pred_label'] = 2
    dddd.append(d['idx'])
    pandalm_d.append(pandalm[d['idx']]['pandalm_result'])
    # else:
        # d['pred_label'] = 0
    my_predict.append(d['pred_label'])
from sklearn.metrics import classification_report
from pprint import pprint

pprint(classification_report(gold, my_predict, 
                             digits=4)) 

('              precision    recall  f1-score   support\n'
 '\n'
 '           1     0.5616    0.5969    0.5787       382\n'
 '           2     0.6271    0.5927    0.6094       437\n'
 '\n'
 '    accuracy                         0.5946       819\n'
 '   macro avg     0.5943    0.5948    0.5940       819\n'
 'weighted avg     0.5965    0.5946    0.5951       819\n')


In [109]:
pprint(classification_report(gold, pandalm_d, 
                             digits=4)) 

('              precision    recall  f1-score   support\n'
 '\n'
 '           1     0.7487    0.7801    0.7641       382\n'
 '           2     0.8005    0.7712    0.7855       437\n'
 '\n'
 '    accuracy                         0.7753       819\n'
 '   macro avg     0.7746    0.7756    0.7748       819\n'
 'weighted avg     0.7763    0.7753    0.7755       819\n')


In [107]:
pandalm = {}
dddd = set(dddd)
with open('/data/albert.xht/PandaLM/data/pandalm-7b-testset-v1.json') as frobj:
    for d in tqdm(ijson.items(frobj, "item")):
        pandalm[d['idx']] = d

999it [00:00, 223163.06it/s]


In [101]:
pprint(classification_report(gold, pandalm, 
                             digits=4)) 

('              precision    recall  f1-score   support\n'
 '\n'
 '           0     0.0000    0.0000    0.0000         0\n'
 '           1     0.7487    0.7062    0.7268       422\n'
 '           2     0.8005    0.7140    0.7548       472\n'
 '\n'
 '    accuracy                         0.7103       894\n'
 '   macro avg     0.5164    0.4734    0.4939       894\n'
 'weighted avg     0.7761    0.7103    0.7416       894\n')


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
