## Load model with PEFT

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType

LOAD_IN_BITS = False  # Set to True if you want to load in 8/4 bits

model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

#### Load in 8/4 bits if needed only
if LOAD_IN_BITS:
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_8bit=True,  # or load_in_4bit=True
        device_map="auto"
    )

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],  # may vary with model
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()



tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


## Set up data format for training

In [4]:
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType
import json
import json
import datasets

# === Load and flatten dataset ===
def load_flat_json(path):
    with open(path, "r", encoding="utf-8") as f:
        raw = json.load(f)
    return [list(e.values())[0] for e in raw]

# === Format dataset into prompt + completion style ===
def format_for_completion(example):
    incident = example.get("incident_description", "")
    logs = example.get("attack_logs", [])
    mitigations = example.get("ground_truth_mitigations", [])
    playbook = example.get("playbook", {})

    logs_text = "\n".join([
        f"- [{log['timestamp']}] {log['host']}: {log['action']} — {log['details']}"
        for log in logs if all(k in log for k in ["timestamp", "host", "action", "details"])
    ])

    mitig_text = "\n".join(mitigations)
    playbook_text = json.dumps(playbook, indent=2)

    full_text = f"""### Incident:
{incident}

### Logs:
{logs_text}

### Predicted Mitigations:
{mitig_text}

### Generated CACAO Playbook:
{playbook_text}"""

    return {"text": full_text}

# === Tokenizer and label masking ===
def tokenize_with_labels(example):
    tokens = tokenizer(example["text"], truncation=True, padding="max_length", max_length=2048)
    input_ids = tokens["input_ids"]

    # Find where target starts
    target_start = example["text"].find("### Predicted Mitigations:")
    if target_start == -1:
        target_start = 0

    prompt_ids = tokenizer(example["text"][:target_start], truncation=True, max_length=2048)["input_ids"]
    prompt_len = len(prompt_ids)

    labels = input_ids.copy()
    labels[:prompt_len] = [-100] * prompt_len

    tokens["labels"] = labels
    return tokens

# === Load and preprocess dataset ===
train_raw = load_flat_json("Dataset/Main/dataset_merged_train.json")
val_raw = load_flat_json("Dataset/Main/dataset_merged_val.json")

train_dataset = Dataset.from_list([format_for_completion(e) for e in train_raw])
val_dataset = Dataset.from_list([format_for_completion(e) for e in val_raw])

train_dataset = train_dataset.map(tokenize_with_labels)
val_dataset = val_dataset.map(tokenize_with_labels)


FileNotFoundError: [Errno 2] No such file or directory: 'Dataset/Main/dataset_merged_train.json'

## Tokenize with labels

Set up HF Trainer

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./llama3-cacao-checkpoints",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=500,
    evaluation_strategy="steps",
    eval_steps=500,
    fp16=True,  # or bf16=True
    save_total_limit=2,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
)
