In [1]:
!pip3 -q install -U "transformers>=4.44.0" accelerate peft bitsandbytes datasets sentencepiece evaluate
!pip3 -q uninstall -y pyarrow
!pip3 install -U bitsandbytes
!pip3 -q install "pyarrow>=21,<22"
!pip3 -q install fastapi uvicorn nest-asyncio pyngrok



In [2]:
from google.colab import drive, files
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# --- Setup / Imports (run once)
import json
import re
import pathlib
import logging

import torch
from datasets import load_dataset
import transformers # Import transformers module directly
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig,
    # transformers, # Removed from here
 )
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    PeftModel,
 )

import asyncio
import os
from uvicorn.config import Config
from uvicorn.server import Server
import nest_asyncio
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from pyngrok import ngrok
from google.colab import userdata



logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

print('Transformers:', transformers.__version__)
print('Torch:', torch.__version__)

Transformers: 4.57.3
Torch: 2.9.0+cu126


In [None]:
src = pathlib.Path("/content/drive/MyDrive/local/azure_policy_dataset_clean.json")
out_openai = pathlib.Path("/content/drive/MyDrive/local/azure_policies_ft_chat.jsonl")   # for OpenAI chat FT
out_hf = pathlib.Path("/content/drive/MyDrive/local/azure_policies_ft_supervised.jsonl") # for HF SFT

param_hint_defaults = {
    # sensible defaults when discovered
    "effect":
     {"type":"String",
      "allowedValues":["Audit","Deny","Disabled"],
      "defaultValue":"Audit"},
}

def find_params_in_template(text):
    return sorted(set(re.findall(r"parameters\\('([^']+)'\\)", text or "")))

def ensure_parameters(properties):
    # collect referenced params
    blob = json.dumps(properties, separators=(",",":"))
    names = find_params_in_template(blob)
    parameters = properties.get("parameters", {})
    if not isinstance(parameters, dict): parameters = {}
    for name in names:
        if name not in parameters:
            parameters[name] = param_hint_defaults.get(name, {"type":"String"})
    properties["parameters"] = parameters

def to_policy_envelope(instruction, target_obj):
    try:
        props = target_obj.get("properties", {})
        # description
        if "description" not in props:
            props["description"] = instruction
        # parameters
        ensure_parameters(props)
        # keep mode stable if absent
        if "mode" not in props:
            props["mode"] = "All"
        return {"properties": props}
    except Exception as e:
        logger.error(f"Error processing instruction '{instruction}': {e}")
        return None

try:
    openai_rows = []
    hf_rows = []

    with src.open() as f:
        for line_num, line in enumerate(f, 1):
            if not line.strip(): continue
            try:
                row = json.loads(line)
                instr = row["instruction"].strip()
                tgt_raw = row["target"]
                tgt_obj = json.loads(tgt_raw)
                policy_obj = to_policy_envelope(instr, tgt_obj)
                if policy_obj is None:
                    continue

                # --- OpenAI chat FT format ---
                openai_rows.append({
                    "messages":[
                        {"role":"system","content":"You are an assistant that generates valid Azure Policy JSON. Return ONLY JSON."},
                        {"role":"user","content": f"Instruction: {instr}\nReturn a full Azure Policy object with properties.displayName, properties.description, properties.parameters (define any referenced parameters), and properties.policyRule (with if/then)."},
                        {"role":"assistant","content": json.dumps(policy_obj, ensure_ascii=False)}
                    ]
                })

                # --- HF supervised format (prompt/completion) ---
                prompt = (
                    "You are an assistant that generates Azure Policy JSON.\n"
                    f"Instruction: {instr}\n"
                    "Return ONLY a single JSON object with `properties` containing displayName, description, parameters, and policyRule (if/then)."
                )
                hf_rows.append({"prompt": prompt, "completion": json.dumps(policy_obj, ensure_ascii=False)})

            except json.JSONDecodeError as e:
                logger.warning(f"Skipping invalid JSON at line {line_num}: {e}")
            except KeyError as e:
                logger.warning(f"Missing key in row at line {line_num}: {e}")

    # write files
    with out_openai.open("w") as w:
        for r in openai_rows: w.write(json.dumps(r, ensure_ascii=False) + "\n")
    with out_hf.open("w") as w:
        for r in hf_rows: w.write(json.dumps(r, ensure_ascii=False) + "\n")

    logger.info(f"Wrote {len(openai_rows)} rows to {out_openai}")
    logger.info(f"Wrote {len(hf_rows)} rows to {out_hf}")

