<a href="https://colab.research.google.com/github/tanatet8/Colab_Script/blob/main/Reasoning_Format_Fix_And_QC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================
# FORMAT FIXER COMPLETE VERSION - For Colab
# แก้ไขครบทุกปัญหาที่พบ - แบ่ง Cell ชัดเจน
# ============================================

# %% [markdown]
# # Format Fixer Complete - แก้ไข MD Files ให้มี Format สมบูรณ์

# %%
# ============================================
# Cell 1: Mount Drive & Import Libraries
# ============================================

from google.colab import drive
drive.mount('/content/drive')

import re
import collections
from pathlib import Path

print("✅ Libraries imported")

In [None]:
# %%
# ============================================
# Cell 2: Configuration - แก้ Path ตรงนี้
# ============================================

# ✏️ ตั้งพาธไฟล์ของคุณ
SRC = '/content/drive/MyDrive/Dataset_Curation/causal_batch_01.md'
DST_FIXED = SRC.replace('.md', '_FIXED_COMPLETE.md')

# Batch id ที่ถูกต้อง
BATCH_ID_EXPECT = 'causal_batch_01'

print("🎯 Configuration loaded")
print(f"Source: {SRC}")
print(f"Output: {DST_FIXED}")

# %%
# ============================================
# Cell 3: Define Mappings & Constants
# ============================================

# Canonical names สำหรับ reasoning_type (แปลง alias → ชื่อมาตรฐาน)
REASONING_TYPE_CANONICAL = {
    # Aliases → Canonical name
    'causal': 'causal_reasoning',
    'counterfactual': 'counterfactual_reasoning',
    'symbolic': 'symbolic_reasoning',
    'deductive': 'deductive_reasoning',
    'meta': 'meta_reasoning',
    'probabilistic': 'probabilistic_reasoning',
    'commonsense': 'commonsense_reasoning',
    'analogical': 'analogical_reasoning',
    'temporal': 'temporal_reasoning',
    'spatial': 'spatial_reasoning',
}

# Tier mapping
TIER_MAPPING = {
    # Tier 1
    'symbolic_reasoning': 1, 'deductive_reasoning': 1,
    'if_then_only_if_iff': 1, 'contrapositive_xor': 1,
    'contradiction_trap': 1, 'logic_trap': 1, 'fallacy_detection': 1,
    'symbolic_recursion': 1, 'logic_tree': 1, 'belief_modeling': 1,

    # Tier 2
    'causal_reasoning': 2, 'counterfactual_reasoning': 2,
    'emotional_behavioral_cause': 2, 'daily_life_reasoning': 2,
    'temporal_reasoning': 2, 'spatial_reasoning': 2,
    'probabilistic_reasoning': 2, 'commonsense_reasoning': 2,

    # Tier 3
    'meta_reasoning': 3, 'ambiguity_detection': 3,
    'ambiguity_resolution': 3, 'weak_evidence_uncertainty': 3,
    'language_driven_inference': 3, 'structural_analogy': 3,
    'multi_hop_justification': 3,

    # Tier 4
    'belief_revision': 4, 'epistemic_reasoning': 4,
    'self_consistency_logic': 4, 'ontological_shift': 4,

    # Tier 5
    'multi_agent_simulation': 5, 'perspective_reasoning': 5,
    'recursive_inference': 5, 'perspective_clash': 5,

    # Tier 6
    'moral_ambiguity_tradeoff': 6, 'identity_loop_reasoning': 6,
    'ethical_dilemma_decomposition': 6, 'planning_goal_based_reasoning': 6,
    'deontic_reasoning': 6, 'narrative_causal_reasoning': 6,
    'philosophical_logic': 6, 'analogical_reasoning': 6,
    'heuristic_reasoning': 6, 'multi_lingual_reasoning': 6,
}

print(f"✅ Loaded {len(REASONING_TYPE_CANONICAL)} canonical mappings")
print(f"✅ Loaded {len(TIER_MAPPING)} tier mappings")

In [None]:
# %%
# ============================================
# Cell 4: Define Metadata Field Order
# ============================================

# Metadata fields ที่ต้องมี (ตามลำดับ)
REQUIRED_METADATA_FIELDS = [
    'prompt_id',
    'batch_id',
    'reasoning_type',
    'sub_type',
    'difficulty',
    'language',
    'domain_context',
    'tier',
    'model_size'
]

