In [1]:
import json
import os
import random

import gokart
import luigi
import pandas as pd
import vertexai
from tqdm import tqdm
from vertexai.language_models import TextGenerationModel
from qrelllm.llm.prompt import gen_doc_rel_prompt

project = os.getenv("GOOGLE_CLOUD_PROJECT_ID")
location = os.getenv("GOOGLE_CLOUD_LOCATION")


def load_queries():
    with open("../data/queries.csv", "r") as f:
        return [line for line in f.readlines()]


class TestCollection(gokart.TaskOnKart):
    """
    与えられたクエリに関連する記事タイトルと関連のない記事タイトルを生成するタスク
    """

    size: int = luigi.IntParameter(default=10)

    def run(self):
        vertexai.init(project=project, location=location)
        parameters = {
            "temperature": 0.5,
            "max_output_tokens": 1000,
            "top_p": 0.95,
            "top_k": 40,
        }

        model = TextGenerationModel.from_pretrained("text-bison@002")

        queries = load_queries()
        queries = random.sample(queries, self.size)

        results = []
        errors = []
        for q in tqdm(queries):
            response = model.predict(
                gen_doc_rel_prompt(q),
                **parameters,
            )

            json_str = (
                response.text.removeprefix(" ")
                .removeprefix("```json")
                .removeprefix("```JSON")
                .removesuffix("```")
                .replace("\n", "")
                .replace(" ", "")
            )
            try:
                result = json.loads(json_str)
            except Exception as e:
                errors.append(json_str)
                continue
            results.extend(result)

        df = pd.DataFrame(results)
        print(f'{len(errors)} errors')
        self.dump(df)


def main():
    df = gokart.build(TestCollection(size=300))
    df.to_csv("../data/dataset_vertex_v1_1.csv", index=False)


if __name__ == "__main__":
    main()
