# Data Generation

In [None]:
import pandas as pd
from datetime import datetime, timedelta
import os
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
import torch
from tqdm import tqdm
import json

model_name = 'Qwen/Qwen3-0.6B'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to('cuda')
generator = TextGenerationPipeline(model=model, tokenizer=tokenizer, device=0)

# Example data
template_data = {
    "date": [f"2024-05-{str(day).zfill(2)}" for day in range(1, 32)],
    "hivenode01_memory_usage": [0.383553, 0.418137, 0.492768, 0.565933, 0.563794, 0.530966, 0.542266, 0.49038, 0.495031, 0.563613, 
                                0.495269, 0.618042, 0.536223, 0.613919, 0.508449, 0.576679, 0.521551, 0.622268, 0.611154, 0.518589, 
                                0.654723, 0.529004, 0.634645, 0.537238, 0.534131, 0.536508, 0.543266, 0.575944, 0.53557, 0.611579, 0.551645],
    "hivenode01_cpu_load_5min": [8.53, 12.07, 9.09, 8.55, 12.5, 7.98, 8.93, 7.85, 7.22, 8.98, 
                                 9.34, 9.6, 10.06, 8.68, 9.03, 8.08, 9.02, 7.89, 7.6, 11.22, 
                                 9.58, 9.53, 8.96, 10.89, 6.7, 9.73, 9.5, 10.5, 9.11, 8.75, 6.51],
    "hivenode01_cpu_load_10min": [9.12, 9.94, 9.3, 8.8, 11.66, 9.5, 8.63, 8.67, 7.77, 8.18, 
                                  8.99, 8.14, 9.02, 8.62, 8.54, 8.71, 8.91, 8.47, 7.66, 9.91, 
                                  8.67, 8.72, 8.77, 9.19, 7.58, 9.44, 8.89, 9.83, 8.32, 8.91, 8.17],
    "hivenode01_cpu_load_15min": [9.49, 9.07, 9.6, 7.92, 9.79, 9.62, 8.36, 9.32, 8.07, 7.91, 
                                  8.99, 8.12, 8.4, 8.46, 8.25, 8.73, 9.04, 8.54, 8.02, 9.58, 
                                  8.62, 8.84, 8.31, 8.68, 8.11, 9.01, 8.3, 8.95, 8.64, 8.83, 8.73],
    "hivenode02_memory_usage": [0.342719, 0.374206, 0.404097, 0.427319, 0.435087, 0.43736, 0.437884, 0.438484, 0.447014, 0.450261, 
                                0.456589, 0.457304, 0.461043, 0.459707, 0.456962, 0.460743, 0.470204, 0.466149, 0.474047, 0.476227, 
                                0.477294, 0.479754, 0.482177, 0.471509, 0.475243, 0.47675, 0.487594, 0.490805, 0.492136, 0.500023, 0.502623],
    "hivenode02_cpu_load_5min": [6.51, 8.39, 10.08, 12.8, 12.48, 10.14, 12.26, 8.85, 5.73, 9.06, 
                                 7.78, 7.96, 7.47, 6.82, 5.79, 7.28, 6.23, 7.19, 6.14, 7.79, 
                                 6.3, 6.57, 7.69, 6.68, 6, 6.59, 7.59, 9.85, 7.96, 9.24, 7.76],
    "hivenode02_cpu_load_10min": [6.27, 8.1, 10.61, 11.74, 12.44, 10.47, 11.24, 9.63, 7.08, 8.45, 
                                  7.22, 6.18, 6.94, 7, 6.41, 7.03, 6.72, 6.58, 6.04, 8.39, 
                                  6.34, 6.72, 6.91, 6.78, 6.35, 6.98, 6.84, 9.35, 8.14, 9.07, 7.54],
    "hivenode02_cpu_load_15min": [6.74, 7.63, 10.64, 10.35, 11.11, 10.6, 10.01, 9.69, 7.04, 8.02, 
                                  6.8, 6.12, 6.77, 7.13, 6.7, 7.07, 7.05, 6.9, 6.66, 8.08, 
                                  7, 7.39, 6.79, 7.14, 6.82, 7.05, 7.04, 8.52, 8, 8.46, 7.58],
    "hivenode03_memory_usage": [0.348359, 0.375609, 0.40562, 0.428764, 0.439686, 0.441799, 0.445848, 0.442726, 0.455341, 0.451364, 
                                0.458707, 0.458583, 0.463021, 0.474306, 0.468443, 0.46892, 0.47674, 0.473441, 0.477693, 0.477164, 
                                0.478697, 0.480722, 0.4821, 0.486745, 0.489345, 0.488128, 0.493571, 0.493586, 0.494228, 0.502349, 0.503208],
    "hivenode03_cpu_load_5min": [8.1, 9.61, 7.42, 8.05, 10.37, 7.72, 11.72, 7.67, 6.14, 9.78, 
                                 7.17, 7.31, 10.6, 6.07, 5.34, 8.05, 6.25, 6.13, 6.1, 8.68, 
                                 10.09, 7.21, 7.5, 7.04, 7.07, 8.59, 10.04, 7.41, 6.68, 12.27, 6.57],
    "hivenode03_cpu_load_10min": [7.6, 8.99, 8.52, 8.7, 10.56, 8.82, 9.46, 7.53, 7.36, 8.75, 
                                  7.09, 7.13, 9.56, 6.75, 6.73, 7.91, 6.71, 6.66, 7.11, 8.29, 
                                  8.12, 7, 6.84, 7.52, 7.84, 8.83, 8.43, 8.22, 7.42, 8.92, 7.05],
    "hivenode03_cpu_load_15min": [7.77, 8.52, 8.85, 8.12, 9.56, 9.05, 8.9, 8.19, 7.57, 8.34, 
                                  7.3, 7.36, 8.66, 7.17, 7.57, 7.68, 7.33, 7.05, 7.46, 8.48, 
                                  7.77, 7.61, 7.16, 7.84, 8.22, 8.38, 8.05, 8.11, 7.66, 8.29, 7.49],
    "hivenode04_memory_usage": [0.406619, 0.419903, 0.433678, 0.441757, 0.452172, 0.455543, 0.4598, 0.460577, 0.459536, 0.461846, 
                                0.465693, 0.474606, 0.478739, 0.470002, 0.465657, 0.467247, 0.473648, 0.473736, 0.478283, 0.482861, 
                                0.483788, 0.488319, 0.489811, 0.485238, 0.493503, 0.487962, 0.496605, 0.49354, 0.490805, 0.485683, 0.487579],
    "hivenode04_cpu_load_5min": [8, 10.23, 11.02, 7.19, 14.19, 9.73, 8, 7.59, 5.16, 7.4, 
                                 6.75, 11.49, 7.46, 6.14, 5.92, 8.56, 8.31, 5.56, 5.3, 8.13, 
                                 9.03, 7.09, 6.19, 9.83, 7.2, 6.55, 8.41, 8.59, 8.73, 7.54, 8.48],
    "hivenode04_cpu_load_10min": [6.98, 9.55, 9.47, 6.98, 10.99, 9.18, 7.47, 8.12, 6.7, 7.94, 
                                  6.55, 8.76, 7.3, 6.44, 7.34, 8.53, 8.74, 5.8, 6.15, 7.7, 
                                  8.45, 7, 6.08, 8.91, 7.49, 6.32, 8.62, 8.64, 9.39, 7.11, 8.12],
    "hivenode04_cpu_load_15min": [7.1, 8.69, 8.87, 7.09, 9.12, 8.57, 7.35, 8.13, 7.05, 7.85, 
                                  6.75, 7.59, 7.14, 6.73, 7.68, 8.18, 8.23, 6.5, 6.73, 7.81, 
                                  7.99, 7.58, 6.5, 8.22, 7.47, 6.78, 7.95, 8.17, 9.28, 7.25, 7.87],
    "hivenode05_memory_usage": [0.351445, 0.383579, 0.411094, 0.438981, 0.445999, 0.448852, 0.449878, 0.451535, 0.458583, 0.462726, 
                                0.466429, 0.469966, 0.471406, 0.469826, 0.481328, 0.483001, 0.476077, 0.489852, 0.49036, 0.485502, 
                                0.491012, 0.492307, 0.493674, 0.49574, 0.50239, 0.502286, 0.507123, 0.507206, 0.507957, 0.515078, 0.519858],
    "hivenode05_cpu_load_5min": [6.35, 10.45, 7.17, 7.47, 6.62, 7.21, 9.36, 7.46, 6.07, 7.37, 
                                 7.32, 7.54, 7.96, 7.12, 6.64, 7.61, 9.68, 7.25, 8.29, 9.44, 
                                 11.2, 7.9, 9.29, 8.73, 10.4, 6.47, 7.19, 11.26, 9.63, 7.66, 9.04],
    "hivenode05_cpu_load_10min": [6.28, 8.91, 7.31, 7.82, 7.7, 7.96, 7.92, 7.33, 6.2, 8.48, 
                                  6.69, 7.32, 7.58, 7.44, 6.99, 7.93, 7.84, 7.75, 7.74, 8.82, 
                                  10.24, 7.76, 7.84, 8.24, 8.31, 7.4, 7.27, 9.23, 8.78, 7.81, 8.21],
    "hivenode05_cpu_load_15min": [6.8, 8.32, 7.67, 7.49, 7.61, 8.17, 7.36, 7.52, 6.51, 7.85, 
                                  7.04, 7.15, 7.65, 7.48, 7.28, 7.67, 7.78, 8.08, 7.52, 8.79, 
                                  9.57, 7.87, 7.51, 8.07, 7.82, 7.67, 7.49, 8.74, 8.16, 8.06, 7.79],
    "hivenode06_memory_usage": [0.36604, 0.269375, 0.307352, 0.364142, 0.173979, 0.139898, 0.192349, 0.218628, 0.207623, 0.156441, 
                                0.164164, 0.205858, 0.18286, 0.156721, 0.154286, 0.194161, 0.153307, 0.17226, 0.154831, 0.158268, 
                                0.206682, 0.174523, 0.213316, 0.193095, 0.21925, 0.20669, 0.219678, 0.257834, 0.30988, 0.199916, 0.160641],
    "hivenode06_cpu_load_5min": [3.42, 3.33, 2.71, 2.62, 3.36, 3.65, 2.88, 3, 3.22, 3.28, 
                                 2.28, 3.02, 3.02, 2.89, 3.22, 3.55, 3.16, 3.36, 3.44, 2.8, 
                                 2.95, 2.28, 3.14, 2.43, 3.47, 2.78, 2.93, 3.51, 2.69, 2.55, 3.2],
    "hivenode06_cpu_load_10min": [3.09, 3.3, 3.01, 2.83, 3.27, 3.3, 3.22, 2.9, 3.19, 3.34, 
                                  2.94, 3.04, 3.19, 3.11, 3.57, 3.37, 3.31, 3.19, 3.12, 3.04, 
                                  3.05, 2.67, 3.22, 2.84, 3.73, 3.18, 3.1, 3.21, 3.08, 3.09, 3.39],
    "hivenode06_cpu_load_15min": [3.02, 3.19, 3.03, 2.87, 3.27, 3.17, 3.19, 2.94, 3.12, 3.12, 
                                  3.06, 3.01, 3.18, 3.38, 3.54, 3.31, 3.35, 3.09, 3.03, 3.12, 
                                  3.01, 2.76, 3.19, 2.87, 3.45, 3.13, 3.14, 3.13, 3.04, 3.22, 3.33],
    "hivenode07_memory_usage": [0.381104, 0.141959, 0.15592, 0.144891, 0.183825, 0.221047, 0.255345, 0.270316, 0.25585, 0.255843, 
                                0.186189, 0.215183, 0.259902, 0.310836, 0.185046, 0.190949, 0.17576, 0.179073, 0.207616, 0.205337, 
                                0.248384, 0.240304, 0.201106, 0.180418, 0.189253, 0.195343, 0.205485, 0.193422, 0.213659, 0.200243, 0.198259],
    "hivenode07_cpu_load_5min": [2.87, 2.82, 2.36, 2.04, 2.55, 2.15, 2.05, 2.06, 2.34, 2.16, 
                                 2.96, 2.44, 2.46, 2.43, 2.34, 2.76, 2.37, 2.58, 2.23, 2.6, 
                                 2.36, 2.35, 2.13, 2.58, 2.24, 2.66, 2.89, 2.58, 2.12, 3.02, 2.59],
    "hivenode07_cpu_load_10min": [2.41, 2.74, 2.3, 2.23, 2.55, 2.46, 2.54, 2.23, 2.37, 2.36, 
                                  2.7, 2.6, 2.57, 2.32, 2.67, 2.5, 2.62, 2.64, 2.25, 2.46, 
                                  2.78, 2.58, 2.36, 2.64, 2.39, 2.79, 2.56, 2.81, 2.43, 2.74, 2.75],
    "hivenode07_cpu_load_15min": [2.32, 2.75, 2.42, 2.36, 2.45, 2.48, 2.51, 2.27, 2.39, 2.43, 
                                  2.58, 2.56, 2.54, 2.48, 2.74, 2.52, 2.62, 2.59, 2.43, 2.47, 
                                  2.71, 2.67, 2.38, 2.56, 2.53, 2.64, 2.42, 2.84, 2.59, 2.66, 2.72]
}
template_df = pd.DataFrame(template_data)