# Optional metadata fields (เรียงตามลำดับถ้ามี)
OPTIONAL_METADATA_FIELDS = [
    'contains_statistics',
    'has_numerical_estimate',
    'requires_visualization',
    'symbolic_risk',
    'contains_fallacy_risk',
    'confidence_level_expected',
    'is_behavior_driven',
    'concept_tags',
    'fallacy',
    'fallacy_type',
    'chain_depth',
    'tone_style',
    'self_critique',
    'belief_tracking',
    'eval_standard',
    'reasoning_path_trace'
]

print(f"✅ Required fields: {len(REQUIRED_METADATA_FIELDS)}")
print(f"✅ Optional fields: {len(OPTIONAL_METADATA_FIELDS)}")

In [None]:
# %%
# ============================================
# Cell 5: Helper Functions - Basic
# ============================================

def normalize_reasoning_type(rtype: str) -> str:
    """Normalize reasoning type to canonical name"""
    rtype = rtype.strip().lower()

    # Check if it's an alias
    if rtype in REASONING_TYPE_CANONICAL:
        return REASONING_TYPE_CANONICAL[rtype]

    # Already canonical or unknown
    return rtype

def infer_tier(reasoning_type: str) -> str:
    """Infer tier from reasoning type"""
    rtype = reasoning_type.strip().lower()

    if rtype in TIER_MAPPING:
        return str(TIER_MAPPING[rtype])

    # Check canonical name
    canonical = normalize_reasoning_type(rtype)
    if canonical in TIER_MAPPING:
        return str(TIER_MAPPING[canonical])

    return '2'  # Default

def parse_metadata(meta_text: str) -> dict:
    """Parse metadata text into dictionary"""
    metadata = {}

    for line in meta_text.split('\n'):
        if ':' in line:
            key, value = line.split(':', 1)
            key = key.strip()
            value = value.strip()

            # Special handling for boolean values
            if value.lower() in ['true', 'false']:
                value = value.lower()

            metadata[key] = value

    return metadata

print("✅ Basic helper functions defined")

In [None]:
# %%
# ============================================
# Cell 6: Helper Functions - Formatting
# ============================================

def format_metadata(metadata: dict, prompt_num: int) -> str:
    """Format metadata with correct order and values"""
    formatted = []

    # Normalize prompt_id
    if 'prompt_id' not in metadata or not metadata['prompt_id']:
        metadata['prompt_id'] = f"{BATCH_ID_EXPECT}_p{prompt_num:03d}"
    elif not metadata['prompt_id'].startswith(BATCH_ID_EXPECT):
        # Extract number and reformat
        num_match = re.search(r'\d+', metadata['prompt_id'])
        if num_match:
            num = num_match.group()
            metadata['prompt_id'] = f"{BATCH_ID_EXPECT}_p{int(num):03d}"

    # Ensure batch_id is correct
    metadata['batch_id'] = BATCH_ID_EXPECT

    # Normalize reasoning_type
    if 'reasoning_type' in metadata:
        metadata['reasoning_type'] = normalize_reasoning_type(metadata['reasoning_type'])

    # Ensure tier exists and normalize model_size
    if 'tier' not in metadata:
        if 'reasoning_type' in metadata:
            metadata['tier'] = infer_tier(metadata['reasoning_type'])
        else:
            metadata['tier'] = '2'

    # Check if tier value is actually model_size (e.g., "13B")
    if 'tier' in metadata:
        tier_val = str(metadata['tier'])
        if 'B' in tier_val.upper():
            metadata['model_size'] = tier_val.upper()
            metadata['tier'] = infer_tier(metadata.get('reasoning_type', ''))
        elif tier_val.isdigit() and int(tier_val) > 6:
            metadata['model_size'] = f"{tier_val}B"
            metadata['tier'] = infer_tier(metadata.get('reasoning_type', ''))

    # Add required fields first (in order)
    for field in REQUIRED_METADATA_FIELDS:
        if field in metadata:
            formatted.append(f"{field}: {metadata[field]}")
        elif field == 'batch_id':
            formatted.append(f"batch_id: {BATCH_ID_EXPECT}")
        elif field == 'tier':
            formatted.append(f"tier: 2")

    # Add optional fields (in order)
    for field in OPTIONAL_METADATA_FIELDS:
        if field in metadata:
            formatted.append(f"{field}: {metadata[field]}")

    # Add any remaining fields not in our lists
    processed_fields = set(REQUIRED_METADATA_FIELDS + OPTIONAL_METADATA_FIELDS)
    for key, value in metadata.items():
        if key not in processed_fields:
            formatted.append(f"{key}: {value}")

    return '\n'.join(formatted)