except FileNotFoundError as e:
    logger.error(f"File not found: {e}")
except Exception as e:
    logger.error(f"Unexpected error in data processing: {e}")

## Hugging Face

In [None]:
base_model_name = "Qwen/Qwen2.5-7B-Instruct"  # swap if needed

try:
    tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    MAX_LEN = 2048
    ds = load_dataset("json", data_files="/content/drive/MyDrive/local/azure_policies_ft_supervised.jsonl", split="train").train_test_split(test_size=0.05, seed=42)

    def tok_fn(r):
        # Train on completion only (mask prompt tokens to -100).
        # This avoids the model learning to "echo" the prompt.
        sep = "\n\n"
        prompt_text = f"{r['prompt']}{sep}"
        full_text = f"{prompt_text}{r['completion']}"

        enc = tok(full_text, truncation=True, max_length=MAX_LEN)
        prompt_enc = tok(prompt_text, truncation=True, max_length=MAX_LEN)
        prompt_len = len(prompt_enc.get("input_ids", []))

        labels = enc["input_ids"].copy()
        prompt_len = min(prompt_len, len(labels))
        labels[:prompt_len] = [-100] * prompt_len
        enc["labels"] = labels
        return enc

    ds_tok = ds.map(tok_fn, remove_columns=ds["train"].column_names)

    # NEW: BitsAndBytesConfig instead of load_in_4bit
    bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True,
                                 bnb_4bit_compute_dtype=torch.bfloat16)

    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=bnb_cfg,   # <-- new style
        device_map="auto"
    )

    model.config.use_cache = False

    # Prepare for k-bit training (adds input grads, etc.)
    model = prepare_model_for_kbit_training(model)

    peft_cfg = LoraConfig(
        r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
        target_modules=["q_proj","k_proj","v_proj","o_proj"], task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, peft_cfg)

    model.gradient_checkpointing_enable()

    # Improved: eval_strategy, early stopping, and logging
    args = TrainingArguments(
        output_dir="/content/drive/MyDrive/local/azure-policy-qlora",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=16,
        learning_rate=2e-4,
        max_steps=-1,                       # or set a small max_steps for smoke test
        logging_steps=50,
        num_train_epochs=2,
        save_steps=500,
        eval_strategy="steps",         # Enable evaluation
        eval_steps=500,                # Evaluate every 500 steps
        save_strategy="steps",
        load_best_model_at_end=True,   # Load best model
        metric_for_best_model="eval_loss",  # Use eval loss for best model
        greater_is_better=False,       # Lower loss is better
        lr_scheduler_type="cosine",
        fp16=True,                          # T4 = FP16 (not bf16)
        optim="paged_adamw_8bit",           # memory-friendly optimizer
        report_to="none"  # Set to "wandb" if using Weights & Biases
    )

    trainer = Trainer(model=model, args=args, train_dataset=ds_tok["train"], eval_dataset=ds_tok["test"])
    trainer.train()

    model.save_pretrained("/content/drive/MyDrive/local/azure-policy-qlora-adapter")
    tok.save_pretrained("/content/drive/MyDrive/local/azure-policy-qlora-adapter")
    logger.info("Saved to /content/drive/MyDrive/local/azure-policy-qlora-adapter")

except Exception as e:
    logger.error(f"Error during training: {e}")
    raise

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/4663 [00:00<?, ? examples/s]

Map:   0%|          | 0/246 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/663 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

Step,Training Loss,Validation Loss
500,0.13,0.180998