# Date range
start_date = datetime(2024, 3, 1)
end_date = datetime(2025, 5, 31)

# List of nodes
nodes = ['hivenode01', 'hivenode02', 'hivenode03', 'hivenode04', 'hivenode05', 'hivenode06', 'hivenode07']

# Days in a month
def days_in_month(year, month):
    if month == 2:
        # Check for leap year
        if (year % 4 == 0 and year % 100 != 0) or (year % 400 == 0):
            return 29
        return 28
    return 31 if month in [1, 3, 5, 7, 8, 10, 12] else 30

current_date = start_date
num_data = []
text_data = []
total_num_samples = 0
total_text_samples = 0
total_num_abnormalities = 0
total_days = (end_date - start_date).days + 1

for _ in tqdm(range(total_days), desc="Processing dates"):
    days = days_in_month(current_date.year, current_date.month)

    if current_date.day == 1:
        text_data = []

    template_day_idx = (current_date.day - 1) % 31
    template_row = template_df.iloc[template_day_idx]
    
    for node in tqdm(nodes, desc="Processing nodes", leave=False):
        # Use template data directly
        memory_usage = template_row[f"{node}_memory_usage"]
        cpu_load_5min = template_row[f"{node}_cpu_load_5min"]
        cpu_load_10min = template_row[f"{node}_cpu_load_10min"]
        cpu_load_15min = template_row[f"{node}_cpu_load_15min"]

        # Determine status and additional fields
        if memory_usage > 0.7 or cpu_load_5min > 10 or cpu_load_10min > 10 or cpu_load_15min > 10:
            status = "abnormal"
            total_num_abnormalities += 1
            
            # Calculate alert severity
            alert_severity = "CRITICAL" if memory_usage > 0.9 or cpu_load_5min > 15 else "HIGH" if memory_usage > 0.8 or cpu_load_5min > 12 else "MEDIUM"
            
            # Determine likely root cause
            likely_root_cause = "high memory usage" if memory_usage > 0.7 else "high CPU load"
            
            # Determine impact scope
            impact_scope = "node" if node == "hivenode01" else "cluster"
            
            # Suggest remediation steps
            remediation_steps = "restart node, increase resources, optimize code"
            
            prompt = (
                f"You are Qwen3-0.6B, a language model for generating operation logs. "
                f"Generate a single log entry with these requirements:\n"
                f"1. Retain all numerical values unchanged.\n"
                f"2. Use ISO 8601 timestamp (to the second).\n"
                f"3. Include fields: log_level (INFO/WARN/ERROR), event_type, node_id, memory_usage, cpu_load_5min, cpu_load_10min, cpu_load_15min, status, "
                f"alert_severity, likely_root_cause, impact_scope, remediation_steps.\n"
                f"4. For status=abnormal: highlight issue with severity, cause, impact, remediation.\n"
                f"5. Vary syntax, use domain terms (OOM, jitter, throughput), keep entry 50-150 characters.\n"
                f"6. Output only the log entry.\n\n"
                f"Input Data:\n"
                f"  timestamp: {current_date.strftime('%Y-%m-%dT%H:%M:%SZ')}\n"
                f"  node: {node}\n"
                f"  memory_usage: {memory_usage:.4f}\n"
                f"  cpu_load_5min: {cpu_load_5min:.2f}\n"
                f"  cpu_load_10min: {cpu_load_10min:.2f}\n"
                f"  cpu_load_15min: {cpu_load_15min:.2f}\n"
                f"  status: {status}\n"
                f"  alert_severity: {alert_severity}\n"
                f"  likely_root_cause: {likely_root_cause}\n"
                f"  impact_scope: {impact_scope}\n"
                f"  remediation_steps: {remediation_steps}\n"
            )
        else:
            status = "normal"
            prompt = (
                f"You are Qwen3-0.6B, a language model for generating operation logs. "
                f"Generate a single log entry with these requirements:\n"
                f"1. Retain all numerical values unchanged.\n"
                f"2. Use ISO 8601 timestamp (to the second).\n"
                f"3. Include fields: log_level (INFO/WARN/ERROR), event_type, node_id, memory_usage, cpu_load_5min, cpu_load_10min, cpu_load_15min, status.\n"
                f"4. For status=normal: emphasize stability and optimal performance.\n"
                f"5. Vary syntax, use domain terms (OOM, jitter, throughput), keep entry 50-150 characters.\n"
                f"6. Output only the log entry.\n\n"
                f"Input Data:\n"
                f"  timestamp: {current_date.strftime('%Y-%m-%dT%H:%M:%SZ')}\n"
                f"  node: {node}\n"
                f"  memory_usage: {memory_usage:.4f}\n"
                f"  cpu_load_5min: {cpu_load_5min:.2f}\n"
                f"  cpu_load_10min: {cpu_load_10min:.2f}\n"
                f"  cpu_load_15min: {cpu_load_15min:.2f}\n"
                f"  status: {status}\n"
            )

        # Generate text log
        generated = generator(prompt, max_new_tokens=150, num_return_sequences=1, do_sample=True, temperature=0.9)
        text = generated[0]['generated_text'].strip()
        text_data.append({
            "date": current_date.strftime("%Y-%m-%d"),
            "node": node,
            "text": text,
        })
        total_text_samples += 1

    if current_date.day == days or current_date == end_date:
        # Save text data
        year_month = current_date.strftime("%Y-%m")
        os.makedirs('data/DevOpsLogs', exist_ok=True)
        with open(f'data/DevOpsLogs/{year_month}.log', 'w', encoding='utf-8') as f:
            for record in text_data:
                json.dump(record, f, ensure_ascii=False)
                f.write('\n')
        text_data = []
    
    current_date += timedelta(days=1)

