In [None]:
!pip install -U bitsandbytes



In [None]:
from dataclasses import dataclass
from typing import List, Dict, Optional

import torch
from torch.utils.data import Dataset

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig

In [None]:
from huggingface_hub import login
login()

In [None]:
ALLOWED_KEYWORDS = [
    # Core networking
    "tcp", "udp", "ip", "bgp", "ospf", "isis", "mpls", "vxlan", "evpn",
    "ethernet", "mac", "arp", "dhcp", "dns", "nat", "acl", "subnet",
    "routing", "switching", "firewall", "load balancer", "sdn", "segment routing",
    # Interconnection / peering / datacenter
    "datacenter", "leaf spine", "underlay", "overlay", "spine", "leaf",
    "rdma", "roce", "fabric", "top of rack", "tor", "edge", "colo",
    "interconnect", "peering", "ixp", "lag", "link aggregation", "qos",
    "throughput", "latency", "packet", "mtu", "jumbo frame", "flow control",
    "netflow", "sflow", "telemetry", "gre", "ipsec", "tls", "ssl", "anycast",
    "multicast", "unicast", "broadcast", "fhrp", "hsrp", "vrrp", "glbp",
    "l2", "l3", "layer 2", "layer 3", "sd-wan", "wan", "lan", "vxlan evpn",
]

REFUSAL_MESSAGE = (
    "I am restricted to Equinix Related stuff dude. "
    "Your request appears out of scope. Please rephrase within that domain."
)

def is_in_scope(text: str) -> bool:
    t = text.lower()
    hits = sum(1 for k in ALLOWED_KEYWORDS if k in t)
    # Threshold: at least 1 direct keyword, can tune later
    return hits > 0

In [None]:
@dataclass
class Sample:
    instruction: str
    response: str

class GuardDataset(Dataset):
    def __init__(self, samples: List[Sample], tokenizer: AutoTokenizer, max_len: int = 1024):
        self.samples = samples
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        # Simple instruction formatting (adjust to base model's chat template if needed)
        prompt = f"Instruction: {s.instruction}\nResponse:"
        full = prompt + " " + s.response
        tokens = self.tokenizer(
            full,
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_tensors="pt",
        )
        input_ids = tokens["input_ids"].squeeze(0)
        attention_mask = tokens["attention_mask"].squeeze(0)
        labels = input_ids.clone()
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


In [None]:
# ------------------------
# Synthetic Data Helpers (placeholder â€” replace with real curated data)
# ------------------------

import random

POSITIVE_EXAMPLES = [
    ("Explain the difference between OSPF and BGP.",
     "OSPF is a link-state IGP optimizing intra-domain convergence; BGP is a path-vector EGP handling inter-domain policy-based routing."),
    ("What is VXLAN EVPN used for in datacenters?",
     "VXLAN with EVPN control plane provides scalable L2/L3 segmentation over an IP underlay using BGP for MAC+IP advertisement."),
    ("How does a leaf-spine architecture improve scalability?",
     "It provides predictable latency and equal-cost multi-pathing (ECMP) between any two endpoints by using a non-blocking fabric."),
    ("Describe typical datacenter fabric telemetry approaches.",
     "Operators combine streaming telemetry (gNMI), flow data (NetFlow/sFlow/IPFIX), and in-band network telemetry for real-time visibility."),
]

NEGATIVE_PROMPTS = [
    "Write a poem about the sunrise.",
    "Explain quantum entanglement.",
    "Give a recipe for lasagna.",
    "Summarize a medieval history event.",
    "Discuss stock trading strategies.",
    "How to bake sourdough bread?",
    "Tell me a joke about cats.",
]

def build_training_samples(positive_scale: int = 1, negative_scale: int = 2) -> List[Sample]:
    samples: List[Sample] = []
    for inst, resp in POSITIVE_EXAMPLES * positive_scale:
        samples.append(Sample(inst, resp))
    # Negative become refusal responses
    for p in NEGATIVE_PROMPTS * negative_scale:
        samples.append(Sample(p, REFUSAL_MESSAGE))
    random.shuffle(samples)
    return samples

In [None]:
def load_tokenizer(model_name: str):
    tok = AutoTokenizer.from_pretrained(model_name)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return tok

def prepare_model(model_name: str, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        device_map="auto",
    )
    lora_cfg = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_cfg)
    model.print_trainable_parameters()
    return model

def train(
    model_name: str,
    output_dir: str,
    epochs: int = 2,
    batch_size: int = 4,
    lr: float = 2e-4,
    max_len: int = 1024,
):
    tokenizer = load_tokenizer(model_name)
    samples = build_training_samples()
    dataset = GuardDataset(samples, tokenizer, max_len=max_len)

    model = prepare_model(model_name)

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        warmup_ratio=0.05,
        learning_rate=lr,
        logging_steps=10,
        save_steps=200,
        save_total_limit=2,
        bf16=torch.cuda.is_available(),
        fp16=False,
        weight_decay=0.01,
        report_to=[],
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset,
        data_collator=data_collator,
    )

    trainer.train()
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)
    print("Training complete. Adapter saved at:", output_dir)

In [None]:
train(
    model_name="google/gemma-3-270m",
    output_dir="./guard_adapter-mini",
    epochs=2,
    batch_size=4,
    lr=2e-4,
    max_len=1024,
)

# Inference

In [None]:
import os
def load_tokenizer(model_name: str):
    tok = AutoTokenizer.from_pretrained(model_name)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return tok

class GuardedModel:
    def __init__(self, base_model: str, adapter_path: Optional[str] = None, temperature: float = 0.2):
        self.tokenizer = load_tokenizer(base_model)
        if adapter_path and os.path.exists(adapter_path):
            # Load adapter
            base = AutoModelForCausalLM.from_pretrained(
                base_model,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                device_map="auto",
            )
            self.model = PeftModel.from_pretrained(base, adapter_path)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(
                base_model,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                device_map="auto",
            )
        self.temperature = temperature

    def generate_guarded(self, prompt: str, max_new_tokens: int = 256) -> str:
        if not is_in_scope(prompt):
            return REFUSAL_MESSAGE
        formatted = f"Instruction: {prompt}\nResponse:"
        inputs = self.tokenizer(
            formatted,
            return_tensors="pt",
        )
        inputs = {k: v.to(self.model.device) for k in inputs.keys()}
        with torch.no_grad():
            out = self.model.generate(
                **inputs,
                do_sample=self.temperature > 0,
                temperature=self.temperature,
                max_new_tokens=max_new_tokens,
                pad_token_id=self.tokenizer.eos_token_id,
            )
        text = self.tokenizer.decode(out[0], skip_special_tokens=True)
        # Post-trim: extract after "Response:"
        if "Response:" in text:
            text = text.split("Response:", 1)[1].strip()
        # Safety fallback
        if not is_in_scope(text):
            return REFUSAL_MESSAGE
        return text

In [None]:
gm = GuardedModel(base_model="unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit", adapter_path="./guard_adapter-mini", temperature=0.2)

In [None]:
gm.generate_guarded("write a Shakespeare style poem")