In [None]:
import json
import os

import gokart
import luigi
import pandas as pd
from openai import OpenAI
from tqdm import tqdm
from qrelllm.llm.prompt import gen_rel_prompt
from qrelllm.eval import CohenKappa
from qrelllm.queries import LoadQueries
from qrelllm.format import clean_json
from qrelllm.llm.vertex import TestCollection
from qrelllm.load import LoadTestCollection

In [None]:


class RelDecision(gokart.TaskOnKart):
    """
    与えられたクエリに関連する記事タイトルと関連のない記事タイトルを生成するタスク
    """

    testcollection = gokart.TaskInstanceParameter()
    _version: int = luigi.Parameter(default=1)

    def run(self):
        df = self.load_data_frame(required_columns={"query", "title"})

        client = OpenAI()

        df = df.groupby("query")["title"].apply(list) \
            .reset_index(name="titles")
        d = dict(zip(df["query"], df["titles"]))

        results = []
        errors = []
        for k, v in tqdm(d.items()):
            completion = client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=[
                    {
                        "role": "system",
                        "content": gen_rel_prompt(k, v)
                    }
                ]
            )

            md_str = completion.choices[0].message.content
            
            json_str = clean_json(md_str)
            try:
                result = json.loads(json_str)
            except Exception as e:
                errors.append(json_str)
                continue
            results.extend(result)

        df = pd.DataFrame(results)
        print(f"{len(errors)} errors")
        self.dump(df)



In [None]:
testcollection = LoadTestCollection(rerun=True)
df = gokart.build(RelDecision(testcollection=testcollection))

In [None]:
df.to_csv("../data/rel.csv", index=False)

In [None]:
testcollection_a = testcollection
testcollection_b = RelDecision(testcollection=testcollection_a)

gokart.build(CohenKappa(testcollection_a=testcollection_a, testcollection_b=testcollection_b, rerun=True))