print(f'Total text samples generated: {total_text_samples}, Total abnormalities: {total_num_abnormalities}, Abnormality rate: {total_num_abnormalities/total_text_samples:.2%}')

torch.cuda.empty_cache()

Device set to use cuda:0
Processing dates:   1%|          | 1/182 [00:44<2:14:17, 44.52s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
Processing dates: 100%|██████████| 182/182 [2:04:02<00:00, 40.89s/it] 

Total num samples generated: 1274, Total num abnormalities: 244 , Num abnormality rate: 19.15%
Total text samples generated: 1274





# Data preprocessing

In [None]:
import pandas as pd
import os
from tqdm import tqdm
import json

# Merge num data
num_files = sorted([f for f in os.listdir('data/pre_train_num') if f.endswith('.csv')])
all_num_data = []
for file in tqdm(num_files, desc="Merging num data"):
    df = pd.read_csv(os.path.join('data/pre_train_num', file))
    all_num_data.append(df)
num_df = pd.concat(all_num_data, ignore_index=True)
num_df['date'] = pd.to_datetime(num_df['date'], format='%Y-%m-%d')
num_df = num_df.sort_values('date')
num_df.to_csv('data/pre_train_num/num_202412-202505.csv', index=False)

# Merge text data
text_files = sorted([f for f in os.listdir('data/pre_train_text') if f.endswith('.jsonl')])
all_text_data = []
for file in tqdm(text_files, desc="Merging text data"):
    with open(os.path.join('data/pre_train_text', file), 'r', encoding='utf-8') as f:
        records = [json.loads(line) for line in f]
        all_text_data.extend(records)
text_df = pd.DataFrame(all_text_data)
text_df['date'] = pd.to_datetime(text_df['date'], format='%Y-%m-%d')
text_df = text_df.sort_values('date')
# Save text data as JSONL
with open('data/pre_train_text/text_202412-202505.jsonl', 'w', encoding='utf-8') as f:
    for _, row in text_df.iterrows():
        record = row.to_dict()
        if isinstance(record['date'], pd.Timestamp):
            record['date'] = record['date'].strftime('%Y-%m-%d')
        json.dump(record, f, ensure_ascii=False)
        f.write('\n')

# Text train-test split (80/20)
split_idx = int(len(num_df) * 0.8)
train_num_df = num_df.iloc[:split_idx]
test_num_df = num_df.iloc[split_idx:]

split_idx_text = int(len(text_df) * 0.8)
train_text_df = text_df.iloc[:split_idx_text]
test_text_df = text_df.iloc[split_idx_text:]


# Save datasets as Parquet to reduce storage space
os.makedirs('data/traindata', exist_ok=True)
os.makedirs('data/testdata', exist_ok=True)
train_num_df.to_parquet('data/traindata/train_num.parquet')
test_num_df.to_parquet('data/testdata/test_num.parquet')
train_text_df.to_parquet('data/traindata/train_text.parquet')
test_text_df.to_parquet('data/testdata/test_text.parquet')

# Calculate abnormality for numerical data
nodes = ['hivenode01', 'hivenode02', 'hivenode03', 'hivenode04', 'hivenode05', 'hivenode06', 'hivenode07']
def count_abnormalities(df):
    total_samples = len(df) * len(nodes)
    total_abnormal = 0
    for node in nodes:
        abnormal = ((df[f"{node}_memory_usage"] > 0.7) | 
                    (df[f"{node}_cpu_load_5min"] > 10) | 
                    (df[f"{node}_cpu_load_10min"] > 10) | 
                    (df[f"{node}_cpu_load_15min"] > 10)).sum()
        total_abnormal += abnormal
    return total_samples, total_abnormal

# Numerical data statistics
train_num_samples, train_num_abnormal = count_abnormalities(train_num_df)
test_num_samples, test_num_abnormal = count_abnormalities(test_num_df)
train_abn_rate = train_num_abnormal / train_num_samples 
test_abn_rate = test_num_abnormal / test_num_samples 

print(f"Numerical data: train set {train_num_samples} samples, abnormalities {train_num_abnormal}, abnormality rate {train_abn_rate:.2%}")
print(f"Numerical data: test set {test_num_samples} samples, abnormalities {test_num_abnormal}, abnormality rate {test_abn_rate:.2%}")
print(f"Text data: train set {train_text_df.shape[0]} samples, test set {test_text_df.shape[0]} samples")

Merging num data: 100%|██████████| 6/6 [00:00<00:00, 705.18it/s]
Merging text data: 100%|██████████| 6/6 [00:00<00:00, 703.17it/s]

Numerical data: train set 1015 samples, abnormalities 33, abnormality rate 3.25%
Numerical data: test set 259 samples, abnormalities 8, abnormality rate 3.09%
Text data: train set 1019 samples, test set 255 samples





# LLM & Fine-Tunning -- Qwen2.5-0.5B

In [1]:
import pandas as pd
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from tqdm import tqdm
import re
import os
import gc
from  torch.cuda.amp import autocast

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

os.makedirs("models/qwen3-0.6B", exist_ok=True)
os.makedirs("models/best-qwen3-0.6B", exist_ok=True)

torch.cuda.empty_cache()

train_text_df = pd.read_parquet('data/traindata/train_text.parquet')
test_text_df = pd.read_parquet('data/testdata/test_text.parquet')

def extract_status(text):
    match = re.search(r"(normal|abnormal)", str(text), re.IGNORECASE)
    return 0 if match and match.group(0).lower() == "normal" else 1

train_text_df['label'] = train_text_df['text'].apply(extract_status)
test_text_df['label'] = test_text_df['text'].apply(extract_status)
print("Train label distribution:", train_text_df['label'].value_counts().to_dict())
print("Test label distribution:", test_text_df['label'].value_counts().to_dict())

model_name = 'Qwen/Qwen3-0.6B'
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load model with 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_type=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) 

