In [None]:
%pip install --upgrade google-cloud-aiplatform


# Setup

In [1]:
from google.colab import auth

PROJECT_ID = "<PUT YOUR GCP PJ>"
auth.authenticate_user(project_id=PROJECT_ID)


In [2]:
import vertexai

vertexai.init(project=PROJECT_ID, location='us-central1')


# Model

In [13]:
from vertexai.language_models import TextEmbeddingModel
import numpy as np

# https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/text-embeddings#model_versions
model = TextEmbeddingModel.from_pretrained("textembedding-gecko-multilingual@001")

def get_embeddings(sentences: list[str], auto_truncate: bool = False) -> np.ndarray:
    embeddings = model.get_embeddings(sentences, auto_truncate=auto_truncate)
    return np.array([embedding.values for embedding in embeddings])

def batch_process_embeddings(sentences: list[str], batch_size: int = 5, auto_truncate: bool = False) -> np.ndarray:
    all_embeddings = []
    for i in range(0, len(sentences), batch_size):
        batch_sentences = sentences[i:i + batch_size]
        batch_embeddings = get_embeddings(batch_sentences, auto_truncate=auto_truncate)
        all_embeddings.append(batch_embeddings)

    return np.vstack(all_embeddings)


# JSTS

In [None]:
import json
import pandas as pd
from urllib.request import urlopen

jsts_url = "https://raw.githubusercontent.com/yahoojapan/JGLUE/main/datasets/jsts-v1.1/valid-v1.1.json"
df = pd.DataFrame([json.loads(line) for line in urlopen(jsts_url).readlines()])
df.head(1)


In [None]:
df.shape


## Encode

In [None]:
import numpy as np

sentence1_embs = batch_process_embeddings(df["sentence1"].tolist())
sentence2_embs = batch_process_embeddings(df["sentence2"].tolist())
sentence1_embs.shape, sentence2_embs.shape


## Correlation Score

In [32]:
from scipy.spatial.distance import cosine
from scipy.stats import spearmanr

df["similarity"] = [
    1 - cosine(s1, s2) for s1, s2 in zip(sentence1_embs, sentence2_embs)
]
spearmanr(df["similarity"], df["label"])[0]


0.8006039095558688

# JSICK

In [None]:
df = pd.read_csv(
    "https://github.com/verypluming/JSICK/raw/main/jsick/test.tsv", sep="\t"
)
df.head(1)


In [None]:
df.shape


In [None]:
sentence1_embs = []
sentence2_embs = []

for batch_index in range(0, len(df["sentence_A_Ja"]), 2048):
    cur_sentence1_list = df["sentence_A_Ja"][batch_index : batch_index + 2048]
    cur_sentence2_list = df["sentence_B_Ja"][batch_index : batch_index + 2048]

    cur_sentence1_embs = batch_process_embeddings(cur_sentence1_list.tolist())
    cur_sentence2_embs = batch_process_embeddings(cur_sentence2_list.tolist())

    sentence1_embs.extend(cur_sentence1_embs)
    sentence2_embs.extend(cur_sentence2_embs)

sentence1_embs = np.array(sentence1_embs)
sentence2_embs = np.array(sentence2_embs)

sentence1_embs.shape, sentence2_embs.shape


## Correlation Score

In [40]:
from scipy.spatial.distance import cosine
from scipy.stats import spearmanr

df["similarity"] = [
    1 - cosine(s1, s2) for s1, s2 in zip(sentence1_embs, sentence2_embs)
]
spearmanr(df["similarity"], df["relatedness_score_Ja"])[0]


0.803561121302977

# Miracle
* Need access token for huggingface

In [4]:
miracle_n_hard_negs = 300
miracle_n_recall = 30
query_prefix = ""
passage_prefix = ""

In [5]:
import os
import dotenv

dotenv.load_dotenv("huggingface_access_token", override=True)

True

In [None]:
import datasets

# query and positives
ds = datasets.load_dataset(
    "miracl/miracl", "ja", use_auth_token=os.environ["HF_ACCESS_TOKEN"], split="dev"
)
ds

In [None]:
# all corpus texts
corpus = datasets.load_dataset("miracl/miracl-corpus", "ja")
corpus

In [None]:
import json

# hard negatives
with open("./miracl_hard_negs_1000.json") as f:
    hn = json.loads(f.read())
len(hn), list(hn.keys())[:5], hn["0"].keys(), hn["0"]["docids"][:2], hn["0"]["indices"][
    :2
]

In [9]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist


def get_text(corpus_item):
    return corpus_item["title"] + " " + corpus_item["text"]


corpus_dict = {item["docid"]: get_text(item) for item in corpus["train"]}

In [14]:
n_total_pos = 0
n_total_tp = 0

# only evaluate first 100 queries
for item in ds.select(range(100)):
    # query
    query_emb = get_embeddings([query_prefix + item["query"]])

    # passages are set(300 hard negatives + positives)
    positive_docids = [pp["docid"] for pp in item["positive_passages"]]
    positive_texts = [get_text(pp) for pp in item["positive_passages"]]
    hn_docids = hn[item["query_id"]]["docids"][:miracle_n_hard_negs]

    # drop hard negatives in positives
    hn_docids = [docid for docid in hn_docids if docid not in positive_docids]

    # search target
    target_docids = positive_docids + hn_docids
    target_texts = positive_texts + [corpus_dict[docid] for docid in hn_docids]

    # embedding
    tmppath = f'tmp/textembedding-gecko-multilingual@001_{item["query_id"]}.npy'
    if os.path.exists(tmppath):
        target_embs = np.load(tmppath)
    else:
        # use cache
        target_embs = batch_process_embeddings([passage_prefix + text for text in target_texts], auto_truncate=True)
        np.save(tmppath, target_embs)

    # topK
    topk_indices = np.argsort(cdist(query_emb, target_embs, metric="cosine"))[0][
        :miracle_n_recall
    ]

    n_pos = len(positive_docids)
    n_tp = len(
        set(topk_indices) & set(range(len(positive_docids)))
    )  # positives are first indices

    n_total_pos += n_pos
    n_total_tp += n_tp

miracl_recall = n_total_tp / n_total_pos

n_total_pos, n_total_tp, miracl_recall

(195, 156, 0.8)

# Output

In [15]:
model_id = "textembedding-gecko-multilingual@001"
jsts_score = 0.8006039095558688
jsick_score = 0.803561121302977
miracl_recall = 0.8
model_id, jsts_score, jsick_score, miracl_recall


('textembedding-gecko-multilingual@001',
 0.8006039095558688,
 0.803561121302977,
 0.8)

In [None]:
import json

with open(f'./scores/{model_id.replace("/", "_")}.txt', "w") as f:
    f.write(
        json.dumps(
            {
                "model_id": model_id,
                "jsts": jsts_score,
                "jsick": jsick_score,
                "miracl": miracl_recall,
            }
        )
    )
