# Doc2Query

In [None]:
!pip install -q transformers torch
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("castorini/doc2query-t5-base-msmarco", legacy=False)
model = T5ForConditionalGeneration.from_pretrained("castorini/doc2query-t5-base-msmarco").to("cuda")
model.eval()
print("Model loaded")

In [None]:
# Load documents
docs = []
with open('/content/drive/MyDrive/hqf_de/collection_100k.tsv') as f:
    for line in f:
        parts = line.strip().split('\t', 1)
        if len(parts) == 2:
            docs.append((parts[0], parts[1]))
print(f"Loaded {len(docs):,} documents")

In [None]:
import time
import sys

BATCH_SIZE = 8
NUM_QUERIES = 3
output_path = '/content/drive/MyDrive/hqf_de/doc2query_100k.tsv'

results = []
start = time.time()

for i in range(0, len(docs), BATCH_SIZE):
    batch = docs[i:i + BATCH_SIZE]
    texts = [d[1] for d in batch]
    
    inputs = tokenizer(texts, max_length=512, truncation=True, padding=True, return_tensors="pt").to("cuda")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=64,
            num_beams=NUM_QUERIES,
            num_return_sequences=NUM_QUERIES,
            early_stopping=True
        )
    
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    for j, (doc_id, text) in enumerate(batch):
        queries = decoded[j * NUM_QUERIES:(j + 1) * NUM_QUERIES]
        expanded = text + " " + " ".join(set(queries))
        results.append((doc_id, expanded))
    
    # Progress
    done = i + BATCH_SIZE
    if done % 5000 == 0:
        elapsed = time.time() - start
        rate = done / elapsed
        eta = (len(docs) - done) / rate if rate > 0 else 0
        print(f"[{done / len(docs) * 100:5.1f}%] {done:,}/{len(docs):,} | {rate:.1f} docs/sec | ETA: {eta / 60:.1f} min")
        sys.stdout.flush()

# Save
with open(output_path, 'w') as f:
    for doc_id, text in results:
        f.write(f"{doc_id}\t{text}\n")

print(f"\nDone! {(time.time() - start) / 60:.1f} min total")