In [None]:
def evaluate_model(trainer, ds_tok):
    import evaluate

    try:
        results = trainer.evaluate(eval_dataset=ds_tok["test"])
        logger.info(f"Evaluation results: {results}")

        bleu = evaluate.load("bleu")
        predictions = []
        references = []

        for example in ds_tok["test"]:
            input_ids = example.get("input_ids")
            if not input_ids:
                continue
            attention_mask = example.get("attention_mask")
            if not attention_mask:
                attention_mask = [1] * len(input_ids)

            input_ids_t = torch.tensor([input_ids], device=model.device)
            attn_t = torch.tensor([attention_mask], device=model.device)

            with torch.inference_mode():
                out = model.generate(
                    input_ids=input_ids_t,
                    attention_mask=attn_t,
                    max_new_tokens=256,
                    do_sample=False,
                    temperature=0.0,
                    pad_token_id=tok.eos_token_id,
                )

            pred = tok.decode(out[0], skip_special_tokens=True)
            predictions.append(pred)

            # Labels may contain -100 where prompt was masked. Decode only the target tokens.
            labels = example.get("labels") or []
            ref_ids = [t for t in labels if isinstance(t, int) and t != -100]
            ref = tok.decode(ref_ids, skip_special_tokens=True)
            references.append([ref])

        bleu_score = bleu.compute(predictions=predictions, references=references)
        logger.info(f"BLEU score: {bleu_score}")
    except Exception as e:
        logger.error(f"Evaluation error: {e}")

# Call after training
evaluate_model(trainer, ds_tok)

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

ERROR:__main__:Evaluation error: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


In [None]:
def extract_last_json(text: str) -> dict:
    """Extract the last complete JSON object from model output.

    This parser is lightweight but JSON-aware enough to ignore braces inside quoted strings.
    """
    if not isinstance(text, str) or not text:
        raise ValueError("Empty model output")

    starts = [i for i, ch in enumerate(text) if ch == "{"]
    if not starts:
        raise ValueError("No '{' found in model output.")

    last_good = None
    for s in starts:
        depth = 0
        in_string = False
        escape = False
        for i in range(s, len(text)):
            ch = text[i]
            if in_string:
                if escape:
                    escape = False
                    continue
                if ch == "\\":
                    escape = True
                    continue
                if ch == '"':
                    in_string = False
                continue

            # not in string
            if ch == '"':
                in_string = True
                continue
            if ch == "{":
                depth += 1
                continue
            if ch == "}":
                depth -= 1
                if depth == 0:
                    candidate = text[s : i + 1]
                    try:
                        obj = json.loads(candidate)
                        last_good = obj
                    except Exception:
                        pass
                    break
        # if depth never returned to 0, ignore this start index
    if last_good is None:
        raise ValueError("Could not parse any complete JSON object.")
    return last_good

def _is_effectively_empty_if(if_block) -> bool:
    if not isinstance(if_block, dict):
        return True
    if if_block == {}:
        return True
    if "allOf" in if_block and isinstance(if_block["allOf"], list) and len(if_block["allOf"]) == 0:
        return True
    if "anyOf" in if_block and isinstance(if_block["anyOf"], list) and len(if_block["anyOf"]) == 0:
        return True
    return False

def _suggest_if_from_instruction(instruction: str | None):
    s = (instruction or "").strip().lower()
    if not s:
        return None

    # Storage Accounts: public network access
    if "storage" in s and ("public network" in s or "publicnetworkaccess" in s):
        return {
            "allOf": [
                {"field": "type", "equals": "Microsoft.Storage/storageAccounts"},
                {"field": "Microsoft.Storage/storageAccounts/publicNetworkAccess", "notEquals": "Disabled"},
            ],
        }

    # Storage Accounts: secure transfer required / HTTPS only
    if "storage" in s and ("https" in s or "https-only" in s or "https only" in s or "secure transfer" in s):
        return {
            "allOf": [
                {"field": "type", "equals": "Microsoft.Storage/storageAccounts"},
                {"field": "Microsoft.Storage/storageAccounts/supportsHttpsTrafficOnly", "equals": False},
            ],
        }

    # Storage Accounts: disable public blob access
    if "storage" in s and ("blob" in s) and ("public" in s and "access" in s):
        return {
            "allOf": [
                {"field": "type", "equals": "Microsoft.Storage/storageAccounts"},
                {"field": "Microsoft.Storage/storageAccounts/allowBlobPublicAccess", "equals": True},
            ],
        }

    # Storage Accounts: Minimum TLS 1.2
    if "storage" in s and "tls" in s and ("1.2" in s or "tls1_2" in s or "tls 1.2" in s):
        return {
            "allOf": [
                {"field": "type", "equals": "Microsoft.Storage/storageAccounts"},
                {
                    "anyOf": [
                        {"field": "Microsoft.Storage/storageAccounts/minimumTlsVersion", "exists": False},
                        {"field": "Microsoft.Storage/storageAccounts/minimumTlsVersion", "notEquals": "TLS1_2"},
                    ]
                },
            ],
        }

    # Generic: required tag 'owner'
    if ("require" in s or "must" in s) and "tag" in s and ("owner" in s):
        return {
            "allOf": [
                {"field": "tags['owner']", "exists": False}
            ],
        }

    # App Configuration: CMK encryption (Key Vault key identifier must exist and be non-empty)
    if ("app configuration" in s or "appconfiguration" in s) and ("customer-managed key" in s or "cmk" in s or "key vault" in s) and "encryption" in s:
        return {
            "allOf": [
                {"field": "type", "equals": "Microsoft.AppConfiguration/configurationStores"},
                {
                    "anyOf": [
                        {"field": "Microsoft.AppConfiguration/configurationStores/encryption.keyVaultProperties.keyIdentifier", "exists": False},
                        {"field": "Microsoft.AppConfiguration/configurationStores/encryption.keyVaultProperties.keyIdentifier", "equals": ""},
                    ]
                },
            ],
        }

    return None

