In [6]:
# ==== 0) 환경 준비 ====
import os, random, math
import pandas as pd
import torch
import transformers

model_name = "jxm/cde-small-v2"
model = transformers.AutoModel.from_pretrained(model_name, trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
torch.set_grad_enabled(False)

# 프리픽스 정의
query_prefix    = "search_query: "
document_prefix = "search_document: "

# 모델 config에서 minicorpus 크기 확인
minicorpus_size = getattr(model.config, "transductive_corpus_size", 512)
print("CDE v2 minicorpus size:", minicorpus_size)

df_train = pd.read_csv("/home/alpaco/sryang/Training.csv")
df_valid = pd.read_csv("/home/alpaco/sryang/validation.csv")

train_docs = df_train.iloc[:,0].dropna().astype(str).tolist()  
val_queries = df_valid.iloc[:,0].dropna().astype(str).tolist() 

minicorpus_size = getattr(model.config, "transductive_corpus_size", 512)
if len(train_docs) >= minicorpus_size:
    minicorpus_docs = random.sample(train_docs, k=minicorpus_size)
else:
    reps = math.ceil(minicorpus_size / max(1, len(train_docs)))
    minicorpus_docs = (train_docs * reps)[:minicorpus_size]

mc_tok = tokenizer(
    [document_prefix + d for d in minicorpus_docs],
    truncation=True, padding=True, max_length=768, return_tensors="pt"
).to(device)

# ====  1단계 임베딩 ====
batch_size = 32
dataset_embeddings = []
for i in range(0, mc_tok["input_ids"].size(0), batch_size):
    batch = {k: v[i:i+batch_size] for k, v in mc_tok.items()}
    with torch.no_grad():
        emb = model.first_stage_model(**batch)
    dataset_embeddings.append(emb)
dataset_embeddings = torch.cat(dataset_embeddings, dim=0).to(device)

# ====  문서 임베딩 ====
docs_tok = tokenizer(
    [document_prefix + d for d in train_docs],
    truncation=True, padding=True, max_length=768, return_tensors="pt"
).to(device)

doc_embeddings = []
for i in range(0, docs_tok["input_ids"].size(0), batch_size):
    batch = {k: v[i:i+batch_size] for k, v in docs_tok.items()}
    with torch.no_grad():
        emb = model.second_stage_model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            dataset_embeddings=dataset_embeddings
        )
        emb = torch.nn.functional.normalize(emb, p=2, dim=1)
    doc_embeddings.append(emb)
doc_embeddings = torch.cat(doc_embeddings, dim=0).to(device)

# ====  질의 임베딩 ====
q_tok = tokenizer(
    [query_prefix + q for q in val_queries],
    truncation=True, padding=True, max_length=512, return_tensors="pt"
).to(device)

query_embeddings = []
for i in range(0, q_tok["input_ids"].size(0), batch_size):
    batch = {k: v[i:i+batch_size] for k, v in q_tok.items()}
    with torch.no_grad():
        emb = model.second_stage_model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            dataset_embeddings=dataset_embeddings
        )
        emb = torch.nn.functional.normalize(emb, p=2, dim=1)
    query_embeddings.append(emb)
query_embeddings = torch.cat(query_embeddings, dim=0).to(device)

print("문서 임베딩 크기:", doc_embeddings.shape)
print("질의 임베딩 크기:", query_embeddings.shape)


A new version of the following files was downloaded from https://huggingface.co/jxm/cde-small-v2:
- model.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Disabled 23 dropout modules from model type <class 'transformers_modules.jxm.cde-small-v2.4e1d021a6c3fd7ce8aa0a7204057eee5ae61d390.model.BiEncoder'>
Disabled 46 dropout modules from model type <class 'transformers_modules.jxm.cde-small-v2.4e1d021a6c3fd7ce8aa0a7204057eee5ae61d390.model.ContextualDocumentEmbeddingTransformer'>
CDE v2 minicorpus size: 512




문서 임베딩 크기: torch.Size([51628, 768])
질의 임베딩 크기: torch.Size([6640, 768])


In [None]:
torch.save(dataset_embeddings.cpu(), "/home/alpaco/sryang/embedding_result/cde_minicorpus.pt")
torch.save(doc_embeddings.cpu(), "/home/alpaco/sryang/embedding_result/cde_doc.pt")
torch.save(query_embeddings.cpu(), "/home/alpaco/sryang/embedding_result/cde_query_emb.pt")