In [None]:
import json
import os
import random

import gokart
import luigi
import pandas as pd
import vertexai
from tqdm import tqdm
from openai import OpenAI
from vertexai.language_models import TextGenerationModel
from qrelllm.llm.prompt import gen_rel_prompt, gen_doc_rel_prompt
from qrelllm.eval import CohenKappa

project = os.getenv("GOOGLE_CLOUD_PROJECT_ID")
location = os.getenv("GOOGLE_CLOUD_LOCATION")


def load_queries():
    with open("../data/queries.csv", "r") as f:
        return [line for line in f.readlines()]

def clean_json(s: str) -> str:
    return (
        s.text.removeprefix(" ")
        .removeprefix("```json")
        .removeprefix("```JSON")
        .removesuffix("```")
        .replace("\n", "")
        .replace(" ", "")
    )

class TestCollection(gokart.TaskOnKart):
    """
    与えられたクエリに関連する記事タイトルと関連のない記事タイトルを生成するタスク
    """

    size: int = luigi.IntParameter(default=10)

    def run(self):
        vertexai.init(project=project, location=location)
        parameters = {
            "temperature": 0.5,
            "max_output_tokens": 1000,
            "top_p": 0.95,
            "top_k": 40,
        }

        model = TextGenerationModel.from_pretrained("text-bison@002")

        queries = load_queries()
        queries = random.sample(queries, self.size)

        results = []
        errors = []
        for q in tqdm(queries):
            response = model.predict(
                gen_doc_rel_prompt(q),
                **parameters,
            )

            json_str = clean_json(response.text)
            try:
                result = json.loads(json_str)
            except Exception as e:
                errors.append(json_str)
                continue
            results.extend(result)

        df = pd.DataFrame(results)
        print(f'{len(errors)} errors')
        self.dump(df)


class RelDecision(gokart.TaskOnKart):
    """
    与えられたクエリに関連する記事タイトルと関連のない記事タイトルを生成するタスク
    """

    testcollection = gokart.TaskInstanceParameter()
    _version: int = luigi.Parameter(default=1)

    def run(self):
        df = self.load_data_frame(required_columns={'query', 'title'})
        vertexai.init(project=project, location=location)
        parameters = {
            "temperature": 0.5,
            "max_output_tokens": 1000,
            "top_p": 0.95,
            "top_k": 40,
        }

        client = OpenAI()

        results = []
        errors = []

        df = df.groupby('query')['title'].apply(list).reset_index(name='titles')
        d = dict(zip(df['query'], df['titles']))
        for k, v in tqdm(d.items()):
            completion = client.chat.completions.create(
                model='gpt-3.5-turbo',
                messages=[
                    gen_rel_prompt(k, v)
                ]
            )

            json_str = completion.choices[0].message.content
            json_str = clean_json(json_str)
            try:
                result = json.loads(json_str)
            except Exception as e:
                errors.append(json_str)
                continue
            results.extend(result)

        df = pd.DataFrame(results)
        print(f'{len(errors)} errors')
        self.dump(df)


def main():
    testcollection = TestCollection(size=300)
    df = gokart.build(RelDecision(testcollection=testcollection, rerun=True))
    df.to_csv("../data/rel.csv", index=False)


if __name__ == "__main__":
    main()


In [None]:
testcollection_a = TestCollection(size=300)
testcollection_b = RelDecision(testcollection=testcollection_a, rerun=True)

gokart.build(CohenKappa(testcollection_a=testcollection_a, testcollection_b=testcollection_b, rerun=True))