def check_and_fix(policy_json_or_dict, *, fallback_instruction: str | None = None, meta: dict | None = None):
    """Normalize model output into an Azure Policy object with `properties`.

    WIP features:
    - Migrates root-level fields into `properties`
    - Normalizes `then.effect` to `[parameters('effect')]`
    - If `policyRule.if` is missing or empty, attempts a small heuristic fill from the instruction
      (and records this in `meta` if provided).
    - Drops unexpected root-level keys (e.g. stray allOf/anyOf from the model).
    """
    if isinstance(policy_json_or_dict, str):
        policy = json.loads(policy_json_or_dict)
    else:
        policy = policy_json_or_dict

    if not isinstance(policy, dict):
        raise ValueError("Policy must be a JSON object/dict")

    meta = meta if isinstance(meta, dict) else None

    # --- If the model returned fields at the root, migrate them into `properties`
    root_keys = {"displayName", "description", "parameters", "policyRule", "mode"}
    if "properties" not in policy and any(k in policy for k in root_keys):
        policy["properties"] = {}
    if "properties" in policy and isinstance(policy["properties"], dict):
        props = policy["properties"]
    else:
        policy["properties"] = {}
        props = policy["properties"]

    for k in ["displayName", "description", "parameters", "policyRule", "mode"]:
        if k in policy and k not in props:
            props[k] = policy[k]
        if k in policy:
            del policy[k]

    if "effect" in policy:
        del policy["effect"]

    # --- Drop unexpected root keys (common model error: {'allOf': [], 'properties': {...}})
    allowed_root = {"properties", "name", "type", "id", "apiVersion"}
    for k in list(policy.keys()):
        if k not in allowed_root:
            del policy[k]

    # --- Validate required keys and fill defaults
    missing = []
    for k in ["displayName", "description", "policyRule"]:
        if k not in props:
            missing.append(k)

    pr = props.get("policyRule")
    if not isinstance(pr, dict):
        pr = {}
        props["policyRule"] = pr
        if "policyRule" not in missing:
            missing.append("policyRule")

    if "displayName" not in props:
        props["displayName"] = fallback_instruction or "Generated Azure Policy"
    if "description" not in props:
        props["description"] = fallback_instruction or "Auto-generated Azure Policy"

    # If 'if' is missing or empty-shell, try to infer a default from the instruction
    current_if = pr.get("if")
    empty_before = ("if" not in pr) or _is_effectively_empty_if(current_if)
    if empty_before:
        suggested_if = _suggest_if_from_instruction(fallback_instruction or props.get("displayName"))
        if suggested_if is not None:
            pr["if"] = suggested_if
            if meta is not None:
                meta["fallback_used"] = True
                meta["fallback_reason"] = "heuristic_if_from_instruction"
        else:
            pr.setdefault("if", {"allOf": []})
            if meta is not None:
                meta.setdefault("fallback_used", False)
    else:
        if meta is not None:
            meta.setdefault("fallback_used", False)

    if meta is not None:
        meta["empty_if_before_fix"] = bool(empty_before)

    if "then" not in pr:
        missing.append("policyRule.then")

    if missing:
        logger.warning("Policy missing fields (%s). Auto-filling defaults.", ", ".join(missing))

    if "then" not in pr:
        pr["then"] = {"effect": "[parameters('effect')]"}

    # --- Normalize then.effect
    then = pr.get("then")
    if isinstance(then, dict):
        effect_value = then.get("effect")
        if isinstance(effect_value, str):
            normalized = effect_value.strip()
            normalized_title = normalized[:1].upper() + normalized[1:].lower() if normalized else normalized
            if normalized_title in {"Audit", "Deny", "Disabled"}:
                then["effect"] = "[parameters('effect')]"
                default_effect = normalized_title
            else:
                default_effect = "Audit"
        else:
            default_effect = "Audit"
            then["effect"] = "[parameters('effect')]"
    else:
        default_effect = "Audit"
        pr["then"] = {"effect": "[parameters('effect')]"}

    # --- Ensure parameters referenced by template are defined
    text = json.dumps(policy)
    needed = set(re.findall(r"parameters\('([^']+)'\)", text))
    params = props.setdefault("parameters", {})
    if not isinstance(params, dict):
        props["parameters"] = {}
        params = props["parameters"]

    if "effect" in needed and "effect" not in params:
        params["effect"] = {
            "type": "String",
            "allowedValues": ["Audit", "Deny", "Disabled"],
            "defaultValue": default_effect,
        }
    elif "effect" in needed and isinstance(params.get("effect"), dict):
        params["effect"].setdefault("defaultValue", default_effect)

    props.setdefault("mode", "All")
    return policy

