In [None]:
# === PART 1/3: Final Corrected Code ===
import os
import fitz
import logging
import re
import json
import numpy as np
import shutil
import random
from tqdm import tqdm
from paddleocr import PPStructureV3
from PIL import Image

# === CONFIG ===
PDF_FOLDER = r"C:\Users\Admin\Desktop\deep books"
BASE_OUTPUT = r"C:\Users\Admin\Desktop\vbooks"
DPI = 300
TRAIN_RATIO = 0.9
RANDOM_SEED = 42

# === LOGGING ===
def setup_logger():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s"
    )

def safe_mkdir(path):
    os.makedirs(path, exist_ok=True)
    return path

# === 1) PDF ‚Üí IMAGES ===
def extract_images_from_pdf(pdf_path, output_dir, dpi=DPI):
    os.makedirs(output_dir, exist_ok=True)
    with fitz.open(pdf_path) as doc:
        for page_num in tqdm(range(len(doc)), desc=f"[PDF‚ÜíIMG] {os.path.basename(pdf_path)}"):
            try:
                page = doc.load_page(page_num)
                pix = page.get_pixmap(dpi=dpi)
                out_path = os.path.join(output_dir, f"page_{page_num+1:04d}.png")
                pix.save(out_path)
            except Exception as e:
                logging.error(f"Failed page {page_num+1}: {e}")
    logging.info(f"Images ‚Üí {output_dir}")
    return output_dir

# === 2) OCR (Paddle PPStructureV3) ===
def run_structure_pipeline(pipeline, img_dir, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    image_files = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(".png")])
    for fname in tqdm(image_files, desc="[OCR] PPStructureV3"):
        json_path = os.path.join(save_dir, os.path.splitext(fname)[0] + ".json")
        if os.path.exists(json_path):
            continue
        try:
            result = pipeline.predict(input=os.path.join(img_dir, fname))
            if result:
                result[0].save_to_json(save_path=json_path)
        except Exception as e:
            logging.error(f"OCR error on {fname}: {e}")

# === 3) POSTPROCESS OCR JSON (Corrected and Enhanced) ===
def clean_block_content(content):
    if not isinstance(content, str):
        return content
    content = content.replace('\u221e', '‚àû')
    content = re.sub(r'([a-z])([A-Z])', r'\1 \2', content)
    content = re.sub(r'\s{2,}', ' ', content)
    return content.strip()

def clean_math_expression(expr):
    expr = expr.strip()
    # General LaTeX cleanup
    expr = re.sub(r'\bexp\s*\((.*?)\)', r'\\exp(\1)', expr)
    expr = re.sub(r'\bsum\s*\((.*?)\)', r'\\sum(\1)', expr)
    expr = re.sub(r'\blog\s*\((.*?)\)', r'\\log(\1)', expr)
    expr = re.sub(r'([a-zA-Z])_([a-zA-Z0-9:,]+)', r'\1_{\2}', expr)
    
    # CRITICAL FIX: Heuristic to correct missing 'd\mathbf{x}' at the end of integrals
    # This specifically targets the common OCR error of mistaking 'dx' for 'x'
    expr = re.sub(r'(p_{\\mathbf{X}}\(\\mathbf{x}\|\\mathcal{C}_[12]\)\s*)(\\mathbf{x}|\s*x|)\s*$', r'\1d\\mathbf{x}', expr) 
    
    return re.sub(r'\s{2,}', ' ', expr)

def group_equations(parsing_blocks):
    grouped, buffer = [], []
    def flush():
        if len(buffer) > 1:
            joined_content = "\n".join(b.get("block_content","") for b in buffer)
            # Use BBOX that spans min y1 and max y2
            all_bboxes = [b["block_bbox"] for b in buffer if "block_bbox" in b]
            min_x1 = min(b[0] for b in all_bboxes)
            min_y1 = min(b[1] for b in all_bboxes)
            max_x2 = max(b[2] for b in all_bboxes)
            max_y2 = max(b[3] for b in all_bboxes)
            group_bbox = [min_x1, min_y1, max_x2, max_y2]

            grouped.append({
                "block_label": "equation_group",
                "block_content": joined_content,
                "block_bbox": group_bbox, 
                "equation_blocks": buffer.copy()
            })
        else:
            grouped.extend(buffer)
        buffer.clear()

    for b in parsing_blocks:
        if b.get("block_label") in ("equation", "formula"):
            b["block_label"] = "equation"
            buffer.append(b)
        else:
            flush()
            grouped.append(b)
    flush()
    return grouped

def fix_footnotes(entry):
    blocks = entry.get("parsing_res_list", [])
    footnotes = [b for b in blocks if b.get("block_label") == "footnote"]
    for fn in footnotes:
        text = clean_block_content(fn.get("block_content", ""))
        for b in reversed(blocks):
            if b.get("block_label") == "text":
                b["block_content"] = (b.get("block_content","") + " " + text).strip()
                break
        blocks.remove(fn)

