In [6]:
# ============================================
# 🚀 Colab-Ready: C Bug Detector + Explainer (T5/FLAN + LoRA)
# Instruction-tuned; local metrics; robust decode; anti-echo generation
# ============================================

# ---------- STEP 0: Install compatible packages ----------
!pip install -q "transformers==4.42.4" "accelerate==0.33.0" "datasets==2.20.0" \
               "peft==0.11.1" "gradio==4.44.0" \
               "rouge-score==0.1.2" "bert-score==0.3.13"

# ---------- STEP 1: Imports, toggles & quiet mode ----------
import os, json, random, numpy as np, re, html, shutil, warnings
from typing import List, Dict, Any

import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    EarlyStoppingCallback,
    DataCollatorForSeq2Seq,
    set_seed,
)
from peft import LoraConfig, get_peft_model

# Remove rogue directories that can shadow metrics modules
for d in ("/content/rouge", "/content/bertscore"):
    if os.path.isdir(d):
        shutil.rmtree(d, ignore_errors=True)

# Silence deprecation chatter (safe)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message="`torch.cuda.amp.GradScaler", category=FutureWarning)

device = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); set_seed(SEED)
print("✅ Device:", device)

# ---- Model toggle ----
# Choose one:
USE_FLAN_BASE = False         # set True to use "google/flan-t5-base", else T5-small
MODEL_NAME = "google/flan-t5-base" if USE_FLAN_BASE else "t5-small"

# LRs suited for chosen base
LR = 8e-4 if USE_FLAN_BASE else 1e-3
EPOCHS = 8 if USE_FLAN_BASE else 5

# ---- Output format instruction (boosts detection accuracy) ----
INSTRUCTION = (
    "You are a C static analysis assistant. "
    "Given the code, respond with exactly one sentence that starts with either "
    "'Bug:' or 'No bug:'. If 'Bug:', briefly explain and include a fix. "
    "If 'No bug:', briefly justify why it is safe."
)

# ---------- STEP 2: Build a balanced synthetic dataset ----------
bug_templates = [
    ("int main(){int arr[3]; arr[3] = 7; return 0;}",
     "Bug: Out-of-bounds array access at arr[3]. Fix: Valid indices are 0–2."),
    ("char *s=NULL; *s='a';",
     "Bug: Dereferencing NULL pointer. Fix: Allocate memory before dereference."),
    ("int *p=malloc(sizeof(int)); if(p) return 0;",
     "Bug: Memory leak. Fix: Call free(p) on all successful allocation paths."),
    ("int x; if(x==1){printf(\"ok\");}",
     "Bug: Use of uninitialized variable x. Fix: Initialize x before use."),
    ("char buf[4]; strcpy(buf,\"test\");",
     "Bug: Buffer overflow (5 bytes into 4). Fix: Use strncpy with size or increase buffer."),
    ("int *p=malloc(sizeof(int)); free(p); free(p);",
      "Bug: Double free. Fix: Avoid freeing the same pointer twice; set to NULL after free."),
    ("char buf[8]; gets(buf);",
     "Bug: gets() is unsafe and can overflow. Fix: Use fgets with size limit."),
    ("int main(){for(int i=0;i<10;i--){printf(\"%d\",i);} }",
     "Bug: Infinite loop due to wrong update direction. Fix: Use i++ to reach termination."),
    ("FILE *f=fopen(\"file.txt\",\"r\"); fclose(f); fclose(f);",
     "Bug: Double fclose. Fix: Ensure fclose is called once per opened file."),
    ("int a=2147483647; int b=a+1;",
     "Bug: Signed integer overflow when adding 1 to INT_MAX. Fix: Use wider type or check bounds."),
    ("char *s; strcpy(s, \"hello\");",
     "Bug: Using strcpy on uninitialized pointer 's'. Fix: Allocate memory before use."),
    ("int *p; *p = 5;",
     "Bug: Dereferencing uninitialized pointer p. Fix: Allocate or point p to valid memory."),
    ("int a=10; int *p=&a; free(p);",
     "Bug: Freeing stack memory. Fix: Only free heap allocations."),
    ("int i; for(i=0;i<=10;i++){ arr[i]=i; }",
     "Bug: Off-by-one writing arr[10] if arr has size 10. Fix: Use i<10."),
    ("pthread_mutex_t m; int x=0; // write x in threads without locking",
     "Bug: Potential data race on shared variable x. Fix: Protect accesses with mutex lock/unlock."),
]

