## Import libs

In [2]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2

In [3]:
!pip install --quiet vllm scikit-learn tqdm
print("Required libraries unsloth, vllm, scikit-learn, and tqdm installed successfully.")

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.9/87.9 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.9/474.9 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m355.0/355.0 kB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.0/183.0 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.5/45.5 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m126.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.0/111.0 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.4/45.4 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Connect to drive for test dataset

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Load model and tokenizer

In [4]:
test_data_path = "drive/MyDrive/mipd_test.jsonl"
MAX_NEW_TOKENS = 256
max_seq_length = 16384
base_model_dir = "drive/MyDrive/bielik-4.5b-base"
TEST_ROWS = None # None for whole dataset
lora_path = "/content/drive/MyDrive/checkpoint-2686" #None for base model

In [None]:
from unsloth import FastLanguageModel
from google.colab import userdata
import torch

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = base_model_dir,
    max_seq_length = max_seq_length, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
    use_gradient_checkpointing = "unsloth",
    # Removed lora_path from here
)

In [7]:
# Load the LoRA adapter separately
if lora_path: # Check if lora_path is defined and not None
    model = FastLanguageModel.get_peft_model(
        model,
        # Unsloth will infer these if not specified, or you can explicitly define them if needed
        # for example: r=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        # lora_alpha=16, lora_dropout=0,
        # bias="none",
        # use_gradient_checkpointing="unsloth",
        # random_state=3407,
        # max_seq_length=max_seq_length,
    )
    model.load_adapter(lora_path, "original_dataset_adapter") # Added "lora" as the adapter_name

==((====))==  Unsloth 2026.1.3: Fast Llama patching. Transformers: 4.56.2. vLLM: 0.13.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


## Load and prepare dataset

In [8]:
import json
from datasets import load_dataset

# 1. Load the test dataset
dataset_test = load_dataset("json", data_files={'test': test_data_path})

# 3. Define a function named format_prompt
def format_prompt(example):
    # Combine instruction for system message and input for the user message
    system_instruction = example['instruction']
    user_message = example['input']

    # Construct the ChatML formatted prompt
    messages = [
        {"role": "system", "content": system_instruction},
        {"role": "user", "content": user_message},
    ]
    example['prompt'] = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    # 5. Parse the output field into a Python list of strings (ground truth tags)
    try:
        # CLEANING STEP: Remove Markdown formatting
        clean_json = example['output'].replace("```json", "").replace("```", "").strip()
        example['tags'] = json.loads(clean_json)['discovered_techniques']
    except json.JSONDecodeError:
        # Handle cases where output might not be perfectly valid JSON (e.g., during training data prep)
        example['tags'] = [] # Assign empty list if parsing fails
        print(f"Warning: Could not parse output: {example['output']}")

    # 6. Return the modified example
    return example


if(TEST_ROWS):
  small_test_dataset = dataset_test['test'].select(range(TEST_ROWS))
else:
  small_test_dataset = dataset_test['test']
# 7. Apply the format_prompt function to the loaded test dataset
original_columns = small_test_dataset.column_names
dataset_test_formatted = small_test_dataset.map(format_prompt, remove_columns=original_columns)

print("Formatted prompts and ground truth tags generated for the test dataset.")
print(dataset_test_formatted)

Generating test split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/1521 [00:00<?, ? examples/s]

Formatted prompts and ground truth tags generated for the test dataset.
Dataset({
    features: ['prompt', 'tags'],
    num_rows: 1521
})


## Define evaluation function
Metric 1: Parsing Success Rate (Did it output valid JSON?).

Metric 2: Format Correction Rate (How many invalid jsons were recovered?)

Metric 3: Classification Performance (If parsable, how accurate?).


In [9]:
import json
import re
from sklearn.metrics import f1_score