print("✅ Formatting functions defined")

In [None]:
# %%
# ============================================
# Cell 7: Helper Functions - Content Processing
# ============================================

def fix_multilang_section(text: str, section_header: str) -> str:
    """Fix multi-language section format"""

    # Find the section
    pattern = rf'({section_header}\s*\n)(.*?)(?=\n###|\n---|\Z)'
    match = re.search(pattern, text, re.DOTALL)

    if not match:
        return text

    section_content = match.group(2)

    # Check if it has proper (TH), (EN), (ZH) format
    has_th = '(TH)' in section_content or '**TH**:' in section_content
    has_en = '(EN)' in section_content or '**EN**:' in section_content

    # Convert **TH**: format to (TH) format
    if '**TH**:' in section_content or '**EN**:' in section_content:
        section_content = section_content.replace('**TH**:', '(TH)')
        section_content = section_content.replace('**EN**:', '(EN)')
        section_content = section_content.replace('**ZH**:', '(ZH)')

    if not has_th and not has_en:
        # Assume it's Thai if no language markers
        fixed_content = f"(TH) {section_content.strip()}\n"
        fixed_content += "(EN) [To be translated]\n"
        fixed_content += "(ZH) [To be translated]"

        text = text[:match.start(2)] + fixed_content + text[match.end(2):]

    return text

print("✅ Content processing functions defined")

In [None]:
# %%
# ============================================
# Cell 8: Main Fix Block Function
# ============================================

def fix_block_complete(header: str, body: str, prompt_num: int) -> tuple:
    """Complete fix for a single block"""

    # Step 1: Clean up everything before Metadata
    body = re.sub(r'^[\s\-#]*\n*', '', body)

    # Step 2: Extract and process Metadata
    meta_match = re.search(r'###\s*Metadata\s*\n(.*?)(?=\n###|\n##|$)', body, re.DOTALL)

    if meta_match:
        meta_start = meta_match.start()
        meta_end = meta_match.end()
        meta_content = meta_match.group(1)
        before_meta = body[:meta_start]
        after_meta = body[meta_end:]
    else:
        # No metadata found - create section
        meta_content = ""
        before_meta = ""
        after_meta = body

    # Parse metadata
    metadata = parse_metadata(meta_content)

    # Clean up duplicate Metadata headers in after_meta
    after_meta = re.sub(r'###\s*Metadata\s*\n', '', after_meta)

    # Remove tier/model_size from after_meta (they should only be in metadata)
    after_meta = re.sub(r'(?m)^\s*tier\s*:.*\n?', '', after_meta)
    after_meta = re.sub(r'(?m)^\s*model_size\s*:.*\n?', '', after_meta)

    # Step 3: Process content sections
    # Clean up --- in the middle of content
    after_meta = re.sub(r'(?m)^\s*---\s*\n?', '', after_meta)

    # Ensure proper structure for multi-language content
    after_meta = fix_multilang_section(after_meta, '### Reasoning')
    after_meta = fix_multilang_section(after_meta, '### Rejected Reasoning')
    after_meta = fix_multilang_section(after_meta, '### Chosen Answer')
    after_meta = fix_multilang_section(after_meta, '### Explanation')

    # Step 4: Reconstruct block
    # Format metadata with correct order
    formatted_metadata = format_metadata(metadata, prompt_num)

    # Build final block
    result = "### Metadata\n"
    result += formatted_metadata
    result += "\n\n"  # Two newlines after metadata
    result += after_meta.lstrip()

    # Clean up multiple newlines
    result = re.sub(r'\n{3,}', '\n\n', result)

    # Ensure ends with ---
    result = result.rstrip() + "\n\n---"

    return header, result

print("✅ Main fix function defined")

In [None]:
# %%
# ============================================
# Cell 9: Load and Process File
# ============================================

print("\n📂 Loading file...")
with open(SRC, 'r', encoding='utf-8') as f:
    content = f.read()

# Split into blocks
print("✂️ Splitting into blocks...")
parts = re.split(r'(##\s*Prompt\s+\d+[^\n]*\n)', content)

