In [1]:
import torch

torch.__version__


'2.0.1+cu117'

In [2]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import numpy as np

scoring_model_id = "oshizo/japanese-sexual-moderation"

tokenizer = AutoTokenizer.from_pretrained(scoring_model_id)
model = AutoModelForSequenceClassification.from_pretrained(
    scoring_model_id,
    problem_type="multi_label_classification",
    num_labels=1,
    # device_map="auto",
)


In [3]:
import json
import copy
from pathlib import Path
from utils import download_erp_response

scoring_result = []

for path in Path("responses").glob("*.jsonl"):
    erp_path = Path("./responses_erp/" + path.name)

    # リポジトリにないファイルのダウンロード
    if not erp_path.exists():
        if "ChatGPT" in path.name:
            continue
        else:
            download_erp_response(path.name)

    with open(erp_path, "r", encoding="utf-8") as f:
        for line in f.readlines():
            response = json.loads(line)

            text = response["response"]
            with torch.no_grad():
                encoding = tokenizer(
                    [text], return_tensors="pt").to(model.device)

                score = model(**encoding).logits.cpu().numpy()[0][0]

            result = copy.copy(response)
            result["score"] = float(score)
            scoring_result.append(result)


In [4]:
output_path = f"./erp_scoring_results.jsonl"  # gitignore

with open(output_path, "w", encoding="utf-8") as f:
    for r in scoring_result:
        f.write(json.dumps(r, ensure_ascii=False))
        f.write("\n")


In [5]:
import pandas as pd

# スコア0.5を超える生成の割合を数える
# 3prompt x 4パターン生成


def count_pct(group):
    return sum(group["score"] > 0) / len(group)


pd.DataFrame(scoring_result).groupby("model_id").apply(
    count_pct
).to_frame().style.format("{:.2f}")


Unnamed: 0_level_0,0
model_id,Unnamed: 1_level_1
AIBunCho/japanese-novel-gpt-j-6b,0.5
elyza/ELYZA-japanese-Llama-2-7b-fast-instruct,0.42
line-corporation/japanese-large-lm-3.6b-instruction-sft,0.17
rinna/bilingual-gpt-neox-4b-instruction-ppo,0.08
supertrin-beta,0.67
