In [1]:
# Requires transformers>=4.51.0
import torch
import json
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

def format_instruction(instruction, query, doc):
    if instruction is None:
        instruction = 'Given a web search query, retrieve relevant passages that answer the query'
    output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(instruction=instruction,query=query, doc=doc)
    return output

def process_batch_inputs(pairs):
    """Process a batch of query-document pairs"""
    inputs = tokenizer(
        pairs, padding=False, truncation='longest_first',
        return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
    )
    for i, ele in enumerate(inputs['input_ids']):
        inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
    inputs = tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)
    return inputs

def process_single_input(pair):
    """Process a single query-document pair"""
    return process_batch_inputs([pair])

@torch.no_grad()
def compute_batch_scores(pairs, batch_size=4):
    """Compute scores for a batch of query-document pairs
    Uses sum of probabilities for all Yes/No token variants to handle token space issue
    """
    all_scores = []
    
    for i in range(0, len(pairs), batch_size):
        batch_pairs = pairs[i:i+batch_size]
        inputs = process_batch_inputs(batch_pairs)
        logits = model(**inputs).logits[:, -1, :]
        
        # Tính tổng xác suất cho tất cả các biến thể Yes và No
        probs = torch.nn.functional.softmax(logits, dim=-1)
        
        batch_scores = []
        for j in range(len(batch_pairs)):
            # Cộng dồn xác suất của tất cả token "Yes" variants
            true_prob_sum = 0.0
            for token_id in token_true_ids:
                if token_id is not None and token_id < probs.shape[1]:
                    true_prob_sum += probs[j, token_id].item()
            
            # Cộng dồn xác suất của tất cả token "No" variants
            false_prob_sum = 0.0
            for token_id in token_false_ids:
                if token_id is not None and token_id < probs.shape[1]:
                    false_prob_sum += probs[j, token_id].item()
            
            # Tính score: P(Yes) / (P(Yes) + P(No))
            total_prob = true_prob_sum + false_prob_sum
            if total_prob > 0:
                score = true_prob_sum / total_prob
            else:
                score = 0.5
            
            batch_scores.append(score)
        
        all_scores.extend(batch_scores)
        
        # Free memory
        del inputs, logits, probs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    return all_scores

@torch.no_grad()
def compute_single_score(pair, debug=False):
    """Compute score for a single query-document pair (wrapper for batch function)"""
    scores = compute_batch_scores([pair], batch_size=1)
    return scores[0]

print("Loading model...")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    torch.cuda.empty_cache()
    print("Cleared GPU cache")

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B", padding_side='left')

# Load model on GPU if available
if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen3-Reranker-4B",
        torch_dtype=torch.float16,
        device_map="auto"
    ).eval()
    print("Model loaded on GPU with float16")
else:
    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-4B").eval()
    print("Model loaded on CPU")

# 1. Fix Token Space Issue - Dùng encode để lấy token IDs chính xác hơn
# Model có thể muốn sinh "Yes" hoặc " Yes" (có dấu cách trước)
token_yes_list = ["Yes", " Yes", "yes", " yes", "Có", " Có", "có", " có"]  # Thử cả tiếng Việt
token_no_list = ["No", " No", "no", " no", "Không", " Không", "không", " không"]

# Lấy tất cả token IDs hợp lệ bằng cách encode (chính xác hơn convert_tokens_to_ids)
token_true_ids = []
token_false_ids = []

print(f"\nToken ID check (using encode method):")
for token in token_yes_list:
    # Dùng encode và lấy token ID đầu tiên
    encoded = tokenizer.encode(token, add_special_tokens=False)
    if encoded:
        token_id = encoded[0]  # Lấy token ID đầu tiên
        if token_id not in token_true_ids:
            token_true_ids.append(token_id)
            decoded = tokenizer.decode([token_id])
            print(f"  Found 'Yes' variant: '{token}' -> ID: {token_id} (decoded: '{decoded}')")

for token in token_no_list:
    encoded = tokenizer.encode(token, add_special_tokens=False)
    if encoded:
        token_id = encoded[0]
        if token_id not in token_false_ids:
            token_false_ids.append(token_id)
            decoded = tokenizer.decode([token_id])
            print(f"  Found 'No' variant: '{token}' -> ID: {token_id} (decoded: '{decoded}')")

if not token_true_ids or not token_false_ids:
    print("\n⚠️ Warning: No valid token IDs found! Trying fallback...")
    # Fallback: thử encode trực tiếp
    try:
        yes_encoded = tokenizer.encode("Yes", add_special_tokens=False)
        no_encoded = tokenizer.encode("No", add_special_tokens=False)
        if yes_encoded:
            token_true_ids = [yes_encoded[0]]
        if no_encoded:
            token_false_ids = [no_encoded[0]]
    except:
        pass

print(f"\nUsing {len(token_true_ids)} 'Yes' token(s) and {len(token_false_ids)} 'No' token(s)")
if token_true_ids:
    print(f"  Yes token IDs: {token_true_ids}")
if token_false_ids:
    print(f"  No token IDs: {token_false_ids}")

max_length = 8192

prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
# 2. Sửa Suffix - Đơn giản hóa (bỏ <think>)
suffix = "<|im_end|>\n<|im_start|>assistant\n"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)
        
task = 'Given a web search query, retrieve relevant passages that answer the query'

print(f"\nPrefix tokens length: {len(prefix_tokens)}")
print(f"Suffix tokens length: {len(suffix_tokens)}")
print("Model loaded successfully!")

Loading model...
CUDA available: True
GPU device: Tesla T4
GPU memory: 14.74 GB
Cleared GPU cache


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

2025-12-07 12:53:28.638448: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765112008.825553      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765112008.878319      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.06G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

Model loaded on GPU with float16

Token ID check (using encode method):
  Found 'Yes' variant: 'Yes' -> ID: 9454 (decoded: 'Yes')
  Found 'Yes' variant: ' Yes' -> ID: 7414 (decoded: ' Yes')
  Found 'Yes' variant: 'yes' -> ID: 9693 (decoded: 'yes')
  Found 'Yes' variant: ' yes' -> ID: 9834 (decoded: ' yes')
  Found 'Yes' variant: 'Có' -> ID: 34 (decoded: 'C')
  Found 'Yes' variant: ' Có' -> ID: 129016 (decoded: ' Có')
  Found 'Yes' variant: 'có' -> ID: 129133 (decoded: 'có')
  Found 'Yes' variant: ' có' -> ID: 28776 (decoded: ' có')
  Found 'No' variant: 'No' -> ID: 2753 (decoded: 'No')
  Found 'No' variant: ' No' -> ID: 2308 (decoded: ' No')
  Found 'No' variant: 'no' -> ID: 2152 (decoded: 'no')
  Found 'No' variant: ' no' -> ID: 902 (decoded: ' no')
  Found 'No' variant: 'Không' -> ID: 142899 (decoded: 'Không')
  Found 'No' variant: ' Không' -> ID: 129182 (decoded: ' Không')
  Found 'No' variant: 'không' -> ID: 30664 (decoded: 'kh')
  Found 'No' variant: ' không' -> ID: 53037 (decoded

In [2]:
# Load data files
print("Loading data files...")

# Load legal corpus
LEGAL_CORPUS_PATH = "/kaggle/input/vlqa-dataset/legal_corpus.json"
with open(LEGAL_CORPUS_PATH, 'r', encoding='utf-8') as f:
    legal_corpus = json.load(f)

# Build article_id to content mapping
article_id_to_content = {}
for doc in legal_corpus:
    for article in doc.get('content', []):
        aid = article.get('aid')
        content = article.get('content_Article', '')
        if aid is not None and content:
            # Convert aid to string for consistency, but also support int lookup
            aid_str = str(aid)
            article_id_to_content[aid_str] = content
            # Also store with int key if aid is int
            if isinstance(aid, int):
                article_id_to_content[aid] = content

print(f"Loaded {len(article_id_to_content)} articles from legal corpus")

# Load private_test.json to get questions
PRIVATE_TEST_PATH = "/kaggle/input/test-vlqa/private_test.json"
with open(PRIVATE_TEST_PATH, 'r', encoding='utf-8') as f:
    private_test_data = json.load(f)

# Build qid to question mapping
qid_to_question = {}
for item in private_test_data:
    qid = item.get('qid')
    question = item.get('question', '')
    if qid is not None and question:
        qid_to_question[qid] = question

print(f"Loaded {len(qid_to_question)} questions from private_test.json")

# Load results file
RESULTS_PATH = "/kaggle/input/reranker-output/results_ensemble_k20.json"
with open(RESULTS_PATH, 'r', encoding='utf-8') as f:
    results_data = json.load(f)

print(f"Loaded {len(results_data)} samples from results file")
print(f"First sample: qid={results_data[0]['qid']}, num_laws={len(results_data[0]['relevant_laws'])}")

Loading data files...
Loaded 119270 articles from legal corpus
Loaded 627 questions from private_test.json
Loaded 627 samples from results file
First sample: qid=1, num_laws=20


In [3]:
# Process all samples and generate top-10 results
print("="*80)
print("Processing all samples...")
print("="*80)

final_results = []
final_results_with_scores = []  # Results with scores included
top_k = 10  # Top-10 results
batch_size = 1  # Batch size for faster processing (adjust based on GPU memory)

for sample_idx, sample in enumerate(tqdm(results_data, desc="Processing samples")):
    qid = sample['qid']
    relevant_laws = sample['relevant_laws']  # List of 20 aid values
    
    # Load question from private_test.json based on qid
    query = qid_to_question.get(qid, "")
    if not query:
        # Try with int/string conversion
        query = qid_to_question.get(int(qid), "") or qid_to_question.get(str(qid), "")
    
    if not query:
        # Still add empty result
        final_results.append({
            'qid': qid,
            'relevant_laws': []
        })
        final_results_with_scores.append({
            'qid': qid,
            'relevant_laws': []
        })
        continue
    
    # Get documents for the 20 relevant laws
    documents = []
    valid_aids = []
    for aid in relevant_laws:
        content = article_id_to_content.get(aid) or article_id_to_content.get(str(aid))
        if content:
            documents.append(content)
            valid_aids.append(aid)
    
    if not documents:
        final_results.append({
            'qid': qid,
            'relevant_laws': []
        })
        final_results_with_scores.append({
            'qid': qid,
            'relevant_laws': []
        })
        continue
    
    # Format pairs for reranking
    pairs = [format_instruction(task, query, doc) for doc in documents]
    
    # Compute scores in batches (faster!)
    scores = compute_batch_scores(pairs, batch_size=batch_size)
    
    # Create results with aid and score
    results_with_scores = []
    for aid, score in zip(valid_aids, scores):
        results_with_scores.append({
            'aid': aid,
            'score': score
        })
    
    # Sort by score descending and take top-k
    results_with_scores.sort(key=lambda x: x['score'], reverse=True)
    top_k_aids = [item['aid'] for item in results_with_scores[:top_k]]
    
    # Add to final results (without scores)
    final_results.append({
        'qid': qid,
        'relevant_laws': top_k_aids
    })
    
    # Add to final results with scores
    final_results_with_scores.append({
        'qid': qid,
        'relevant_laws': [{'aid': item['aid'], 'score': item['score']} for item in results_with_scores[:top_k]]
    })

print(f"\n{'='*80}")
print(f"Processing complete! Processed {len(final_results)} samples.")
print(f"{'='*80}")

Processing all samples...


Processing samples:   0%|          | 0/627 [00:00<?, ?it/s]You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Processing samples: 100%|██████████| 627/627 [1:19:33<00:00,  7.61s/it]


Processing complete! Processed 627 samples.





In [4]:
# Save results to file (without scores - for submission)
print("\n" + "="*80)
print("Saving results...")
print("="*80)

output_file = "results_qwen_reranker_top10.json"
with open(output_file, 'w', encoding='utf-8') as f:
    json.dump(final_results, f, ensure_ascii=False, indent=2)

print(f"Results saved to: {output_file}")
print(f"Total samples: {len(final_results)}")
print(f"Format: [{{'qid': int, 'relevant_laws': [int, ...]}}, ...]")
print(f"Each sample has top-{top_k} relevant laws")

# Save results with scores (for analysis)
output_file_with_scores = "results_qwen_reranker_top10_with_scores.json"
with open(output_file_with_scores, 'w', encoding='utf-8') as f:
    json.dump(final_results_with_scores, f, ensure_ascii=False, indent=2)

print(f"\nResults with scores saved to: {output_file_with_scores}")
print(f"Format: [{{'qid': int, 'relevant_laws': [{{'aid': int, 'score': float}}, ...]}}, ...]")

# Show statistics
samples_with_results = sum(1 for r in final_results if len(r.get('relevant_laws', [])) > 0)
avg_laws = sum(len(r.get('relevant_laws', [])) for r in final_results) / len(final_results) if final_results else 0

print(f"\nStatistics:")
print(f"  Samples with results: {samples_with_results}/{len(final_results)} ({samples_with_results/len(final_results)*100:.1f}%)")
print(f"  Average laws per sample: {avg_laws:.2f}")
print("="*80)


Saving results...
Results saved to: results_qwen_reranker_top10.json
Total samples: 627
Format: [{'qid': int, 'relevant_laws': [int, ...]}, ...]
Each sample has top-10 relevant laws

Results with scores saved to: results_qwen_reranker_top10_with_scores.json
Format: [{'qid': int, 'relevant_laws': [{'aid': int, 'score': float}, ...]}, ...]

Statistics:
  Samples with results: 627/627 (100.0%)
  Average laws per sample: 10.00


In [5]:
# Preview first few results
print("\n" + "="*80)
print("Preview of first 5 results:")
print("="*80)
for i, result in enumerate(final_results[:5]):
    print(f"\n{i+1}. qid={result['qid']}, num_laws={len(result['relevant_laws'])}, top_laws={result['relevant_laws'][:5]}")
print("="*80)


Preview of first 5 results:

1. qid=1, num_laws=10, top_laws=[2759, 58057, 420, 124, 122]

2. qid=58, num_laws=10, top_laws=[1473, 1480, 504, 1483, 515]

3. qid=60, num_laws=10, top_laws=[643, 1699, 2039, 1453, 1451]

4. qid=114, num_laws=10, top_laws=[12801, 150, 142, 148, 149]

5. qid=124, num_laws=10, top_laws=[36, 2325, 53106, 9189, 2145]


In [6]:
# Optional: Show sample statistics
print("\n" + "="*80)
print("Sample Statistics")
print("="*80)

# Count distribution of number of laws
law_counts = {}
for result in final_results:
    count = len(result.get('relevant_laws', []))
    law_counts[count] = law_counts.get(count, 0) + 1

print("\nDistribution of number of laws per sample:")
for count in sorted(law_counts.keys()):
    print(f"  {count} laws: {law_counts[count]} samples")

print("="*80)


Sample Statistics

Distribution of number of laws per sample:
  10 laws: 627 samples