blocks = []
for i in range(1, len(parts), 2):
    if i+1 < len(parts):
        header = parts[i].rstrip('\n')
        body = parts[i+1]

        # Extract prompt number
        num_match = re.search(r'Prompt\s+(\d+)', header)
        prompt_num = int(num_match.group(1)) if num_match else i//2 + 1

        blocks.append((header, body, prompt_num))

print(f"📦 Found {len(blocks)} blocks")

In [None]:
# %%
# ============================================
# Cell 10: Fix All Blocks
# ============================================

print("\n🔧 Fixing blocks...")
fixed_blocks = []

for header, body, prompt_num in blocks:
    fixed_header, fixed_body = fix_block_complete(header, body, prompt_num)
    fixed_blocks.append((fixed_header, fixed_body))

    # Show progress every 10 blocks
    if prompt_num % 10 == 0:
        print(f"  Processing block {prompt_num}...")

print(f"✅ Fixed {len(fixed_blocks)} blocks")

In [None]:
# %%
# ============================================
# Cell 11: Combine and Save
# ============================================

print("\n📝 Combining blocks...")
fixed_content = ""
for i, (header, body) in enumerate(fixed_blocks):
    if i > 0:
        fixed_content += "\n"  # Single newline between blocks
    fixed_content += f"{header}\n{body}"

# Save file
print(f"💾 Saving to {DST_FIXED}")
with open(DST_FIXED, 'w', encoding='utf-8') as f:
    f.write(fixed_content)

print("✅ File saved!")

In [None]:
# %%
# ============================================
# Cell 12: Validation
# ============================================

print("\n🔍 Validating fixed file...")

errors = collections.defaultdict(list)
stats = collections.defaultdict(int)

for i, (header, body) in enumerate(fixed_blocks, 1):
    # Check batch_id
    if f"batch_id: {BATCH_ID_EXPECT}" not in body:
        errors['batch_id'].append(i)

    # Check prompt_id format
    if not re.search(rf'prompt_id:\s*{BATCH_ID_EXPECT}_p\d{{3}}', body):
        errors['prompt_id_format'].append(i)

    # Check separator
    if not body.strip().endswith('---'):
        errors['separator'].append(i)

    # Check tier
    tier_match = re.search(r'(?m)^\s*tier\s*:\s*([^\n]+)$', body)
    if tier_match:
        tier = tier_match.group(1).strip()
        if tier in ['1','2','3','4','5','6']:
            stats[f'tier_{tier}'] += 1
        else:
            errors['tier_invalid'].append(i)
    else:
        errors['tier_missing'].append(i)

    # Check reasoning_type normalization
    rtype_match = re.search(r'(?m)^\s*reasoning_type\s*:\s*([^\n]+)$', body)
    if rtype_match:
        rtype = rtype_match.group(1).strip()
        stats[f'type_{rtype}'] += 1

print("\n📊 Validation Results:")
print("="*50)

if errors:
    print("❌ Issues found:")
    for error_type, blocks in errors.items():
        print(f"  {error_type}: {len(blocks)} blocks")
        if len(blocks) <= 5:
            print(f"    Blocks: {blocks}")
else:
    print("✅ No issues found!")

print("\n📈 Statistics:")
print(f"  Total blocks: {len(fixed_blocks)}")

# Show tier distribution
tier_dist = {k: v for k, v in stats.items() if k.startswith('tier_')}
if tier_dist:
    print("  Tier distribution:")
    for tier, count in sorted(tier_dist.items()):
        print(f"    {tier}: {count}")

# Show type distribution (top 10)
type_dist = {k: v for k, v in stats.items() if k.startswith('type_')}
if type_dist:
    print("  Top reasoning types:")
    for rtype, count in sorted(type_dist.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"    {rtype.replace('type_', '')}: {count}")

In [None]:
# %%
# ============================================
# Cell 13: Preview Sample Output
# ============================================

print("\n👁️ Preview of first fixed block:")
print("="*50)
if fixed_blocks:
    preview = fixed_blocks[0][0] + "\n" + fixed_blocks[0][1]
    lines = preview.split('\n')[:30]  # Show first 30 lines
    for line in lines:
        print(line)
    if len(preview.split('\n')) > 30:
        print("... (truncated)")

print("\n✅ COMPLETE! File saved to:", DST_FIXED)