In [None]:
import os
import json
import torch
import random

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer, Trainer, TrainingArguments

In [3]:
dataset = []
with open('../dataset/processed_documents_queries.jsonl', 'r', encoding='utf-8') as f:
    for line in f:
        dataset.append(json.loads(line.strip()))

In [4]:
# 문서 ID와 내용을 매핑하는 딕셔너리 생성
doc_dict = {doc['docid']: doc['content'] for doc in dataset}

# 모든 문서의 docid 리스트
all_docids = [doc['docid'] for doc in dataset]

# 2. 데이터셋 생성

positive_samples = []  # 양성 샘플 리스트
negative_samples = []  # 음성 샘플 리스트

for doc in dataset:
    docid = doc['docid']
    content = doc['content']
    questions = [doc['question1'], doc['question2'], doc['question3']]

    # 양성 샘플 생성
    for question in questions:
        positive_samples.append({
            'query': question,
            'passage': content,
            'label': 1
        })

        # 음성 샘플 생성 (다른 문서 중 하나를 랜덤하게 선택)
        negative_docids = [d for d in all_docids if d != docid]
        negative_docid = random.choice(negative_docids)
        negative_content = doc_dict[negative_docid]

        negative_samples.append({
            'query': question,
            'passage': negative_content,
            'label': 0
        })

In [5]:
dataset = positive_samples + negative_samples
random.shuffle(dataset)

In [6]:
class RetrievalDataset(Dataset):
    def __init__(self, data, question_tokenizer, context_tokenizer, max_length=512):
        self.data = data
        self.question_tokenizer = question_tokenizer
        self.context_tokenizer = context_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        query = item['query']
        passage = item['passage']
        label = item['label']

        query_inputs = self.question_tokenizer(
            query,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        passage_inputs = self.context_tokenizer(
            passage,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        return {
            'query_input_ids': query_inputs['input_ids'].squeeze(0),
            'query_attention_mask': query_inputs['attention_mask'].squeeze(0),
            'passage_input_ids': passage_inputs['input_ids'].squeeze(0),
            'passage_attention_mask': passage_inputs['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.float)
        }

In [None]:
question_model_name = "snumin44/biencoder-ko-bert-question"
context_model_name = "snumin44/biencoder-ko-bert-context"

# 토크나이저 로드
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_model_name)
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(context_model_name)

# 데이터셋 및 데이터로더 생성
train_dataset = RetrievalDataset(dataset, question_tokenizer, context_tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

question_encoder = DPRQuestionEncoder.from_pretrained(question_model_name).to(device)
context_encoder = DPRContextEncoder.from_pretrained(context_model_name).to(device)

# 손실 함수 및 옵티마이저 정의
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(list(question_encoder.parameters()) + list(context_encoder.parameters()), lr=2e-5)

In [None]:
epochs = 3

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    question_encoder.train()
    context_encoder.train()

    total_loss = 0

    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()

        # 데이터 로드
        query_input_ids = batch['query_input_ids'].to(device)
        query_attention_mask = batch['query_attention_mask'].to(device)
        passage_input_ids = batch['passage_input_ids'].to(device)
        passage_attention_mask = batch['passage_attention_mask'].to(device)
        labels = batch['label'].to(device)

        # 임베딩 생성
        query_embeddings = question_encoder(
            input_ids=query_input_ids,
            attention_mask=query_attention_mask
        ).pooler_output  # [batch_size, hidden_size]

        passage_embeddings = context_encoder(
            input_ids=passage_input_ids,
            attention_mask=passage_attention_mask
        ).pooler_output  # [batch_size, hidden_size]

        # 유사도 계산 (내적)
        scores = torch.sum(query_embeddings * passage_embeddings, dim=-1)  # [batch_size]

        # 손실 계산
        loss = criterion(scores, labels)

        # 역전파 및 옵티마이저 스텝
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Average Loss: {avg_loss:.4f}")

In [None]:
os.makedirs("./dpr", exist_ok=True)
question_encoder.save_pretrained('./dpr/question_encoder')
context_encoder.save_pretrained('./dpr/context_encoder')