In [None]:
import json
import os

import gokart
import luigi
import pandas as pd
import vertexai
from tqdm import tqdm
from vertexai.language_models import TextGenerationModel
from qrelllm.llm.prompt import gen_rel_prompt
from qrelllm.eval import CohenKappa
from qrelllm.queries import LoadQueries
from qrelllm.format import clean_json
from qrelllm.llm.vertex import TestCollection

In [None]:
project = os.getenv("GOOGLE_CLOUD_PROJECT_ID")
location = os.getenv("GOOGLE_CLOUD_LOCATION")

In [None]:


class RelDecision(gokart.TaskOnKart):
    """
    与えられたクエリに関連する記事タイトルと関連のない記事タイトルを生成するタスク
    """

    testcollection = gokart.TaskInstanceParameter()
    _version: int = luigi.Parameter(default=1)

    def run(self):
        df = self.load_data_frame(required_columns={'query', 'title'})

        vertexai.init(project=project, location=location)
        parameters = {
            "temperature": 1.0,
            "max_output_tokens": 1000,
            "top_p": 0.8,
            "top_k": 40,
        }

        model = TextGenerationModel.from_pretrained("text-bison@002")

        results = []
        errors = []

        df = df.groupby('query')['title'].apply(list).reset_index(name='titles')
        d = dict(zip(df['query'], df['titles']))
        for k, v in tqdm(d.items()):
            response = model.predict(
                gen_rel_prompt(k, v),
                **parameters,
            )

            json_str = clean_json(response.text)
            json_str = clean_json(json_str)
            try:
                result = json.loads(json_str)
            except Exception as e:
                errors.append(json_str)
                continue
            results.extend(result)

        df = pd.DataFrame(results)
        print(f'{len(errors)} errors')
        self.dump(df)



In [None]:
queries = LoadQueries(csv_file_path="../data/queries.csv")
testcollection = TestCollection(
    project=project,
    location=location,
    queries=queries,
    size=300
)
df = gokart.build(RelDecision(testcollection=testcollection, rerun=True))

In [None]:
df.to_csv("../data/rel.csv", index=False)

In [None]:
testcollection_a = testcollection
testcollection_b = RelDecision(testcollection=testcollection_a, rerun=True)

gokart.build(CohenKappa(testcollection_a=testcollection_a, testcollection_b=testcollection_b, rerun=True))