clean_templates = [
    ("int main(){int x=0; printf(\"%d\", x); return 0;}",
     "No bug: Code initializes and prints x correctly."),
    ("char s[6]; strcpy(s, \"hi\");",
     "No bug: Buffer is large enough for the copied string and null terminator."),
    ("int arr[3]={1,2,3}; for(int i=0;i<3;i++){ printf(\"%d\", arr[i]); }",
     "No bug: Proper array bounds in loop."),
    ("FILE *f=fopen(\"file.txt\",\"w\"); if(f){ fprintf(f,\"ok\"); fclose(f);} return 0;",
     "No bug: File opened, used, and closed safely."),
    ("int *p=(int*)malloc(sizeof(int)); if(p){ *p=5; free(p);} return 0;",
     "No bug: Heap memory allocated, used, and freed correctly."),
    ("char buf[8]; snprintf(buf, sizeof(buf), \"%s\", \"ok\");",
     "No bug: snprintf prevents overflow by size limiting."),
    ("struct S{int a;}; struct S s={.a=1}; printf(\"%d\", s.a);",
     "No bug: Struct is initialized before use."),
    ("int sum(int a,int b){return a+b;} int main(){printf(\"%d\",sum(1,2));}",
     "No bug: Simple addition function is correct."),
    ("for(int i=0;i<10;i++){ /* work */ }",
     "No bug: Loop bounds and update are correct."),
    ("const char* msg = \"hello\"; puts(msg);",
     "No bug: String literal used safely with puts."),
]

TOTAL_EXAMPLES = 6000  # can increase to 1000+ for better results if GPU time allows
half = TOTAL_EXAMPLES // 2
examples = []
PREFIX = "Find bug in this C code:"

for _ in range(half):
    code, expl = random.choice(bug_templates)
    resp = expl if expl.lower().startswith("bug:") else f"Bug: {expl}"
    prompt = f"{INSTRUCTION}\n\n{PREFIX}\n\n{code}"
    examples.append({"prompt": prompt, "response": resp})

for _ in range(half):
    code, expl = random.choice(clean_templates)
    resp = expl if expl.lower().startswith("no bug") else f"No bug: {expl}"
    prompt = f"{INSTRUCTION}\n\n{PREFIX}\n\n{code}"
    examples.append({"prompt": prompt, "response": resp})

random.shuffle(examples)
train_size = int(0.8 * len(examples))
train_data, test_data = examples[:train_size], examples[train_size:]

os.makedirs("c_bug_dataset", exist_ok=True)
with open("c_bug_dataset/c_bugs_train.json", "w") as f: json.dump(train_data, f, indent=2)
with open("c_bug_dataset/c_bugs_test.json", "w") as f: json.dump(test_data, f, indent=2)

print(f"✅ Dataset: {len(train_data)} train | {len(test_data)} test "
      f"| Test buggy={sum(d['response'].lower().startswith('bug:') for d in test_data)}")

# ---------- STEP 3: Model + LoRA ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
base_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)

# Optional: reduce memory & improve stability
try:
    base_model.gradient_checkpointing_enable()
except Exception:
    pass

lora_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.1, bias="none",
    target_modules=["q", "k", "v", "o"], task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()

# ---------- STEP 4: Preprocessing ----------
INPUT_MAX_LEN = 256
OUTPUT_MAX_LEN = 128

