In [None]:
!pip install -U pymilvus
!pip install "pymilvus[model]"

In [None]:
from google.colab import drive
from pymilvus import MilvusClient, DataType, model
import pandas as pd

In [None]:
drive.mount("/content/drive")

In [None]:
client = MilvusClient("sfc_syllabus.db")

In [None]:
# embeddingモデル定義
embedding_fn = model.dense.SentenceTransformerEmbeddingFunction(
    model_name="all-MiniLM-L6-v2",
    device="cpu",
)
dim = embedding_fn.dim

In [None]:
# スキーマ定義
schema = MilvusClient.create_schema(
    auto_id=True,
    enable_dynamic_field=False,
)

schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True) # ID
schema.add_field(field_name="subject_name", datatype=DataType.VARCHAR, max_length=64) # 科目名
schema.add_field(field_name="faculty", datatype=DataType.BOOL) # 学部
schema.add_field(field_name="category", datatype=DataType.VARCHAR, max_length=32) # 分野
schema.add_field(field_name="credits", datatype=DataType.INT8) # 単位
schema.add_field(field_name="year", datatype=DataType.INT16) # 年度
schema.add_field(field_name="semester", datatype=DataType.VARCHAR, max_length=1) # 学期
# schema.add_field(field_name="day", datatype=DataType.ARRAY, element_type=DataType.VARCHAR, max_capacity=4, max_length=1, nullable=True) # 曜日
# schema.add_field(field_name="period", datatype=DataType.ARRAY, element_type=DataType.INT8, max_capacity=4, nullable=True) # 時限
schema.add_field(field_name="delivery_mode", datatype=DataType.VARCHAR, max_length=8) # 実施形態
schema.add_field(field_name="language", datatype=DataType.VARCHAR, max_length=16) # 言語
schema.add_field(field_name="english_support", datatype=DataType.BOOL) # 英語サポート
schema.add_field(field_name="selection", datatype=DataType.VARCHAR, max_length=4) # 履修選抜
schema.add_field(field_name="giga", datatype=DataType.BOOL) # GIGA
schema.add_field(field_name="summary", datatype=DataType.FLOAT_VECTOR, dim=dim) # 講義概要
schema.add_field(field_name="goals", datatype=DataType.FLOAT_VECTOR, dim=dim) # 主題と目標
schema.add_field(field_name="schedule", datatype=DataType.FLOAT_VECTOR, dim=dim) # 授業計画
schema.add_field(field_name="url", datatype=DataType.VARCHAR, max_length=64) # URL

In [None]:
# インデックス定義
index_params = client.prepare_index_params()

index_params.add_index(
    field_name="summary",
    metric_type="COSINE",
    index_type="FLAT",
)

In [None]:
# コレクション作成
collection_name = "sfc_syllabus_collection"

if client.has_collection(collection_name=collection_name):
    client.drop_collection(collection_name=collection_name)

client.create_collection(
    collection_name=collection_name,
    schema=schema,
    index_params=index_params,
)

In [None]:
df = pd.read_csv("/content/drive/MyDrive/sfc-llm/sfc_syllabus.csv")

data_list = []

for index, row in df.iterrows():
    if row["学部・研究科"] not in ("総合政策・環境情報学部", "政策・メディア研究科"):
        print(f"error {index}")
        continue

    # 欠損値を除外
    if pd.isna(row["授業概要"]) or pd.isna(row["主題と目標"]) or pd.isna(row["授業計画"]):
        print(f"error {index}")
        continue

    print(index)

    docs = [row["授業概要"], row["主題と目標"], row["授業計画"]]
    vectors = embedding_fn.encode_documents(docs)

    data = {
        "subject_name": row["科目名"],
        "faculty": True if row["学部・研究科"] == "総合政策・環境情報学部" else False,
        "category": row["分野"],
        "credits": int(row["単位"][0]),
        "year": int(row["開講年度・学期"].split()[0]),
        "semester": row["開講年度・学期"].split()[1][0],
        "delivery_mode": row["実施形態"],
        "language": row["授業で使う言語"],
        "english_support": True if row["英語サポート"] == "あり" else False,
        "selection": row["履修制限"],
        "giga": True if row["GIGA"] == "対象" else False,
        "url": row["URL"],
        "summary": vectors[0],
        "goals": vectors[1],
        "schedule": vectors[2]
    }

    data_list.append(data)

res = client.insert(collection_name=collection_name, data=data_list)
print(res)

In [None]:
# 検索クエリ
queries = [""]
query_vectors = embedding_fn.encode_queries(queries)
res = client.search(
    collection_name=collection_name,
    anns_field="summary",
    data=query_vectors,
    limit=5,
    search_params={"metric_type": "COSINE"},
    output_fields=["subject_name", "url"],
)

for hits in res:
    for hit in hits:
        print(hit)