# 目的
テキストからembeddingを作成する

# Setting

In [16]:
DATA_PATH = "../data"
OUTPUT_PATH = "../output"
MODEL_PATH = "BAAI/bge-small-en-v1.5" # BAAI/bge-large-en-v1.5

# Import

In [17]:
import os

import polars as pl
import numpy as np

from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

In [18]:
import sentence_transformers

assert pl.__version__ == "1.17.1"
assert sentence_transformers.__version__ == "3.3.1"

# Data Load

In [19]:
monster = pl.read_csv(f"{DATA_PATH}/monster.csv")

# train = pl.DataFrame(
#     {
#         "text": [
#             "I love apples",
#             "I hate apples",
#             "I love oranges",
#         ],
#         "num": [1, 2, 3],
#     }
# )

# test = pl.DataFrame(
#     {
#         "text": [
#             "I love bananas",
#             "I hate watermelons",
#             "I love oranges",
#         ],
#         "num": [4, 5, 6],
#     }
# )

# BGE

In [20]:
model = SentenceTransformer(MODEL_PATH)

monster_vec = model.encode(
    monster["description"].to_list(), normalize_embeddings=True
)

print(monster_vec.shape)

(19, 384)


In [23]:
monster_vec_df = pl.DataFrame(monster_vec, schema=[f"vec_{i}" for i in range(monster_vec.shape[1])])

monster_vec_df.write_csv(f"{OUTPUT_PATH}/monster_vec.csv")

In [22]:
monster_vec_df