def normalize_fields(policy: dict):
    alias_map = {
        "Microsoft.AppConfiguration/configurationStores/keyVaultKeyUri":
        "Microsoft.AppConfiguration/configurationStores/encryption.keyVaultProperties.keyIdentifier",
    }
    s = json.dumps(policy)
    for a, b in alias_map.items():
        s = s.replace(a, b)
    return json.loads(s)

## Load base model + adapter for inference

In [None]:
base_model_name = "Qwen/Qwen2.5-7B-Instruct"
adapter_dir = "/content/drive/MyDrive/local/azure-policy-qlora-adapter"

tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
 )

base = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_cfg,
    device_map="auto",
 )
base.config.use_cache = False

model = PeftModel.from_pretrained(base, adapter_dir)
model.eval()

print('Loaded adapter:', adapter_dir)

In [None]:
nest_asyncio.apply()

app = FastAPI()

class GenerateRequest(BaseModel):
    instruction: str

@app.get('/health')
def health():
    model_loaded = 'model' in globals()
    tok_loaded = 'tok' in globals()
    device = None
    try:
        device = str(getattr(model, 'device', None)) if model_loaded else None
    except Exception:
        device = None
    return {
        'ok': True,
        'model_loaded': model_loaded,
        'tokenizer_loaded': tok_loaded,
        'cuda_available': torch.cuda.is_available(),
        'device': device,
    }

def _build_prompt(instruction: str, *, strict: bool = True, feedback: str | None = None) -> str:
    base = [
        "You are an assistant that generates valid Azure Policy JSON.",
        "Return ONLY one JSON object. No markdown. No explanations.",
        "",
        "Required schema:",
        "{",
        '  "properties": {',
        '    "displayName": "string",',
        '    "description": "string",',
        '    "parameters": {',
        '      "effect": {"type":"String","allowedValues":["Audit","Deny","Disabled"],"defaultValue":"Audit"}',
        '    },',
        '    "policyRule": {"if": { ... }, "then": {"effect": "[parameters(\'effect\')]"}},',
        '    "mode": "All"',
        "  }",
        "}",
        "",
    ]
    if strict:
        base.extend([
            "Constraints:",
            "- properties.policyRule.if MUST contain at least one real condition (do NOT return {\"allOf\": []} or an empty object).",
            "- Include the target resource type when relevant (e.g., Storage Accounts -> Microsoft.Storage/storageAccounts).",
            "",])
    if feedback:
        base.extend([f"Feedback: {feedback}", ""])
    base.append(f"Instruction: {instruction}")
    return "\n".join(base)