def evaluate_response(response_text: str, ground_truth_tags: list):
    """
    Evaluates response with support for Dict format {"discovered_techniques": []}
    and Markdown stripping.
    """
    parsed_tags = []
    parsing_status = 'Failed'

    # 0. Pre-processing: Strip Markdown (Crucial for Strict Success)
    # If we don't do this, valid JSON wrapped in ```json will fail strict parsing
    clean_text = response_text.replace("```json", "").replace("```", "").strip()

    # Attempt 1: Strict JSON parsing
    try:
        parsed_output = json.loads(clean_text)

        # CASE A: Output is the expected Dictionary
        if isinstance(parsed_output, dict):
            # Extract the specific key we trained on
            parsed_tags = parsed_output.get("discovered_techniques", [])
            # Check if the inner content is actually a list
            if not isinstance(parsed_tags, list):
                 # Try to force it if it's a string representation
                 parsed_tags = []
            parsing_status = 'Strict Success'

        # CASE B: Model outputted a raw List (unlikely but possible)
        elif isinstance(parsed_output, list):
            parsed_tags = parsed_output
            parsing_status = 'Strict Success'

        else:
            raise ValueError("Parsed output is not a Dict or List.")

    except (json.JSONDecodeError, ValueError):
        # Attempt 2: Regex-based correction
        # We look for the list explicitly
        match = re.search(r'\[(.*?)\]', clean_text, re.DOTALL)
        if match:
            extracted_content = f"[{match.group(1)}]"
            try:
                parsed_output_recovered = json.loads(extracted_content)
                if isinstance(parsed_output_recovered, list):
                    parsed_tags = parsed_output_recovered
                    parsing_status = 'Recovered'
            except (json.JSONDecodeError, ValueError):
                pass

    # --- F1 CALCULATION ---
    parsed_tags = [str(tag) for tag in parsed_tags if tag is not None]
    ground_truth_tags = [str(tag) for tag in ground_truth_tags if tag is not None]

    all_unique_tags = sorted(list(set(parsed_tags + ground_truth_tags)))

    if not all_unique_tags:
        f1 = 1.0
    elif not ground_truth_tags and parsed_tags:
        f1 = 0.0
    elif ground_truth_tags and not parsed_tags:
        f1 = 0.0
    else:
        y_true = [1 if tag in ground_truth_tags else 0 for tag in all_unique_tags]
        y_pred = [1 if tag in parsed_tags else 0 for tag in all_unique_tags]
        f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    return {
        'parsing_status': parsing_status,
        'parsed_tags': parsed_tags,
        'f1_score': f1,
    }

## Infer and evaluate all answers

In [None]:
from unsloth import FastLanguageModel
import torch
from tqdm import tqdm
import json

print(f"Loading Model with max_seq_length = {max_seq_length}...")

FastLanguageModel.for_inference(model) # Enable native 2x faster inference

# --- 2. SIMPLE INFERENCE LOOP (No Chunks!) ---
print("Starting Long-Context Inference...")
evaluation_results = []

for example in tqdm(dataset_test_formatted, desc="Processing"):
    prompt = example['prompt']
    ground_truth_tags = example['tags']

    # Tokenize
    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")

    # Generate
    # Since we have a massive context, we just feed the whole thing in.
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens = 256,
            use_cache = True,
            do_sample = False,
            temperature = 0.0,
             # Unsloth handles padding automatically usually, but being explicit is safe
            pad_token_id = tokenizer.pad_token_id
        )

    # Decode
    # Slice off the input prompt
    generated_ids = output_ids[:, inputs.input_ids.shape[1]:]
    raw_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # --- 3. STANDARD EVALUATION ---
    # Clean JSON
    if '}' in raw_output:
        raw_output = raw_output[:raw_output.find('}') + 1]

    # Evaluate
    result = evaluate_response(raw_output, ground_truth_tags)

    evaluation_results.append({
        'original_prompt_len': inputs.input_ids.shape[1], # Log length to verify it worked
        'ground_truth': ground_truth_tags,
        'predicted': result['parsed_tags'],
        'f1_score': result['f1_score'],
        'raw_output': raw_output
    })

    # CRITICAL: Clear cache after massive context to avoid OOM on next iteration
    torch.cuda.empty_cache()


print(f"Done. Processed {len(evaluation_results)} documents.")

Loading Model with max_seq_length = 16384...
Starting Long-Context Inference...


Processing:  43%|████▎     | 661/1521 [1:13:53<3:16:59, 13.74s/it]Unsloth: Input IDs of shape torch.Size([1, 18832]) with length 18832 > the model's max sequence length of 16384.
We shall truncate it ourselves. It's imperative if you correct this issue first.
Processing:  61%|██████▏   | 934/1521 [1:45:24<1:32:11,  9.42s/it]

In [None]:
# --- 4. REPORTING ---
print("\n" + "="*60)
print(f"INFERENCE REPORT: {len(evaluation_results)} documents")
print("="*60)

if evaluation_results:
    # 1. Calculate Global Metric
    avg_f1 = sum(r['f1_score'] for r in evaluation_results) / len(evaluation_results)
    print(f"\nGlobal Average F1 Score: {avg_f1:.4f}")

    # 2. Print Sample Table
    print("\n--- Sample Results (First 10) ---")
    print(f"{'F1 Score':<10} | {'Ground Truth':<30} | {'Predicted':<30}")
    print("-" * 80)

    for res in evaluation_results[:10]:
        # Truncate lists for cleaner printing
        gt_str = str(res['ground_truth'])
        pred_str = str(res['predicted'])
        gt_display = (gt_str[:27] + '..') if len(gt_str) > 27 else gt_str
        pred_display = (pred_str[:27] + '..') if len(pred_str) > 27 else pred_str

        print(f"{res['f1_score']:.4f}     | {gt_display:<30} | {pred_display:<30}")

    print("-" * 80)
else:
    print("No results generated.")