def preprocess_batch(batch: Dict[str, List[str]]) -> Dict[str, Any]:
    model_inputs = tokenizer(
        batch["prompt"], truncation=True, padding="max_length", max_length=INPUT_MAX_LEN
    )
    # Tokenize targets; mask pad with -100
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch["response"], truncation=True, padding="max_length", max_length=OUTPUT_MAX_LEN
        )
    labels_ids = [
        [(tid if tid != tokenizer.pad_token_id else -100) for tid in seq]
        for seq in labels["input_ids"]
    ]
    model_inputs["labels"] = labels_ids
    return model_inputs

train_ds = Dataset.from_list(train_data).map(preprocess_batch, batched=True, remove_columns=["prompt", "response"])
test_ds  = Dataset.from_list(test_data ).map(preprocess_batch, batched=True, remove_columns=["prompt", "response"])

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# ---------- STEP 5: Metrics (Local-only; no evaluate.load) ----------
from rouge_score import rouge_scorer
from bert_score import score as bert_score_fn

def _normalize(xs):
    return [re.sub(r"\s+", " ", x).strip() for x in xs]

class LocalRouge:
    def compute(self, predictions, references, use_stemmer=True):
        scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=use_stemmer)
        n = len(predictions)
        sums = {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
        for p, r in zip(predictions, references):
            s = scorer.score(r, p)  # (target, prediction)
            sums["rouge1"] += s["rouge1"].fmeasure
            sums["rouge2"] += s["rouge2"].fmeasure
            sums["rougeL"] += s["rougeL"].fmeasure
        return {k: v / max(n, 1) for k, v in sums.items()}

class LocalBERTScore:
    # Use a lighter backbone to avoid Roberta pooler warnings
    def __init__(self, model_type="microsoft/deberta-base-mnli"):
        self.model_type = model_type
    def compute(self, predictions, references, lang="en"):
        P, R, F1 = bert_score_fn(
            predictions, references, lang=lang,
            model_type=self.model_type, rescale_with_baseline=True
        )
        return {"precision": float(P.mean()), "recall": float(R.mean()), "f1": float(F1.mean())}

rouge = LocalRouge()
bertscore = LocalBERTScore(model_type="microsoft/deberta-base-mnli")

def _to_int_ids(arr, pad_id: int):
    """Map negatives to pad_id; ensure int dtype; return as list[list[int]]."""
    arr = np.asarray(arr[0] if isinstance(arr, tuple) else arr)
    if arr.dtype.kind not in "iu":
        arr = arr.astype(np.int64)
    arr = np.where(arr < 0, pad_id, arr)
    if arr.ndim == 1:
        arr = arr[None, :]
    return [[int(x) for x in row.tolist()] for row in arr]

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    pred_ids = _to_int_ids(preds, pad_id=tokenizer.pad_token_id)
    label_ids = _to_int_ids(labels, pad_id=tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    decoded_preds = _normalize(decoded_preds)
    decoded_labels = _normalize(decoded_labels)

    r = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    metrics = dict(r)
    bs = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")
    metrics["bertscore_f1"] = float(bs["f1"])

    # Derived detection accuracy
    def detect_flag(text: str) -> int:
        t = text.lower()
        if t.startswith("no bug") or "no bug" in t: return 0
        if "bug:" in t: return 1
        return 1 if ("fix" in t and "no bug" not in t) else 0

    y_pred = [detect_flag(p) for p in decoded_preds]
    y_true = [1 if lbl.lower().startswith("bug:") else 0 for lbl in decoded_labels]
    metrics["detection_accuracy"] = float((np.array(y_pred) == np.array(y_true)).mean().item())
    return metrics

# ---------- STEP 6: Training config (Seq2SeqTrainer) ----------
training_args = Seq2SeqTrainingArguments(
    output_dir="./gen_results",
    evaluation_strategy="epoch",     # deprecation warnings suppressed above
    save_strategy="epoch",
    learning_rate=LR,                 # auto-adjusted by toggle
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,    # effective batch size 16
    num_train_epochs=EPOCHS,          # longer for FLAN base
    weight_decay=0.01,
    fp16=torch.cuda.is_available(),
    predict_with_generate=True,
    generation_max_length=OUTPUT_MAX_LEN,
    logging_dir="./gen_logs",
    logging_steps=20,
    label_smoothing_factor=0.05,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="rougeL",
    greater_is_better=True,
    seed=SEED,
    save_total_limit=2,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

# ---------- STEP 7: Train & evaluate (safe fallback) ----------
def _norm(xs): return [re.sub(r"\s+", " ", x).strip() for x in xs]

def generate_explanation(code_or_prompt: str,
                         num_beams: int = 6,
                         max_len: int = 128,
                         no_repeat_ngram_size: int = 3,
                         repetition_penalty: float = 1.15,
                         length_penalty: float = 0.9,
                         min_new_tokens: int = 8,
                         force_prefix: bool = False) -> str:
    """
    Prefix-safe, HTML-unescaped, anti-echo generation.
    Set force_prefix=True to experiment with hard prefixing 'Bug:' or 'No bug:' (optional).
    """
    text = html.unescape(code_or_prompt or "").strip()
    if text.lower().startswith(PREFIX.lower()):
        prompt = f"{INSTRUCTION}\n\n{text}"
    else:
        prompt = f"{INSTRUCTION}\n\n{PREFIX}\n\n{text}"

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=INPUT_MAX_LEN)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    gen_kwargs = dict(
        max_length=max_len,
        num_beams=num_beams,
        no_repeat_ngram_size=no_repeat_ngram_size,
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        min_new_tokens=min_new_tokens,
        early_stopping=True,
    )
    # Optional: experiment with forced prefix (commented out by default)
    # if force_prefix:
    #     bug_ids = tokenizer("Bug:", return_tensors="pt")["input_ids"][0].to(model.device)
    #     gen_kwargs["prefix_allowed_tokens_fn"] = lambda batch_id, sent: bug_ids.tolist()

    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def manual_eval():
    preds, refs = [], []
    for ex in test_data:
        out = generate_explanation(ex["prompt"], num_beams=6, max_len=OUTPUT_MAX_LEN)
        preds.append(out); refs.append(ex["response"])
    preds_n, refs_n = _norm(preds), _norm(refs)
    r = rouge.compute(predictions=preds_n, references=refs_n)
    metrics = dict(r)
    bs = bertscore.compute(predictions=preds_n, references=refs_n, lang="en")
    metrics["bertscore_f1"] = float(bs["f1"])
    def detect_flag(t: str) -> int:
        t = t.lower()
        if t.startswith("no bug") or "no bug" in t: return 0
        if "bug:" in t: return 1
        return 1 if ("fix" in t and "no bug" not in t) else 0
    y_pred = [detect_flag(p) for p in preds_n]
    y_true = [1 if r.lower().startswith("bug:") else 0 for r in refs_n]
    metrics["detection_accuracy"] = float((np.array(y_pred) == np.array(y_true)).mean().item())
    return metrics

try:
    trainer.train()
    eval_res = trainer.evaluate()
    print("📊 Evaluation:", json.dumps(eval_res, indent=2))
except Exception as e:
    print("⚠️ Training-time evaluation failed; doing manual post-training eval.\n", repr(e))
    eval_res = manual_eval()
    print("📊 Manual Evaluation:", json.dumps(eval_res, indent=2))

# ---------- STEP 8: Spot-check 10 samples ----------
print("\n🔎 Spot-check predictions vs references")
for i, ex in enumerate(test_data[:10]):
    pred = generate_explanation(ex["prompt"], num_beams=6, max_len=OUTPUT_MAX_LEN)
    print(f"\n[{i}] PRED: {pred}\n    REF:  {ex['response']}")

# ---------- STEP 9: Sanity sample (code-only to avoid double prefix) ----------
sample_prompt = test_data[0]["prompt"]
sample_code_only = re.sub(r"(?i)^.*find bug in this c code:\s*", "", sample_prompt).strip()
print("\n🔹 Sample test code:\n", sample_code_only)
print("\n💡 Model output:\n", generate_explanation(sample_code_only))

# ---------- STEP 10: Save artifacts (LoRA adapters + optional merged model) ----------
ADAPTER_DIR = "./gen_lora_adapter"
os.makedirs(ADAPTER_DIR, exist_ok=True)
model.save_pretrained(ADAPTER_DIR)
tokenizer.save_pretrained(ADAPTER_DIR)
print(f"💾 Saved LoRA adapters to: {ADAPTER_DIR}")

MERGED_DIR = "./gen_merged_model"
try:
    merged = model.merge_and_unload()
    merged.save_pretrained(MERGED_DIR)
    tokenizer.save_pretrained(MERGED_DIR)
    print(f"💾 Saved merged model to: {MERGED_DIR}")
except Exception as e:
    print("⚠️ Could not merge LoRA into base model; continue with adapters.\n", e)

# ---------- STEP 11: Gradio UI ----------
import gradio as gr
def ui_predict(text):
    return generate_explanation(text, num_beams=6, max_len=OUTPUT_MAX_LEN)

demo = gr.Interface(
    fn=ui_predict,
    inputs=gr.Textbox(lines=12, placeholder="Paste C code OR a full prompt like 'Find bug in this C code: ...'"),
    outputs="text",
    title="C Bug Detector + Explainer (T5/FLAN + LoRA)",
    description="Detects if there is a bug and explains the fix. Instruction-tuned; accepts code or full prompt.",
)
print("\n🌐 Launching Gradio…")
demo.launch(share=True)


✅ Device: cuda
✅ Dataset: 4800 train | 1200 test | Test buggy=585
trainable params: 589,824 || all params: 61,096,448 || trainable%: 0.9654


Map:   0%|          | 0/4800 [00:00<?, ? examples/s]



Map:   0%|          | 0/1200 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Bertscore F1,Detection Accuracy
1,1.5632,1.199999,0.764621,0.691207,0.760538,0.749344,0.8275
2,1.0295,0.99047,1.0,1.0,1.0,1.0,1.0
3,0.9557,0.943066,1.0,1.0,1.0,1.0,1.0
4,0.937,0.930508,1.0,1.0,1.0,1.0,1.0


📊 Evaluation: {
  "eval_loss": 0.9904703497886658,
  "eval_rouge1": 1.0,
  "eval_rouge2": 1.0,
  "eval_rougeL": 1.0,
  "eval_bertscore_f1": 1.0,
  "eval_detection_accuracy": 1.0,
  "eval_runtime": 77.0921,
  "eval_samples_per_second": 15.566,
  "eval_steps_per_second": 1.946,
  "epoch": 4.0
}

🔎 Spot-check predictions vs references

[0] PRED: No bug: Simple addition function is correct.
    REF:  No bug: Simple addition function is correct.

[1] PRED: No bug: String literal used safely with puts.
    REF:  No bug: String literal used safely with puts.

[2] PRED: Bug: Infinite loop due to wrong update direction. Fix: Use i++ to reach termination.
    REF:  Bug: Infinite loop due to wrong update direction. Fix: Use i++ to reach termination.

[3] PRED: Bug: gets() is unsafe and can overflow. Fix: Use fgets with size limit.
    REF:  Bug: gets() is unsafe and can overflow. Fix: Use fgets with size limit.

[4] PRED: Bug: Off-by-one writing arr[10] if arr has size 10. Fix: Use i10.
    REF: 

--------


Running on public URL: https://1a17cdd1f06fde7dbc.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