vec_0,vec_1,vec_2,vec_3,vec_4,vec_5,vec_6,vec_7,vec_8,vec_9,vec_10,vec_11,vec_12,vec_13,vec_14,vec_15,vec_16,vec_17,vec_18,vec_19,vec_20,vec_21,vec_22,vec_23,vec_24,vec_25,vec_26,vec_27,vec_28,vec_29,vec_30,vec_31,vec_32,vec_33,vec_34,vec_35,vec_36,…,vec_347,vec_348,vec_349,vec_350,vec_351,vec_352,vec_353,vec_354,vec_355,vec_356,vec_357,vec_358,vec_359,vec_360,vec_361,vec_362,vec_363,vec_364,vec_365,vec_366,vec_367,vec_368,vec_369,vec_370,vec_371,vec_372,vec_373,vec_374,vec_375,vec_376,vec_377,vec_378,vec_379,vec_380,vec_381,vec_382,vec_383
f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,…,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32
-0.077965,0.022964,0.019223,-0.018083,0.032325,0.011745,0.032781,0.00406,0.00802,-0.00806,-0.024405,-0.091296,0.045534,0.054921,0.011585,-0.041806,0.002423,-0.020468,-0.020397,-0.001301,0.055973,-0.002063,-0.015292,-0.071581,0.029804,0.001119,0.016434,0.030572,0.020838,-0.188068,-0.007308,-0.022373,0.024564,-0.00475,0.021279,0.039242,-0.038893,…,-0.012664,0.019324,-0.038133,0.002214,0.096769,0.023403,0.079889,0.01341,0.006448,0.028502,0.004605,-0.017157,-0.002106,0.060498,0.004493,0.033911,0.056763,0.061261,-0.038509,-0.029108,-0.047004,-0.032966,-0.014701,0.004414,0.036839,0.038166,0.007389,0.038987,-0.019145,-0.00243,-0.014898,0.012385,0.036511,0.011133,-0.009406,0.009783,0.039788
-0.014884,0.022541,0.043756,0.002289,0.040177,0.010712,-0.003189,0.043763,0.02515,0.00011,-0.039374,-0.036549,0.052302,0.053889,0.017337,-0.035785,-0.016317,0.005062,-0.052683,0.001505,0.084764,0.031305,0.027686,-0.067246,0.013003,-0.026565,-0.011453,-0.033936,0.012275,-0.145089,-0.046922,0.009985,0.036808,-0.004146,0.01971,0.035762,-0.007991,…,-0.006943,0.015077,-0.043959,-0.008857,0.059743,0.011725,0.033811,-0.02486,0.041216,0.044299,0.027493,-0.026824,-0.016259,0.069942,-0.021852,0.054542,0.052302,0.035898,-0.004607,-0.066251,-0.07686,-0.003762,0.035128,0.066002,-0.024795,0.019826,0.054408,0.020847,-0.054969,-0.027924,-0.0247,-0.054911,0.010617,-0.010123,0.011126,-0.005766,0.012046
-0.054641,0.032748,-0.023494,-0.016202,0.073966,-0.023759,0.045149,0.041672,-0.001538,0.034817,-0.043592,-0.100374,0.06699,0.065444,0.070154,-0.016053,-0.000729,-0.017264,-0.06022,-0.036901,0.07751,-0.023372,0.02364,-0.028272,0.018902,0.056549,-0.018039,0.008967,0.023857,-0.146831,-0.023547,0.010424,0.018852,-0.003624,0.025339,0.059238,-0.035879,…,0.001137,0.02781,-0.03565,-0.028541,0.063592,0.02007,0.037839,-0.035122,0.028498,0.036762,-0.026525,-0.03442,-0.006664,0.043971,-0.026987,0.045071,0.069922,0.082999,-0.006075,-0.0491,-0.080801,-0.036335,-0.001847,0.02314,-0.000149,0.051705,0.039438,0.025008,-0.017194,-0.031689,0.011338,0.006165,0.052851,0.010329,0.030847,0.018496,0.048241
-0.016118,0.037495,0.005233,0.002177,0.034076,0.025638,0.027348,-0.01658,0.002697,0.023006,-0.004283,-0.073776,0.0445,0.077884,-0.02869,-0.017665,0.032185,-0.018933,-0.059059,-0.020062,0.060735,-0.028042,-0.003799,-0.035579,0.005873,0.009997,-0.007921,0.001245,-0.023354,-0.167889,-0.020511,0.011287,0.022884,-0.025696,0.028107,0.032134,-0.076514,…,-0.004377,0.068226,-0.045036,0.031887,0.056882,-0.027964,0.042912,0.001102,0.03235,0.035656,-0.020694,-0.013029,0.013029,0.040954,0.006401,0.030193,0.058806,0.096739,-0.045359,-0.070963,-0.042388,-0.032401,-0.003212,-0.015205,0.006064,0.047301,-0.00353,0.036536,-0.018317,0.012558,0.010876,0.020435,0.012063,0.017712,0.002896,-0.007478,0.035232
-0.030969,0.029671,0.025217,0.011273,0.030282,-0.027545,0.044776,-0.010657,0.027166,-0.017408,-0.009155,-0.072007,0.045808,0.057593,0.03264,-0.006629,0.046936,-0.035954,-0.086193,0.064729,0.023919,-0.039066,0.018369,-0.033967,-0.02726,-0.012141,-0.016199,-0.032499,0.007721,-0.099665,-0.009589,0.038251,0.045144,0.00308,0.016045,0.066797,-0.029871,…,0.036331,0.021912,-0.022992,-0.010958,0.056966,0.003956,0.019576,-0.052944,-0.007845,-0.032432,0.020516,-0.037194,0.028711,0.06906,0.007494,0.033332,0.084038,0.054585,-0.010468,-0.062249,-0.047304,0.009385,0.076781,-0.015291,0.006752,0.028967,0.077747,0.039335,-0.060094,0.00419,-0.017965,-0.028116,0.025174,0.036188,0.038595,0.032346,-0.022645
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
-0.043604,-0.01633,0.022437,0.008526,0.026973,-0.002945,0.041262,0.042914,0.001558,0.00992,-0.008774,-0.088082,0.068494,0.008362,0.01202,-0.006776,0.025306,0.025445,-0.093748,0.042723,0.049638,-0.039195,-0.006476,-0.018549,0.028023,-0.017319,-0.005304,-0.021703,0.022952,-0.129607,-0.039046,0.045541,0.069881,-0.026783,0.024974,0.027885,-0.03875,…,-0.021446,0.0423,-0.072221,0.030969,0.059142,-0.02162,0.020721,-0.001602,-0.050938,0.008403,-0.016305,-0.037387,0.008752,0.066907,-0.035045,0.039944,0.028628,0.012907,0.013316,-0.03104,-0.021013,-0.007422,-0.014356,0.020029,-0.005095,0.020885,0.055265,0.028703,-0.028946,-0.011472,-0.005227,-0.047324,0.029499,-0.014945,0.055293,0.044754,0.016241
-0.009148,-0.00945,-0.007281,0.008193,-0.007158,-0.008171,-0.029366,0.017687,-0.011854,-0.020912,0.006329,-0.087861,0.033623,0.055625,0.019467,-0.064392,0.002532,0.064619,-0.041591,0.041069,0.039769,0.027565,0.000028,-0.069784,0.011103,0.032302,0.039869,-0.019627,0.006082,-0.139265,-0.024163,-0.005552,0.020221,-0.035591,-0.033374,0.077971,-0.074883,…,-0.02211,0.025374,-0.058065,0.011001,0.027442,0.01347,0.061317,-0.01098,-0.011541,-0.009354,0.007802,-0.0343,-0.022024,0.042012,-0.004099,0.015622,0.099419,-0.005748,-0.017744,-0.039054,0.013192,0.009824,-0.028434,0.048996,-0.00759,0.046679,0.064307,0.029748,0.039456,-0.051768,-0.00688,-0.06079,0.033238,-0.004497,0.018037,0.026905,-0.020925
-0.018451,-0.006669,0.02224,0.017761,0.045688,-0.017289,0.073982,0.036411,0.02542,-0.00881,0.023831,-0.098281,0.057053,0.042081,0.037048,-0.042902,0.014583,-0.014831,-0.039172,0.004187,0.087928,-0.024699,0.018039,-0.013335,0.028231,-0.018123,0.013093,-0.037684,-0.015997,-0.163564,-0.052262,0.009002,0.040367,0.022159,-0.043369,0.046024,-0.080853,…,0.031598,0.001347,-0.038455,0.044549,0.092833,0.012755,0.001102,-0.018938,-0.0163,-0.007077,-0.027333,-0.00903,0.022856,0.035517,-0.018029,-0.010818,0.04974,0.010408,-0.037371,-0.131922,-0.034309,-0.020772,0.022874,0.014186,0.015303,0.024342,0.069979,0.029068,-0.03959,0.010617,-0.032146,-0.060421,0.00105,-0.046938,0.012893,0.049319,0.050983
0.01025,0.007741,-0.043485,0.00086,-0.010995,-0.018369,0.020118,-0.014985,0.023164,0.020806,0.006474,-0.14434,0.026635,0.077319,0.038857,-0.016749,-0.002867,0.044478,-0.058873,-0.021487,0.032312,-0.017969,0.033627,-0.069503,0.046677,0.01497,0.053011,-0.014348,-0.016007,-0.138957,-0.020349,0.029355,0.065407,-0.045179,-0.014726,0.044333,-0.051586,…,0.044937,0.00665,-0.041979,0.027439,0.049314,0.008817,0.007368,-0.03446,-0.051061,-0.022487,-0.041094,-0.052438,0.018063,0.04982,-0.040628,-0.015371,0.085144,-0.013856,-0.005302,-0.075235,-0.036939,0.018379,0.009535,0.048552,0.007761,0.023158,0.105621,0.037334,-0.030315,-0.01212,-0.038821,-0.078842,0.022445,0.053476,-0.018061,0.003009,0.03932