def postprocess_json_file(src_path, dst_path):
    try:
        with open(src_path, "r", encoding="utf-8") as f:
            data = json.load(f)
    except Exception as e:
        logging.warning(f"[POST] skip {src_path}: {e}")
        return

    pages = data if isinstance(data, list) else [data]
    for page in pages:
        entries = page.get("data", [])
        if not entries and "parsing_res_list" in page:
            entries = [{"parsing_res_list": page.get("parsing_res_list", [])}]

        for entry in entries:
            if "parsing_res_list" not in entry:
                continue
            
            for b in entry["parsing_res_list"]:
                if "block_content" in b:
                    b["block_content"] = clean_block_content(b["block_content"])
            
            fix_footnotes(entry)
            
            entry["parsing_res_list"] = group_equations(entry["parsing_res_list"])
            
            for b in entry["parsing_res_list"]:
                if b.get("block_label") in ("equation", "equation_group"):
                    # Add cleaned LaTeX field
                    b["equation_latex"] = clean_math_expression(b.get("block_content",""))

    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    with open(dst_path, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

def batch_postprocess(input_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    for fname in tqdm([f for f in os.listdir(input_folder) if f.endswith(".json")], desc="[POST] Clean JSON"):
        postprocess_json_file(os.path.join(input_folder, fname),
                              os.path.join(output_folder, fname))

# === 4) YOLO LABELS === (Unchanged)
def convert_bbox_to_yolo(bbox, W, H):
    x1,y1,x2,y2 = bbox
    xc = ((x1+x2)/2)/W
    yc = ((y1+y2)/2)/H
    w = (x2-x1)/W
    h = (y2-y1)/H
    return f"0 {xc:.6f} {yc:.6f} {w:.6f} {h:.6f}"

def generate_yolo_labels(cleaned_json_dir, img_dir, label_dir):
    os.makedirs(label_dir, exist_ok=True)
    count = 0
    for fname in tqdm(os.listdir(cleaned_json_dir), desc="[LBL] YOLO from JSON"):
        if not fname.endswith(".json"): continue
        page_id = os.path.splitext(fname)[0]
        img_path = os.path.join(img_dir, f"{page_id}.png")
        if not os.path.exists(img_path):
            continue
        try:
            with open(os.path.join(cleaned_json_dir, fname), "r", encoding="utf-8") as f:
                data = json.load(f)
            W,H = Image.open(img_path).size
        except Exception as e:
            logging.warning(f"[LBL] {fname}: {e}")
            continue

        pages = data if isinstance(data, list) else [data]
        yolo_lines = []
        for page in pages:
            entries = page.get("data", [])
            if not entries and "parsing_res_list" in page:
                entries = [{"parsing_res_list": page.get("parsing_res_list", [])}]
            for entry in entries:
                for b in entry.get("parsing_res_list", []):
                    if b.get("block_label") == "image":
                        bbox = b.get("block_bbox")
                        if bbox and len(bbox)==4:
                            yolo_lines.append(convert_bbox_to_yolo(bbox,W,H))
        if yolo_lines:
            with open(os.path.join(label_dir, f"{page_id}.txt"), "w", encoding="utf-8") as out:
                out.write("\n".join(yolo_lines))
            count += 1
    logging.info(f"[LBL] wrote {count} label files ‚Üí {label_dir}")

# === 5) Split to train/val === (Unchanged)
def split_dataset(img_dir, lbl_dir, dataset_dir, train_ratio=TRAIN_RATIO):
    random.seed(RANDOM_SEED)
    for p in ["images/train","images/val","labels/train","labels/val"]:
        safe_mkdir(os.path.join(dataset_dir, p))
    image_files = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(".png")])
    random.shuffle(image_files)
    split_idx = int(len(image_files)*train_ratio)
    splits = {"train": image_files[:split_idx], "val": image_files[split_idx:]}
    for split, files in splits.items():
        for fname in tqdm(files, desc=f"[SPLIT] {split}"):
            shutil.copyfile(os.path.join(img_dir,fname), os.path.join(dataset_dir,"images",split,fname))
            lbl_src = os.path.join(lbl_dir, os.path.splitext(fname)[0]+".txt")
            if os.path.exists(lbl_src):
                shutil.copyfile(lbl_src, os.path.join(dataset_dir,"labels",split,os.path.basename(lbl_src)))
    logging.info(f"[SPLIT] dataset ready ‚Üí {dataset_dir}")

# === MASTER ===
def process_books(pdf_folder):
    setup_logger()
    logging.info("Loading PaddleOCR (PPStructureV3)‚Ä¶")
    pipeline = PPStructureV3()

    pdf_files = [f for f in os.listdir(pdf_folder) if f.lower().endswith(".pdf")]
    for pdf_file in pdf_files:
        book = os.path.splitext(pdf_file)[0]
        logging.info(f"\n===== BOOK: {book} =====")
        book_dir = safe_mkdir(os.path.join(BASE_OUTPUT, book))
        img_dir = safe_mkdir(os.path.join(book_dir, "images"))
        ocr_dir = safe_mkdir(os.path.join(book_dir, "paddle_structure_jsons"))
        cleaned_dir = safe_mkdir(os.path.join(book_dir, "cleaned_semantic_jsons"))
        lbl_dir = safe_mkdir(os.path.join(book_dir, "labels"))
        dataset_dir = safe_mkdir(os.path.join(book_dir, "dataset"))
        # 1
        extract_images_from_pdf(os.path.join(pdf_folder, pdf_file), img_dir, DPI)
        # 2
        run_structure_pipeline(pipeline, img_dir, ocr_dir)
        # 3
        batch_postprocess(ocr_dir, cleaned_dir)
        # 4
        generate_yolo_labels(cleaned_dir, img_dir, lbl_dir)
        # 5
        split_dataset(img_dir, lbl_dir, dataset_dir)
        logging.info(f"Finished: {book}")

if __name__ == "__main__":
    process_books(PDF_FOLDER)

In [None]:
# === PART 2/3: Final Corrected Code (Includes Equation Cropping) ===
import os
import json
from PIL import Image

BASE_OUTPUT = r"C:\Users\Admin\Desktop\vbooks"

