In [1]:
import os
import json
import random
import colbert
import numpy as np
import pandas as pd

import matplotlib.cm as cm
import matplotlib.pyplot as plt

from sklearn.manifold import TSNE

from colbert import Indexer, Searcher
from colbert.data import Queries, Collection
from colbert.infra import Run, RunConfig, ColBERTConfig

from matplotlib import font_manager, rc
font_path = '/usr/share/fonts/truetype/nanum/NanumGothic.ttf'
font = font_manager.FontProperties(fname = font_path).get_name()
rc('font', family = font)

In [2]:
file_path = '../dataset/labeled_documents.jsonl'
data = pd.read_json(file_path, lines=True)

documents = data['content'].tolist()
queries = data['question'].tolist()

In [None]:
documents

In [4]:
# Triples 데이터를 리스트 형식으로 변환
triples_data = []
for q_idx, query in enumerate(queries):
    pos_idx = q_idx  # Positive example은 동일 문서의 인덱스
    
    # Negative example은 다른 문서의 인덱스 선택
    neg_idx = random.randint(0, len(documents) - 1)
    while neg_idx == pos_idx:  # Negative example이 positive와 같으면 안 됨
        neg_idx = random.randint(0, len(documents) - 1)
    
    # [query_idx, positive_idx, negative_idx] 형식으로 저장
    # triples_data.append([q_idx, pos_idx, neg_idx])
    triples_data.append(f'{q_idx}, {pos_idx}, {neg_idx}')

In [5]:
# ColBERT 학습을 위하여 학습 데이터를 파일에 저장
collection_file = '../dataset/collection.tsv'
query_file = '../dataset/query.tsv'
triples_file = '../dataset/triples'

with open(collection_file, 'w') as f:
  for i,item in enumerate(documents):
    f.write(f'{i}\t{item}\n')

with open(query_file, 'w') as f:
  for i,item in enumerate(queries):
    f.write(f'{i}\t{item}\n')

with open(triples_file, 'w') as f:
  for i,item in enumerate(triples_data):
    f.write(f'[{item}]\n')

In [None]:
from colbert import Trainer

with Run().context(RunConfig(nranks=1, experiment="sample_ko_new")):

    config = ColBERTConfig(
        bsize=24,
        root="./experiments",
    )

    trainer = Trainer(
        triples=triples_file,
        queries=query_file,
        collection=collection_file,
        config=config,
    )

    # Pretrained model을 한국어 기반 모델로 설정해 준다.
    checkpoint_path = trainer.train(checkpoint='hunkim/sentence-transformer-klue')
    print(f"Saved checkpoint to {checkpoint_path}...")

In [None]:
!find experiments/sample_ko_new -name colbert

In [None]:
checkpoint = './experiments/sample_ko_new/none/2024-10/06/13.57.50/checkpoints/colbert'
index_name = 'sample_ko_new'

with Run().context(RunConfig(nranks=1, experiment='notebook')):
    config = ColBERTConfig(nbits=2, kmeans_niters=4)
    indexer = Indexer(checkpoint=checkpoint, config=config)
    indexer.index(name=index_name, collection=documents, overwrite=True)

: 

In [None]:
with Run().context(RunConfig(experiment='notebook')):
    searcher = Searcher(index=index_name, collection=documents)

In [None]:
# 문서 임베딩 추출: encode 메서드를 사용해 문서 임베딩을 추출합니다.
document_embeddings = []
for doc in documents:
    # GPU에 있는 텐서를 CPU로 옮기고, numpy 배열로 변환 후 첫 번째 차원 제거
    doc_embedding = searcher.encode([doc]).cpu().numpy().squeeze()  
    document_embeddings.append(doc_embedding)

# numpy 배열로 변환
document_embeddings = np.vstack(document_embeddings)

# 확인 출력
print(f"Document embeddings shape: {document_embeddings.shape}")


In [None]:
# t-SNE 차원 축소
tsne = TSNE(n_components=2, perplexity=30, n_iter=300)
reduced_embeddings = tsne.fit_transform(document_embeddings)

# 도메인별 색상 지정
domains = [doc.metadata.get("domain") for doc in documents]
unique_domains = list(set(domains))
colors = cm.rainbow(np.linspace(0, 1, len(unique_domains)))
domain_to_color = {domain: color for domain, color in zip(unique_domains, colors)}

# t-SNE 시각화
plt.figure(figsize=(12, 8))
for i, doc in enumerate(documents):
    domain = doc.metadata.get("domain")
    color = domain_to_color.get(domain)
    plt.scatter(reduced_embeddings[i, 0], reduced_embeddings[i, 1], color=color, label=domain, alpha=0.6, s=10)

# 범례 추가 및 그래프 설정
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), title="Domains", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize='small')
plt.title('ColBERT Document Embeddings Visualization', fontsize=16)
plt.xlabel('t-SNE Axis 1', fontsize=14)
plt.ylabel('t-SNE Axis 2', fontsize=14)
plt.grid(True)
plt.tight_layout()
plt.show()