#  LoRA(low-Rank Adaptation) configuration
peft_config = LoraConfig(
    r=2,
    lora_alpha=8,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)

for name, param in model.named_parameters():
    if "lora" in name:
        param.requires_grad = True
model.train()

def prepare_history_context(df, date, node, window_days=3):
    start_date = pd.to_datetime(date) - pd.Timedelta(days=window_days)
    history = df[(df['date'] >= start_date) & (df['date'] < date) & (df['node'] == node)]
    context = "\n".join(history['text'].tolist())
    return context if context else "No historical data."

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, max_length=150, window_days=3, days_ahead=1):
        self.df = df.sort_values(by='date')
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.window_days = window_days
        self.days_ahead = days_ahead
        self.dates = df['date'].unique()[window_days:]
        self.nodes = ['hivenode01', 'hivenode02', 'hivenode03', 'hivenode04', 'hivenode05', 'hivenode06', 'hivenode07']
        print(f"Dataset: {len(df['date'].unique())} unique dates, {len(self.dates)} dates after {window_days} days, {len(self.nodes)} nodes")
        if len(self.dates) == 0:
            print(f"Warning: No valid dates found after {window_days} days. Check dataset time range.")

    def __len__(self):
        return len(self.dates) * len(self.nodes) * self.days_ahead

    def __getitem__(self, idx):
        day_idx = idx % self.days_ahead
        node_idx = (idx // self.days_ahead) % len(self.nodes)
        date_idx = idx // (len(self.nodes) * self.days_ahead)
        date = pd.to_datetime(self.dates[date_idx])
        node = self.nodes[node_idx]
        future_date = date + pd.Timedelta(days=day_idx + 1)
        
        context = prepare_history_context(self.df, date, node)
        future_data = self.df[(self.df['date'] == future_date) & (self.df['node'] == node)]
        label = future_data['label'].iloc[0] if not future_data.empty else 0
        
        prompt = f"Past 3 days:\n{context}\nPredict day {day_idx + 1} status: normal or abnormal"
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        labels = input_ids.clone()
        label_start = len(self.tokenizer.encode(prompt.rsplit(":", 1)[0] + ":")) - 1
        labels[:label_start] = -100
        target_token = self.tokenizer.encode("normal" if label == 0 else "abnormal", add_special_tokens=False)
        target_token = target_token[-1] if target_token else self.tokenizer.pad_token_id
        labels[label_start:] = target_token
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

def custom_collate_fn(batch):
    return {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'labels': torch.stack([item['labels'] for item in batch])
    }

train_dataset = TextDataset(train_text_df, tokenizer, window_days=3, days_ahead=1)
test_dataset = TextDataset(test_text_df, tokenizer, window_days=3, days_ahead=1)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits[:, -1, :2], axis=-1)
    labels = [label[-1] for label in labels if label[-1] in [0, 1]]
    predictions = predictions[:len(labels)]
    return {
        'accuracy': accuracy_score(labels, predictions),
        'f1': f1_score(labels, predictions, zero_division=0),
        'precision': precision_score(labels, predictions, zero_division=0),
        'recall': recall_score(labels, predictions, zero_division=0)
    }