def crop_elements_from_paddle_json(cleaned_json_dir, img_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    total = 0
    for fname in os.listdir(cleaned_json_dir):
        if not fname.endswith(".json"): continue
        page_id = os.path.splitext(fname)[0]
        image_path = os.path.join(img_dir, f"{page_id}.png")
        if not os.path.exists(image_path):
            print(f"[CROP] missing image for {fname}")
            continue
        try:
            with open(os.path.join(cleaned_json_dir, fname), "r", encoding="utf-8") as f:
                data = json.load(f)
            img = Image.open(image_path)
        except Exception as e:
            print(f"[CROP] {fname}: {e}")
            continue

        pages = data if isinstance(data, list) else [data]
        for page in pages:
            entries = page.get("data", [])
            if not entries and "parsing_res_list" in page:
                entries = [{"parsing_res_list": page.get("parsing_res_list", [])}]
            
            for entry in entries:
                plist = entry.get("parsing_res_list", [])
                for block_index, b in enumerate(plist):
                    
                    block_label = b.get("block_label")
                    # Target diagrams, single equations, and grouped equations
                    if block_label in ("image", "equation", "equation_group"):
                        
                        bbox = b.get("block_bbox")
                        if not bbox or len(bbox)!=4: continue
                        x1,y1,x2,y2 = map(int, bbox)
                        if x1>=x2 or y1>=y2: continue
                        
                        # Define type for filename
                        element_type = "diagram" if block_label == "image" else "equation"
                        
                        crop = img.crop((x1,y1,x2,y2))
                        # IMPORTANT: standardized file name for linking
                        out_name = f"{page_id}_{element_type}_{block_index}.png"
                        crop.save(os.path.join(output_dir, out_name))
                        total += 1
                        
    print(f"[CROP] saved {total} crops ‚Üí {output_dir}")

def process_labels_and_diagrams(book_dir):
    cleaned_json_dir = os.path.join(book_dir, "cleaned_semantic_jsons")
    img_dir = os.path.join(book_dir, "images")
    # Output to a single 'assets' folder
    assets_dir = os.path.join(book_dir, "outputs", "assets")
    os.makedirs(os.path.join(book_dir, "outputs"), exist_ok=True)

    if not (os.path.isdir(cleaned_json_dir) and os.path.isdir(img_dir)):
        print(f"[CROP] skip {book_dir}: missing cleaned_semantic_jsons or images")
        return
        
    crop_elements_from_paddle_json(cleaned_json_dir, img_dir, assets_dir)

if __name__ == "__main__":
    print("=== Code 2: Element Cropping (Diagrams & Equations) ===")
    for name in os.listdir(BASE_OUTPUT):
        d = os.path.join(BASE_OUTPUT, name)
        if os.path.isdir(d):
            process_labels_and_diagrams(d)
    print("=== Done ===")

In [None]:
# === PART 3/3: Final Corrected Code (Full Semantic Pipeline) ===
import os
import re
import json
import torch
import clip
import spacy
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

BASE_OUTPUT = r"C:\Users\Admin\Desktop\vbooks"

# Initialize CLIP and SpaCy once
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
nlp = spacy.load("en_core_web_sm")

VALID_BLOCKS = {"text","paragraph","header","caption","figure_title"}
figure_ref_pattern = re.compile(r"Figure\s?(\d+(?:\.\d+)*)", re.IGNORECASE)
equation_ref_pattern = re.compile(r"(Eq(?:uation)?\.)\s*\((\d+(?:\.\d+)*)\)", re.IGNORECASE)

LAYOUT_Y_TOL = 250
THRESH_SENT2DIAG = 0.22

# --- HELPER FUNCTIONS ---

def truncate_for_clip(text):
    doc = nlp(text)
    return " ".join([t.text for t in doc][:75])

def embed_texts(texts):
    if not texts: return np.zeros((0,512))
    tokens = clip.tokenize([truncate_for_clip(t) for t in texts], truncate=True).to(device)
    with torch.no_grad():
        return clip_model.encode_text(tokens).cpu().numpy()

def embed_images(paths):
    embs = []
    for p in paths:
        try:
            image = preprocess(Image.open(p).convert("RGB")).unsqueeze(0).to(device)
            with torch.no_grad():
                embs.append(clip_model.encode_image(image).cpu().numpy()[0])
        except:
            embs.append(None)
    return embs

def layout_scores(sentence_bbox, image_bboxes):
    if not sentence_bbox: return [0.0]*len(image_bboxes)
    sy = (sentence_bbox[1]+sentence_bbox[3])/2
    # Simple score based on vertical inverse distance
    return [1/(1+abs(sy - ((b[1]+b[3])/2))) for b in image_bboxes]

def hybrid_match(sentence_emb, diagram_embs, layout_sc):
    # Combines semantic score (85%) and layout score (15%)
    cs = cosine_similarity([sentence_emb], diagram_embs)[0]
    return np.argmax([0.85*c + 0.15*layout_sc[i] for i,c in enumerate(cs)])

def load_blocks(cleaned_json):
    # returns list of blocks for single page (flat)
    data = cleaned_json
    if isinstance(data, list):
        plist = data[0].get("parsing_res_list", []) if data else []
    elif isinstance(data, dict):
        if "data" in data and isinstance(data["data"], list) and data["data"]:
            plist = data["data"][0].get("parsing_res_list", [])
        else:
            plist = data.get("parsing_res_list", [])
    else:
        plist = []
    return plist

def extract_sentences(cleaned_json_dir, sentence_dir):
    os.makedirs(sentence_dir, exist_ok=True)
    for fname in os.listdir(cleaned_json_dir):
        if not fname.endswith(".json"): continue
        page_id = fname[:-5]
        with open(os.path.join(cleaned_json_dir, fname),"r",encoding="utf-8") as f:
            data = json.load(f)
        blocks = load_blocks(data)

        out, sid = [], 0
        for bi, b in enumerate(blocks):
            if b.get("block_label","") not in VALID_BLOCKS: continue
            text = (b.get("block_content") or "").strip()
            if not text: continue
            bbox = b.get("block_bbox")
            for sidx, sent in enumerate(nlp(text).sents):
                t = sent.text.strip()
                if len(t) < 5: continue
                out.append({
                    "text": t, "sentence_id": sid, "block_index": bi,
                    "block_label": b.get("block_label"), "sentence_index": sidx,
                    "bbox": bbox
                })
                sid += 1
        with open(os.path.join(sentence_dir, f"{page_id}.json"),"w",encoding="utf-8") as f:
            json.dump({"sentences": out}, f, indent=2)
    print(f"[SENTS] ‚Üí {sentence_dir}")

def map_formula_numbers(blocks):
    """Return dict: '1.27' -> index_of_nearest_equation"""
    eq_idxs = [(i,b) for i,b in enumerate(blocks) if b.get("block_label") in ("equation","equation_group","formula")]
    nums = [(i,b) for i,b in enumerate(blocks) if b.get("block_label") in ("formula_number","equation_number")]
    mapping = {}
    def ycenter(bb): return (bb[1]+bb[3])/2 if bb else 0
    for i_num, b_num in nums:
        num_txt = (b_num.get("block_content") or "").strip()
        m = re.search(r"\((\d+(?:\.\d+)*)\)", num_txt)
        if not m: continue
        y_num = ycenter(b_num.get("block_bbox"))
        closest = None; best = 1e9
        for i_eq, b_eq in eq_idxs:
            y_eq = ycenter(b_eq.get("block_bbox"))
            d = abs(y_eq - y_num)
            if d < best:
                best, closest = d, i_eq
        if closest is not None:
            mapping[m.group(1)] = closest
    return mapping

def tokens_from_latex(s):
    # more permissive: include one-letter tokens (x,p,R), keep greek words
    return set([t.lower() for t in re.findall(r"[A-Za-z]+", s)])

# --- MAIN LINKING FUNCTION ---
def semantic_linking(sentence_dir, cleaned_json_dir, assets_dir, links_dir):
    os.makedirs(links_dir, exist_ok=True)
    
    # Filter for diagram assets only
    diag_asset_files_all = sorted([f for f in os.listdir(assets_dir) if "diagram" in f and f.endswith(".png")])
    
    for fname in tqdm(sorted(os.listdir(sentence_dir)), desc="[LINK]"):
        if not fname.endswith(".json"): continue
        page_id = fname[:-5]

        try:
            with open(os.path.join(sentence_dir, fname),"r",encoding="utf-8") as f:
                sent_data = json.load(f)
            with open(os.path.join(cleaned_json_dir, f"{page_id}.json"),"r",encoding="utf-8") as f:
                page_data = json.load(f)
        except Exception as e:
            print(f"[LINK] skip {page_id}: {e}")
            continue

        sentences = sent_data.get("sentences", [])
        s_texts = [s["text"] for s in sentences]
        s_embs = embed_texts(s_texts) if len(s_texts) else np.zeros((0,512))

        blocks = load_blocks(page_data)
        figures = [(b.get("block_content"), b.get("block_bbox")) for b in blocks if b.get("block_label")=="figure_title"]
        
        # --- Diagram Linking ---
        # Filter diagram assets for the current page
        diag_files_all = [f for f in diag_asset_files_all if f.startswith(page_id)]
        diag_paths = [os.path.join(assets_dir,f) for f in diag_files_all]
        diag_embs_all = embed_images(diag_paths)
        valid = []
        for fimg, emb in zip(diag_files_all, diag_embs_all):
            try:
                # Extract block index from filename: {page_id}_diagram_{block_index}.png
                bidx = int(fimg.split("_")[-1].split(".")[0])
            except:
                continue
            if emb is None: continue
            if bidx < len(blocks) and blocks[bidx].get("block_label")=="image":
                valid.append((fimg, bidx, emb, blocks[bidx].get("block_bbox")))
        sent2diag, caption2diag, textref2diag = [], [], []

        if valid and len(sentences):
            files, idxs, emb_list, vbboxes = zip(*valid)
            
            # Sentence-to-Diagram Semantic Match
            for i,(s_obj, s_emb) in enumerate(zip(sentences, s_embs)):
                ls = layout_scores(s_obj.get("bbox"), list(vbboxes))
                best = hybrid_match(s_emb, list(emb_list), ls)
                cs = cosine_similarity([s_emb], list(emb_list))[0]
                comb = 0.85*cs[best] + 0.15*ls[best]
                if comb > THRESH_SENT2DIAG:
                    sent2diag.append({
                        "sentence": s_obj["text"], "sentence_id": i, "bbox": s_obj.get("bbox"),
                        "diagram_file": files[best], "score": float(round(comb,4))
                    })

            # Caption-to-Diagram (Layout Match)
            for cap, cap_box in figures:
                if not cap_box: continue
                cy = (cap_box[1]+cap_box[3])/2
                closest = min(valid, key=lambda v: abs(((v[3][1]+v[3][3])/2) - cy))
                caption2diag.append({
                    "caption": cap, "bbox": cap_box, "diagram_file": closest[0], "score": 0.9
                })

            # Textual Reference-to-Diagram
            for i, s_obj in enumerate(sentences):
                m = figure_ref_pattern.search(s_obj["text"])
                if m:
                    num = m.group(1)
                    for cap in caption2diag:
                        if num in (cap.get("caption") or ""):
                            textref2diag.append({
                                "sentence": s_obj["text"], "sentence_id": i, "bbox": s_obj.get("bbox"),
                                "diagram_file": cap["diagram_file"], "score": 0.8
                            })
                            break
        
        # --- Equation Linking ---
        num_to_eqidx = map_formula_numbers(blocks)
        equations = [(i,b) for i,b in enumerate(blocks) if b.get("block_label") in ("equation","equation_group","formula")]
        sent2eq = []

        for s in sentences:
            s_text = s["text"]; s_bbox = s.get("bbox")
            matched = False
            
            # 1) explicit Eq.(x.x) reference
            m = equation_ref_pattern.search(s_text)
            if m:
                eq_no = m.group(2)
                if eq_no in num_to_eqidx:
                    i_eq = num_to_eqidx[eq_no]
                    b_eq = blocks[i_eq]
                    sent2eq.append({
                        "sentence": s_text, "sentence_id": s["sentence_id"], "bbox": s_bbox,
                        "equation_index": i_eq, 
                        "equation_latex": b_eq.get("equation_latex"), # <-- Use cleaned LaTeX
                        "equation_bbox": b_eq.get("block_bbox"), "match_type": "eq_number_ref"
                    })
                    matched = True

            if matched: continue

            # 2) token overlap with LaTeX
            s_tokens = tokens_from_latex(s_text)
            best_tok = None
            for i_eq, b_eq in equations:
                e_tokens = tokens_from_latex(b_eq.get("equation_latex") or "") # <-- Use cleaned LaTeX
                if s_tokens & e_tokens:
                    best_tok = (i_eq, b_eq); break
            if best_tok:
                i_eq, b_eq = best_tok
                sent2eq.append({
                    "sentence": s_text, "sentence_id": s["sentence_id"], "bbox": s_bbox,
                    "equation_index": i_eq, 
                    "equation_latex": b_eq.get("equation_latex"), # <-- Use cleaned LaTeX
                    "equation_bbox": b_eq.get("block_bbox"), "match_type": "token_overlap"
                })
                continue

            # 3) layout proximity
            if s_bbox:
                sy = (s_bbox[1]+s_bbox[3])/2
                closest = None; bestd = 1e9
                for i_eq, b_eq in equations:
                    eb = b_eq.get("block_bbox")
                    if not eb: continue
                    ey = (eb[1]+eb[3])/2
                    d = abs(sy - ey)
                    if d < bestd:
                        bestd, closest = d, (i_eq, b_eq)
                if closest and bestd < LAYOUT_Y_TOL:
                    i_eq, b_eq = closest
                    sent2eq.append({
                        "sentence": s_text, "sentence_id": s["sentence_id"], "bbox": s_bbox,
                        "equation_index": i_eq, 
                        "equation_latex": b_eq.get("equation_latex"), # <-- Use cleaned LaTeX
                        "equation_bbox": b_eq.get("block_bbox"), "match_type": "layout_near"
                    })

        with open(os.path.join(links_dir, f"{page_id}.json"),"w",encoding="utf-8") as f:
            json.dump({
                "page": page_id,
                "sentence_to_diagram": sent2diag,
                "caption_to_diagram": caption2diag,
                "textual_reference": textref2diag,
                "sentence_to_equation": sent2eq
            }, f, indent=2)
        print(f"[LINK] {page_id} ‚Üí S2D {len(sent2diag)} | C2D {len(caption2diag)} | TRef {len(textref2diag)} | S2E {len(sent2eq)}")

# --- MERGE FUNCTION ---
def merge_semantic_pages(sentence_dir, cleaned_json_dir, links_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    for fname in tqdm(sorted(os.listdir(sentence_dir)), desc="[MERGE]"):
        if not fname.endswith(".json"): continue
        page_id = fname[:-5]
        try:
            with open(os.path.join(sentence_dir, fname),"r",encoding="utf-8") as f:
                sdata = json.load(f)
            with open(os.path.join(cleaned_json_dir, f"{page_id}.json"),"r",encoding="utf-8") as f:
                raw = json.load(f)
            with open(os.path.join(links_dir, f"{page_id}.json"),"r",encoding="utf-8") as f:
                ldata = json.load(f)
        except FileNotFoundError:
            print(f"[MERGE] missing file for {page_id}, skip")
            continue
        blocks = load_blocks(raw)
        merged = {
            "page_id": page_id,
            "sentences": sdata.get("sentences", []),
            "blocks": blocks,
            "links": {
                "sentence_to_diagram": ldata.get("sentence_to_diagram", []),
                "caption_to_diagram": ldata.get("caption_to_diagram", []),
                "textual_reference": ldata.get("textual_reference", []),
                "sentence_to_equation": ldata.get("sentence_to_equation", [])
            }
        }
        with open(os.path.join(output_dir, f"{page_id}.json"),"w",encoding="utf-8") as f:
            json.dump(merged, f, indent=2)
    print(f"[MERGE] ‚Üí {output_dir}")

# --- MASTER EXECUTION FOR PART 3 ---
def process_semantic_pipeline(book_dir):
    print(f"\n=== Semantic Pipeline: {book_dir} ===")
    cleaned_json_dir = os.path.join(book_dir, "cleaned_semantic_jsons")
    sentence_dir = os.path.join(book_dir, "outputs", "sm_sentences")
    links_dir = os.path.join(book_dir, "outputs", "semantic_links")
    assets_dir = os.path.join(book_dir, "outputs", "assets")
    merged_dir = os.path.join(book_dir, "outputs", "semantic_pages")
    for p in [os.path.join(book_dir,"outputs"), sentence_dir, links_dir, merged_dir]:
        os.makedirs(p, exist_ok=True)

    # 8) sentences
    extract_sentences(cleaned_json_dir, sentence_dir)
    # 9) links
    semantic_linking(sentence_dir, cleaned_json_dir, assets_dir, links_dir)
    # 10) merge
    merge_semantic_pages(sentence_dir, cleaned_json_dir, links_dir, merged_dir)
    print(f"=== Done: {book_dir} ===")

if __name__ == "__main__":
    for name in os.listdir(BASE_OUTPUT):
        d = os.path.join(BASE_OUTPUT, name)
        if os.path.isdir(d):
            # NOTE: Before running, ensure Part 1 and Part 2 have completed successfully.
            process_semantic_pipeline(d)

In [None]:
# ================================================
# CELL 1 ‚Äî CONFIG + UTILITIES
# ================================================

import os
import json
import base64
import re
import requests
import time
import pathlib
from typing import Optional, List, Dict, Any

# ------------------------------------------------
# üîß MODEL NAMES (YOUR MODELS)
# ------------------------------------------------
VISION_MODEL = "qwen3-vl:8b"     # vision-language model
TEXT_MODEL   = "llama3:8b"       # text-only reasoning modelAA

OLLAMA_URL = "http://localhost:11434/api/generate"
RETRY_COUNT = 3
RETRY_DELAY = 3


# ------------------------------------------------
# üîç PATH + BASE64 HELPERS
# ------------------------------------------------
def to_posix(p: str) -> str:
    """Windows path ‚Üí POSIX (required for training jsonl)."""
    return pathlib.Path(p).as_posix()

def load_image_b64(path: str) -> Optional[str]:
    """Read an image file and return base64 string."""
    if not path or not os.path.exists(path):
        return None
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


# ------------------------------------------------
# üß† ROBUST JSON EXTRACTOR (handles messy output)
# ------------------------------------------------
def extract_json_raw(text: str) -> Any:
    """
    Find JSON between <JSON>...</JSON> if present.
    Otherwise try to auto-detect the first valid JSON block.
    """
    # Look for <JSON> ... </JSON>
    block = re.search(r"<JSON>([\s\S]+?)</JSON>", text)
    if block:
        try:
            return json.loads(block.group(1).strip())
        except:
            pass

    # Fallback: find first {...} or [...]
    def find_balanced(s, oc, cc):
        depth = 0
        start = None
        for i, ch in enumerate(s):
            if ch == oc:
                if start is None:
                    start = i
                depth += 1
            elif ch == cc and depth > 0:
                depth -= 1
                if depth == 0:
                    return s[start:i+1]
        return None

    for oc, cc in [("{", "}"), ("[", "]")]:
        candidate = find_balanced(text, oc, cc)
        if candidate:
            try:
                return json.loads(candidate)
            except:
                pass

    # Nothing worked ‚Üí return empty
    return {}


# ------------------------------------------------
# üöÄ UNIVERSAL OLLAMA CALL
# ------------------------------------------------
def ollama_call(prompt: str, model: str, images: Optional[List[str]] = None,
                max_tokens: int = 120) -> str:
    """
    Universal caller for both Qwen-VL and LLaMA.
    """
    payload = {
        "model": model,
        "prompt": prompt,
        "stream": False,
        "num_predict": max_tokens
    }
    if images:
        payload["images"] = [img for img in images if img]

    for attempt in range(RETRY_COUNT):
        try:
            r = requests.post(OLLAMA_URL, json=payload, timeout=240)
            if r.status_code != 200:
                raise RuntimeError(r.text[:300])
            out = r.json()
            return out.get("response", json.dumps(out))
        except Exception as e:
            if attempt < RETRY_COUNT - 1:
                time.sleep(RETRY_DELAY)
            else:
                raise RuntimeError(f"Ollama FAILED: {e}")


# ------------------------------------------------
# üìë PAGE TYPE DETECTOR
# ------------------------------------------------
def detect_page_type(context: str, has_image: bool) -> str:
    """
    Decide mode:
      - 'image'      ‚Üí diagram/figure detected
      - 'equation'   ‚Üí equation-heavy text pages
      - 'text'       ‚Üí plain text pages
    """
    
    if has_image:
        return "image"

    # equation detection
    eq_patterns = [
        r"\\begin\{equation\}",
        r"\\frac",
        r"\\sum",
        r"\\int",
        r"=",
    ]
    eq_count = sum(bool(re.search(p, context)) for p in eq_patterns)

    if eq_count >= 2:
        return "equation"

    return "text"


In [None]:
# ================================================
# CELL 2 ‚Äî MULTI-MODEL AGENTS
# ================================================

import random

class TextbookTrainingAgents:
    def __init__(self):
        self.vision_model = VISION_MODEL
        self.text_model   = TEXT_MODEL

    # ------------------------------------------------------------
    # QUESTION GENERATION (image ‚Üí Qwen; text ‚Üí LLaMA)
    # ------------------------------------------------------------
    def generate_questions(self, context: str, page_type: str,
                           b64: Optional[str]) -> List[str]:

        if page_type == "image":
            model = self.vision_model
            img = [b64] if b64 else None
        else:
            model = self.text_model
            img = None

        prompt = f"""
SYSTEM:
Generate 3‚Äì5 short questions (max 18 words).
Rules:
- Must be answerable ONLY from the provided TEXT and (if present) IMAGE.
- If an image is included, the question MUST reference or depend on the image.
- Avoid LaTeX in the question text.
- Output ONLY a JSON list inside <JSON>...</JSON>.

TEXT:
{context}
"""
        if page_type == "image":
            prompt += "\nIMAGE: <image provided>\n"

        try:
            resp = ollama_call(prompt, model, img, max_tokens=120)
            data = extract_json_raw(resp)
            if isinstance(data, list) and data:
                return [str(x).strip() for x in data]
            return ["Explain this content."]
        except:
            return ["Explain this content."]

    # ------------------------------------------------------------
    # STUDENT ANSWER (image ‚Üí Qwen; text ‚Üí LLaMA)
    # ------------------------------------------------------------
    def student_answer(self, question: str, context: str,
                       page_type: str, b64: Optional[str]) -> str:

        if page_type == "image":
            model = self.vision_model
            img = [b64] if b64 else None
        else:
            model = self.text_model
            img = None

        prompt = f"""
SYSTEM:
You are a BEGINNER-LEVEL tutor.
Rules:
- Use ONLY the textbook TEXT (and IMAGE if provided).
- No outside knowledge.
- If the answer is missing, say:
  "The textbook context does not provide this information."
- 2‚Äì4 clear sentences.
- No markdown.
- End with [p.?].

QUESTION:
{question}

TEXT:
{context}
"""
        if page_type == "image":
            prompt += "\nIMAGE: <image provided>\n"

        return ollama_call(prompt, model, img, max_tokens=150)

    # ------------------------------------------------------------
    # EXPERT ANSWER (image ‚Üí Qwen; text ‚Üí LLaMA)
    # ------------------------------------------------------------
    def expert_answer(self, question: str, context: str,
                      page_type: str, b64: Optional[str]) -> str:

        if page_type == "image":
            model = self.vision_model
            img = [b64] if b64 else None
        else:
            model = self.text_model
            img = None

        prompt = f"""
SYSTEM:
You are an EXPERT-LEVEL tutor strictly limited to the TEXTBOOK CONTEXT.
Rules:
-Use ONLY the TEXT (and IMAGE if available).
- DO NOT introduce equations or symbols not present in the context.
- LaTeX is allowed, but ONLY for equations that appear in the context.
- Provide a clear explanation using BOTH:
    ‚Ä¢ LaTeX (when present in context)
    ‚Ä¢ Verbal description of the math
- Provide 4‚Äì12 sentences depending on complexity.
- If the derivation steps are NOT shown in the context, say so.
- No markdown formatting.
- DO NOT use outside facts.

QUESTION:
{question}

TEXT:
{context}
"""
        if page_type == "image":
            prompt += "\nIMAGE: <image provided>\n"

        return ollama_call(prompt, model, img, max_tokens=220)

    # ------------------------------------------------------------
    # VALIDATOR (always uses LLaMA ‚Äî fastest)
    # ------------------------------------------------------------
    def validate(self, question: str, student: str, expert: str,
                 context: str, b64: Optional[str]) -> Dict[str, Any]:

        prompt = f"""
SYSTEM:
Soft validator.
- Approve by default.
- Reject ONLY if the answer contradicts the TEXT.
- LaTeX allowed.
- Output JSON inside <JSON>...</JSON>.

QUESTION: {question}
STUDENT: {student}
EXPERT: {expert}
TEXT: {context}
"""

        try:
            resp = ollama_call(prompt, self.text_model, None, max_tokens=120)
            data = extract_json_raw(resp)

            return {
                "student_approved": bool(data.get("student_approved", True)),
                "expert_approved":  bool(data.get("expert_approved",  True)),
                "student_issues":   data.get("student_issues", []),
                "expert_issues":    data.get("expert_issues", [])
            }
        except:
            # fallback: approve both
            return {
                "student_approved": True,
                "expert_approved": True,
                "student_issues": [],
                "expert_issues": []
            }


In [None]:
# ==========================================================
# CELL 3 ‚Äî PAGE-LEVEL HYBRID GENERATOR (FULL POWER)
# ==========================================================

import os
import json
import re
from tqdm import tqdm

class HybridLlavaGenerator:
    """
    One page -> one big context -> 3‚Äì5 Qs -> student + expert answers.
    Uses:
      - Qwen-VL for image pages
      - LLaMA3 for pure-text / equation pages
    """

    def __init__(self, agents: TextbookTrainingAgents):
        self.agents = agents

    # --------------------------------------------------------------
    # checkpoint helpers
    # --------------------------------------------------------------
    def load_checkpoint(self, ckpt_path: str):
        if not os.path.exists(ckpt_path):
            return []
        buf = []
        with open(ckpt_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    buf.append(json.loads(line))
                except:
                    pass
        print(f"üîÑ Loaded checkpoint ({len(buf)} items)")
        return buf

    def save_checkpoint(self, data, ckpt_path: str):
        with open(ckpt_path, "w", encoding="utf-8") as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + "\n")
        print(f"üíæ Saved checkpoint ({len(data)} items)")

    # --------------------------------------------------------------
    # main book processor  (PAGE-LEVEL)
    # --------------------------------------------------------------
    def process_book(self, semantic_dir: str, assets_dir: str,
                     output_path: str, ckpt_path: str, max_pages: int):

        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        # load previous stuff
        training_data = self.load_checkpoint(ckpt_path)
        done_ids = {x["id"] for x in training_data}

        # which pages already have *any* sample?
        done_pages = set()
        for x in training_data:
            pid = str(x["id"]).split("_")[0]
            done_pages.add(pid)

        files = sorted([f for f in os.listdir(semantic_dir) if f.endswith(".json")])
        if max_pages:
            files = files[:max_pages]

        print(f"\nüìò PAGE-LEVEL processing for {len(files)} pages\n")

        for fname in tqdm(files):
            page_id = os.path.splitext(fname)[0]

            # skip pages we already processed in a previous run
            if page_id in done_pages:
                continue

            page_path = os.path.join(semantic_dir, fname)
            try:
                with open(page_path, "r", encoding="utf-8") as f:
                    page_data = json.load(f)
            except Exception as e:
                print(f"‚ö†Ô∏è Failed to read {fname}: {e}")
                continue

            # always set page_id on the JSON (for context string)
            page_data["page_id"] = page_id

            # ------------- build page-level context -------------
            context, eq_list = self.build_page_context(page_data)

            # ------------- pick an image (diagram if exists) ----
            img_path = self.pick_page_image(page_data, assets_dir)
            b64 = load_image_b64(img_path) if img_path else None
            has_image = b64 is not None

            page_type = detect_page_type(context, has_image)

            # ------------- generate questions for this page -----
            try:
                questions = self.agents.generate_questions(context, page_type, b64)
            except Exception as e:
                print(f"‚ö†Ô∏è Question gen failed for {page_id}: {e}")
                continue

            if not questions:
                continue

            # cap questions per page (3‚Äì5 typical)
            questions = questions[:5]

            # ------------- run student + expert for each Q ------
            for q_idx, q in enumerate(questions):
                # simple safe id
                raw = re.sub(r"\s+", "_", q)
                raw = re.sub(r"[^A-Za-z0-9_-]", "", raw)
                safe = raw[:40] if raw else f"q{q_idx}"
                base_id = f"{page_id}_{safe}"

                sid = f"{base_id}_s"
                eid = f"{base_id}_e"

                if sid in done_ids and eid in done_ids:
                    continue

                # answers
                try:
                    stu_ans = self.agents.student_answer(q, context, page_type, b64)
                    exp_ans = self.agents.expert_answer(q, context, page_type, b64)
                except Exception as e:
                    print(f"‚ö†Ô∏è Answer gen failed ({page_id} / '{q[:30]}'): {e}")
                    continue

                # validation (LLaMA)
                val = self.agents.validate(q, stu_ans, exp_ans, context, b64)
                stu_ok = val.get("student_approved", True)
                exp_ok = val.get("expert_approved", True)

                # student sample
                if stu_ok and sid not in done_ids:
                    training_data.append({
                        "id": sid,
                        "image": to_posix(img_path) if img_path else None,
                        "conversations": [
                            {
                                "from": "human",
                                "value": f"<image>\n{q}" if has_image else q
                            },
                            {
                                "from": "gpt",
                                "value": stu_ans
                            }
                        ],
                        "level": "student"
                    })
                    done_ids.add(sid)

                # expert sample
                if exp_ok and eid not in done_ids:
                    training_data.append({
                        "id": eid,
                        "image": to_posix(img_path) if img_path else None,
                        "conversations": [
                            {
                                "from": "human",
                                "value": f"<image>\n{q}" if has_image else q
                            },
                            {
                                "from": "gpt",
                                "value": exp_ans
                            }
                        ],
                        "level": "expert"
                    })
                    done_ids.add(eid)

            # mark this page as processed and checkpoint
            done_pages.add(page_id)
            self.save_checkpoint(training_data, ckpt_path)

        # ------------------ final save -------------------------
        with open(output_path, "w", encoding="utf-8") as f:
            for item in training_data:
                f.write(json.dumps(item, ensure_ascii=False) + "\n")

        print(f"\n‚úÖ PAGE-LEVEL DONE ‚Äî {len(training_data)} samples ‚Üí {output_path}")

    # --------------------------------------------------------------
    # build PAGE-LEVEL context (text + equations)
    # --------------------------------------------------------------
    def build_page_context(self, page_data: Dict[str, Any]) -> (str, List[str]):
        sentences = page_data.get("sentences", [])
        txt_bits = []

        for s in sentences:
            t = s.get("text", "").strip()
            if t:
                txt_bits.append(t)

        raw_text = " ".join(txt_bits)

        # gather unique equations from links (page-level)
        eqs = []
        seen = set()
        for link in page_data.get("links", {}).get("sentence_to_equation", []):
            eq = link.get("equation_latex", "")
            if eq and eq not in seen:
                seen.add(eq)
                eqs.append(eq)

        eq_block = ""
        if eqs:
            # keep them short-ish
            eq_lines = [f"[Eq {i+1}] {e}" for i, e in enumerate(eqs)]
            eq_block = "\nEQUATIONS ON THIS PAGE:\n" + "\n".join(eq_lines)

        page_id = page_data.get("page_id", "")
        context = raw_text
        if eq_block:
            context += "\n" + eq_block
        context += f"\n[Source: {page_id}]"

        # safety: trim insane length
        context = context[:4000]

        return context, eqs

    # --------------------------------------------------------------
    # choose ONE image for the page (diagram preferred)
    # --------------------------------------------------------------
    def pick_page_image(self, page_data: Dict[str, Any], assets_dir: str) -> Optional[str]:
        links = page_data.get("links", {})

        # 1) prefer sentence_to_diagram links
        diag_links = links.get("sentence_to_diagram", []) or []
        for dl in diag_links:
            fname = dl.get("diagram_file")
            if not fname:
                continue
            full = os.path.join(assets_dir, fname)
            if os.path.exists(full):
                return full

        # 2) fallback: anything in assets_dir starting with page_id_
        page_id = page_data.get("page_id", "")
        if os.path.isdir(assets_dir):
            for f in os.listdir(assets_dir):
                if f.startswith(page_id) and f.lower().endswith((".png", ".jpg", ".jpeg", ".webp")):
                    return os.path.join(assets_dir, f)

        # 3) no image for this page
        return None


In [None]:
# ==========================================================
# CELL 4 ‚Äî RUN THE HYBRID GENERATOR
# ==========================================================

# Create the generator using agents defined in Cell 2
agents = TextbookTrainingAgents()
runner = HybridLlavaGenerator(agents)

# ----------------------------------------------------------
# CONFIGURE YOUR PATHS HERE
# ----------------------------------------------------------
BASE = "C:/Users/Admin/Desktop/vbooks"   # change if needed

# Books inside BASE folder (each should have /outputs/semantic_pages & /outputs/assets)
books = [b for b in os.listdir(BASE) if os.path.isdir(os.path.join(BASE, b))]

print("üìö Books detected:", books)

# SAMPLE mode?
sample = input("Run SAMPLE mode (first 3 pages)? (y/n): ").lower().strip() == "y"
max_pages = 3 if sample else 9999

# ----------------------------------------------------------
# RUN THROUGH ALL BOOKS
# ----------------------------------------------------------
for book in books:
    print(f"\nüìò Processing book: {book}")

    semantic_dir = os.path.join(BASE, book, "outputs", "semantic_pages")
    assets_dir   = os.path.join(BASE, book, "outputs", "assets")

    if not os.path.exists(semantic_dir):
        print(f"‚ùå Missing semantic_pages for {book}")
        continue

    output_path = os.path.join(BASE, book, "outputs", "llava_training.jsonl")
    ckpt_path   = os.path.join(BASE, book, "outputs", "llava_checkpoint.jsonl")

    runner.process_book(
        semantic_dir=semantic_dir,
        assets_dir=assets_dir,
        output_path=output_path,
        ckpt_path=ckpt_path,
        max_pages=max_pages,
    )

print("\n‚úÖ ALL BOOKS COMPLETED!")


In [None]:
import os, json, random

BASE = r"C:/Users/Admin/Desktop/vbooks"

all_lines = []

for bk in os.listdir(BASE):
    fpath = os.path.join(BASE, bk, "outputs", "llava_training.jsonl")
    if not os.path.exists(fpath): 
        continue

    print("collecting -->", fpath)

    with open(fpath, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if ln:
                all_lines.append(ln)

print("total samples:", len(all_lines))

# optional shuffle
random.shuffle(all_lines)

# final merged dataset
merged = os.path.join(BASE, "merged_llava_training.jsonl")
with open(merged, "w", encoding="utf-8") as f:
    for r in all_lines:
        f.write(r + "\n")

print("merged written to:", merged)


In [None]:
# ============================================================
# LLaVA 1.5‚Äì7B Heavy LoRA Fine-Tuning (Image + Text, 1 GPU)
# ============================================================

import os
import sys
import math
import glob
import random
from typing import Dict, Any, List

# ======== ENV VARS (NO BNB, NO NVML) ========
os.environ["BITSANDBYTES_NOWELCOME"]    = "1"
os.environ["BITSANDBYTES_DISABLE"]      = "1"
os.environ["PYTORCH_NO_NVML"]           = "1"
os.environ["CUDA_VISIBLE_DEVICES"]      = "0"
sys.modules["bitsandbytes"]             = None

# ======== IMPORTS ========
import torch
from PIL import Image
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset, load_from_disk
from transformers import (
    LlavaProcessor,
    LlavaForConditionalGeneration,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model

# ============================================================
# 0Ô∏è‚É£ Seed & Device
# ============================================================
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("üñ• Using device:", device)

# ============================================================
# 1Ô∏è‚É£ Paths
# ============================================================
RAW_JSONL = "C:/Users/Admin/Desktop/vbooks/merged_llava_training.jsonl"
IMAGE_ROOT = "C:/Users/Admin/Desktop/vbooks"
DATA_DIR   = "C:/Users/Admin/Desktop/vbooks/hf_llava_all_heavy"

OUTPUT_DIR  = "C:/Users/Admin/Desktop/vbooks/llava_outputs_bf16_heavy/ckpts"
ADAPTER_DIR = "C:/Users/Admin/Desktop/vbooks/llava_outputs_bf16_heavy/adapter"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(ADAPTER_DIR, exist_ok=True)

MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
MAX_LEN = 1536

# ============================================================
# 2Ô∏è‚É£ Processor + Tokenizer
# ============================================================
print("üîÅ Loading processor/tokenizer...")
processor = LlavaProcessor.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

processor.tokenizer = tokenizer
image_token = processor.image_token
print("üîπ image_token:", image_token)

tokenizer.padding_side = "right"
processor.tokenizer.padding_side = "right"

# ============================================================
# 3Ô∏è‚É£ Dataset loader
# ============================================================

def resolve_image_path(p):
    if not p:
        return None
    if os.path.isabs(p) and os.path.exists(p):
        return p
    cand = os.path.join(IMAGE_ROOT, p)
    return cand if os.path.exists(cand) else None

def build_segments(conv):
    segs = []
    first_human = True
    for turn in conv:
        role = turn["from"]
        text = turn["value"]

        # Strip <image> if present
        if isinstance(text, str) and text.startswith("<image>"):
            parts = text.split("\n", 1)
            text = parts[1].strip() if len(parts) == 2 else ""

        if role == "human":
            if first_human:
                seg = f"USER: {image_token}\n{text}\n"
                first_human = False
            else:
                seg = f"USER: {text}\n"
            segs.append({"text": seg, "is_human": True})
        else:
            seg = f"ASSISTANT: {text}</s>\n"
            segs.append({"text": seg, "is_human": False})
    return segs

def prepare_or_load():
    if os.path.exists(DATA_DIR):
        ds = load_from_disk(DATA_DIR)
        print("üì¶ Loaded existing dataset:", DATA_DIR)
        print("üìä Size:", len(ds))
        return ds

    print("üì• Loading:", RAW_JSONL)
    raw = load_dataset("json", data_files=RAW_JSONL, split="train")
    print("üìä Raw examples:", len(raw))

    processed = []
    for ex in raw:
        img = resolve_image_path(ex.get("image"))
        segs = build_segments(ex["conversations"])
        full = "".join(s["text"] for s in segs)
        processed.append({
            "id": ex.get("id", ""),
            "image_path": img,
            "has_real_image": img is not None,
            "segments": segs,
            "full_text": full,
        })

    ds = Dataset.from_list(processed)
    ds.save_to_disk(DATA_DIR)
    print("üíæ Saved processed dataset:", DATA_DIR)
    return ds

dataset = prepare_or_load()

# ============================================================
# 4Ô∏è‚É£ Train/Val split
# ============================================================
split = dataset.train_test_split(test_size=0.05, seed=SEED)
train_ds = split["train"]
val_ds   = split["test"]

print(f"üß© Train: {len(train_ds)} | Val: {len(val_ds)}")

# ============================================================
# 5Ô∏è‚É£ Collate Function (FINAL FIXED VERSION)
# ============================================================

def collate_fn(batch):

    texts, images, all_segments = [], [], []

    for ex in batch:
        texts.append(ex["full_text"])
        all_segments.append(ex["segments"])

        if ex["image_path"]:
            try:
                images.append(Image.open(ex["image_path"]).convert("RGB"))
            except:
                images.append(Image.new("RGB", (336,336), "white"))
        else:
            images.append(Image.new("RGB", (336,336), "white"))

    # --- Let processor build the true multimodal encoding ---
    inputs = processor(
        text=texts,
        images=images,
        padding=True,
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt",
    )

    input_ids = inputs["input_ids"]
    labels = input_ids.clone()     # start fully supervised

    # --- Mask HUMAN tokens (offset-walking) ---
    for i, segs in enumerate(all_segments):

        # Assuming processor adds 1 BOS token
        offset = 1
        seq_len = labels.size(1)

        for seg in segs:
            seg_ids = processor.tokenizer(seg["text"], add_special_tokens=False).input_ids
            L = len(seg_ids)

            if L == 0:
                continue
            if offset >= seq_len:
                break

            end = min(offset + L, seq_len)

            if seg["is_human"]:
                labels[i, offset:end] = -100  # ignore user text

            offset += L

        # Also mask padding positions
        labels[i, inputs["attention_mask"][i] == 0] = -100

    out = {
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": inputs["attention_mask"],
        "pixel_values": inputs["pixel_values"],
    }
    if "image_sizes" in inputs:
        out["image_sizes"] = inputs["image_sizes"]

    return out

# ============================================================
# üîé DEBUG: Check if masking works
# ============================================================

from torch.utils.data import DataLoader

debug_dl = DataLoader(train_ds, batch_size=1, collate_fn=collate_fn)
dbg = next(iter(debug_dl))

print("\n=== DEBUG MASKING ===")
print("labels shape:", dbg["labels"].shape)
print("supervised token count:", (dbg["labels"] != -100).sum().item())
print("supervised fraction:", (dbg["labels"] != -100).sum().item() / dbg["labels"].numel())
print("======================\n")

# ============================================================
# 6Ô∏è‚É£ Load LLaVA and Apply LoRA
# ============================================================

print("üß† Loading LLaVA base model...")
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()

model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
    device_map={"": 0}
)

model.config.use_cache = False

# Freeze vision tower
for n,p in model.named_parameters():
    if "vision_tower" in n:
        p.requires_grad = False
print("üéØ Vision tower frozen")

# Make multi-modal projector trainable
mm_train = 0
for n,p in model.named_parameters():
    if "multi_modal_projector" in n:
        p.requires_grad = True
        mm_train += p.numel()
print("üîπ mm_projector trainable:", mm_train)

# LoRA targets for HF LLaVA
lconf = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=[
        "self_attn.q_proj",
        "self_attn.k_proj",
        "self_attn.v_proj",
        "self_attn.o_proj",
        "mlp.gate_proj",
        "mlp.up_proj",
        "mlp.down_proj",
    ],
)

model = get_peft_model(model, lconf)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"üßÆ Trainable Params: {trainable} / {total} ({100*trainable/total:.4f}%)")

# ============================================================
# 7Ô∏è‚É£ Training Args
# ============================================================

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=2,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=3e-4,
    warmup_ratio=0.03,
    lr_scheduler_type="cosine",
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=20,
    save_steps=500,
    save_total_limit=3,
    remove_unused_columns=False,
    fp16=(not use_bf16),
    bf16=use_bf16,
    report_to=["tensorboard"],
)

# ============================================================
# 8Ô∏è‚É£ Trainer
# ============================================================

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    tokenizer=tokenizer,
)

# ============================================================
# 9Ô∏è‚É£ Training
# ============================================================

ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint-*"))
resume_ckpt = sorted(ckpts)[-1] if ckpts else None

trainer.train(resume_from_checkpoint=resume_ckpt)

# ============================================================
# üîü Save Adapter
# ============================================================

trainer.save_model(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)
print("‚úÖ LoRA adapter saved:", ADAPTER_DIR)

# ============================================================
# 1Ô∏è‚É£1Ô∏è‚É£ PLOTS
# ============================================================

logs = trainer.state.log_history
train_steps, train_losses = [], []
eval_steps, eval_losses = [], []

for r in logs:
    if "loss" in r: 
        train_steps.append(r["step"])
        train_losses.append(r["loss"])
    if "eval_loss" in r:
        eval_steps.append(r["step"])
        eval_losses.append(r["eval_loss"])

plt.figure(figsize=(8,5))
plt.plot(train_steps, train_losses, label="train")
plt.plot(eval_steps, eval_losses, label="eval")
plt.legend(); plt.grid(); plt.xlabel("step"); plt.ylabel("loss")
plt.show()
