# LLM Expansion

In [None]:
!pip install -q transformers accelerate bitsandbytes torch
!nvidia-smi

In [None]:
from google.colab import userdata
HF_TOKEN = userdata.get('HF_TOKEN')

from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token

quant = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(
    model_name, token=HF_TOKEN, quantization_config=quant,
    torch_dtype=torch.float16, device_map="auto"
)
print("Model loaded")

In [None]:
PROMPT = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Add 2-3 factual sentences to expand this passage. Only add information that is factually consistent.

Passage: {text}

Expansion:<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

def expand(text):
    prompt = PROMPT.format(text=text[:500])
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=150, temperature=0.7, do_sample=True, pad_token_id=tokenizer.pad_token_id)
    response = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response.strip()

In [None]:
# Load documents
docs = []
with open('/content/drive/MyDrive/hqf_de/collection_100k.tsv') as f:
    for line in f:
        parts = line.strip().split('\t', 1)
        if len(parts) == 2:
            docs.append((parts[0], parts[1]))
print(f"Loaded {len(docs):,} documents")

In [None]:
from tqdm import tqdm
import time

output_path = '/content/drive/MyDrive/hqf_de/expanded_100k.tsv'
checkpoint_path = '/content/drive/MyDrive/hqf_de/checkpoint.txt'

# Resume from checkpoint
start = 0
try:
    with open(checkpoint_path) as f:
        start = int(f.read().strip())
    print(f"Resuming from {start}")
except:
    pass

# Process
t0 = time.time()
with open(output_path, 'a') as out:
    for i, (doc_id, text) in enumerate(tqdm(docs[start:], initial=start, total=len(docs))):
        try:
            expansion = expand(text)
            expanded = text + " " + expansion
        except:
            expanded = text
        
        out.write(f"{doc_id}\t{expanded}\n")
        
        # Checkpoint every 1000
        if (start + i + 1) % 1000 == 0:
            out.flush()
            with open(checkpoint_path, 'w') as cp:
                cp.write(str(start + i + 1))
            elapsed = time.time() - t0
            rate = (i + 1) / elapsed
            print(f"Checkpoint: {start + i + 1}, Rate: {rate:.2f} docs/sec")

print(f"\nDone! Total time: {(time.time() - t0) / 3600:.1f} hours")