training_args = TrainingArguments(
    output_dir="models/qwen3-0.6B",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_strategy="epoch",
    logging_steps=1,
    warmup_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    gradient_accumulation_steps=16,
    report_to="none",
    fp16=True,
    eval_accumulation_steps=1
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=custom_collate_fn,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model("models/best-qwen3-0.6B")
tokenizer.save_pretrained("models/best-qwen3-0.6B")

def predict_future_status(model, tokenizer, context, days_ahead=1, max_length=128):
    model.eval()
    predictions = []
    probabilities = []
    prompt = f"Past 3 days:\n{context}\nPredict day 1 status: normal or abnormal"
    inputs = tokenizer(
        prompt,
        truncation=True,
        padding='max_length',
        max_length=max_length,
        return_tensors='pt'
    ).to('cuda')
    with torch.no_grad():
        with autocast():
            outputs = model(**inputs)
            logits = outputs.logits[:, -1, :]
            normal_token = tokenizer.encode("normal", add_special_tokens=False)[-1]
            abnormal_token = tokenizer.encode("abnormal", add_special_tokens=False)[-1]
            probs = torch.softmax(logits[:, [normal_token, abnormal_token]], dim=-1).cpu().numpy()[0]
            pred = int(np.argmax(probs))
            predictions.append(pred)
            probabilities.append(probs[1])
    del inputs, outputs, logits, probs
    torch.cuda.empty_cache()
    gc.collect()
    return predictions, probabilities

test_dates = test_text_df['date'].unique()[3:]
nodes = ['hivenode01', 'hivenode02', 'hivenode03', 'hivenode04', 'hivenode05', 'hivenode06', 'hivenode07']
predictions = []
true_labels = []
probabilities = []

with tqdm(total=len(test_dates) * len(nodes), desc="Predicting future status") as pbar:
    for date in test_dates:
        date = pd.to_datetime(date)
        for node in nodes:
            context = prepare_history_context(test_text_df, date, node)
            if context != "No historical data.":
                pred, prob = predict_future_status(model, tokenizer, context)
                predictions.extend(pred)
                probabilities.extend(prob)
                future_date = date + pd.Timedelta(days=1)
                future_data = test_text_df[(test_text_df['date'] == future_date) & (test_text_df['node'] == node)]
                if not future_data.empty:
                    true_labels.append(future_data['label'].iloc[0])
                    print(f"Node {node}, Day 1 (Date: {future_date}): Predicted={pred[0]}, Prob(abnormal)={prob[0]:.4f}, True={true_labels[-1]}")
                del pred, prob, future_data            
            del context
            torch.cuda.empty_cache()
            gc.collect()
            pbar.update(1)

min_len = min(len(predictions), len(true_labels))
predictions = predictions[:min_len]
true_labels = true_labels[:min_len]
probabilities = probabilities[:min_len]

qwen_metrics = {
    'accuracy': accuracy_score(true_labels, predictions),
    'f1': f1_score(true_labels, predictions, zero_division=0),
    'precision': precision_score(true_labels, predictions, zero_division=0),
    'recall': recall_score(true_labels, predictions, zero_division=0),
    'confusion_matrix': confusion_matrix(true_labels, predictions).tolist()
}

print("True labels:", pd.Series(true_labels).value_counts().to_dict())
print("Predictions:", pd.Series(predictions).value_counts().to_dict())
print("Qwen3-0.6B Metrics:")
print(f"Accuracy: {qwen_metrics['accuracy']:.2%}")
print(f"F1: {qwen_metrics['f1']:.4f}")
print(f"Precision: {qwen_metrics['precision']:.4f}")
print(f"Recall: {qwen_metrics['recall']:.4f}")
print(f"Confusion Matrix: {qwen_metrics['confusion_matrix']}")
print(f"Average abnormal probability: {np.mean(probabilities):.4f}")

torch.cuda.empty_cache()
gc.collect()

Train label distribution: {0: 1019}
Test label distribution: {0: 255}


  _ = torch.tensor([0], device=i)
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Dataset: 146 unique dates, 143 dates after 3 days, 7 nodes
Dataset: 37 unique dates, 34 dates after 3 days, 7 nodes


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.0,,,0.0,0.0,0.0


  avg = a.mean(axis, **keepdims_kw)
  ret = ret.dtype.type(ret / rcount)


KeyboardInterrupt: 

# Qwen3-1.7B + Fine-tunning + RAG