def _generate_raw(prompt: str) -> str:
    inputs = tok(prompt, return_tensors='pt').to(model.device)
    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=900,
            do_sample=False,
            temperature=0.0,
            pad_token_id=tok.eos_token_id,
        )
    return tok.decode(out[0], skip_special_tokens=True)

@app.post('/generate')
def generate(req: GenerateRequest):
    instruction = (req.instruction or '').strip()
    if not instruction:
        return {"error": "instruction is empty"}

    if 'model' not in globals() or 'tok' not in globals():
        return {
            'error': 'model/tokenizer not loaded. Run the model load cell before starting the API.',
            'hint': 'Run the notebook cells in order until the adapter is loaded into `model` and `tok` exists.',
        }

    # First attempt
    prompt1 = _build_prompt(instruction, strict=True)
    raw1 = _generate_raw(prompt1)
    meta1 = {}
    try:
        raw_policy1 = extract_last_json(raw1)
        fixed1 = check_and_fix(raw_policy1, fallback_instruction=instruction, meta=meta1)
        fixed1 = normalize_fields(fixed1)
    except Exception as e:
        return {"error": f"Failed to parse/validate model output: {e}", "raw_output": raw1}

    # Retry once if still empty-shell after fix
    try:
        if_block = fixed1.get('properties', {}).get('policyRule', {}).get('if')
    except Exception:
        if_block = None

    if _is_effectively_empty_if(if_block):
        prompt2 = _build_prompt(
            instruction,
            strict=True,
            feedback="Your previous output had an empty properties.policyRule.if. Regenerate with at least 1 concrete condition.",
        )
        raw2 = _generate_raw(prompt2)
        meta2 = {}
        try:
            raw_policy2 = extract_last_json(raw2)
            fixed2 = check_and_fix(raw_policy2, fallback_instruction=instruction, meta=meta2)
            fixed2 = normalize_fields(fixed2)
        except Exception as e:
            return {
                'raw_output': raw1,
                'raw_policy': raw_policy1,
                'fixed_policy': fixed1,
                'policy': fixed1,
                'meta': meta1,
                'retry': True,
                'retry_error': str(e),
                'retry_raw_output': raw2,
            }

        return {
            'raw_output': raw2,
            'raw_policy': raw_policy2,
            'fixed_policy': fixed2,
            'policy': fixed2,
            'meta': meta2,
            'retry': True,
        }

    return {
        'raw_output': raw1,
        'raw_policy': raw_policy1,
        'fixed_policy': fixed1,
        'policy': fixed1,
        'meta': meta1,
        'retry': False,
    }

# --- Start ngrok tunnel
token = userdata.get('NGROK_AUTH_TOKEN')
if not token:
    raise RuntimeError('Missing NGROK_AUTH_TOKEN. Add it in Colab Secrets and restart the runtime.')
ngrok.set_auth_token(token)
public_url = ngrok.connect(8000).public_url
print('\nPUBLIC API URL (paste into local Gradio app):')
print(public_url)
print('\nHealth check:')
print(f"curl {public_url}/health")
print('\nTest with:')
json_data_for_curl = {'instruction':'Disallow public network access on storage accounts'}
json_payload_for_curl = json.dumps(json_data_for_curl)
print(f"curl -X POST {public_url}/generate -H 'Content-Type: application/json' -d '{json_payload_for_curl}'")

# Run the API in the background using the existing event loop
config = Config(app, host='0.0.0.0', port=8000, log_level='info')
server = Server(config)

async def run_server_in_background():
    await server.serve()

loop = asyncio.get_event_loop()
if not loop.is_running():
    loop.run_until_complete(run_server_in_background())
else:
    loop.create_task(run_server_in_background())

print('Uvicorn server started in the background.')

### Output


PUBLIC API URL (paste into local Gradio app):
https:/xxxxxx.ngrok-free.app

Health check:
curl https://xxxxxx.ngrok-free.app/health

Test with:
curl -X POST https://xxxxxx.ngrok-free.app/generate -H 'Content-Type: application/json' -d '{"instruction": "Disallow public network access on storage accounts"}'
Uvicorn server started in the background.