In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, AutoModel, AutoConfig,
    TrainingArguments, Trainer,
    DataCollatorForTokenClassification,
    get_cosine_schedule_with_warmup
)
import json
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from typing import List, Dict, Tuple, Optional
import logging
from pathlib import Path
import re
from collections import defaultdict, Counter
import pickle
from dataclasses import dataclass
import seaborn as sns
import matplotlib.pyplot as plt
import os
import math

# Install pytorch-crf if needed
# pip install pytorch-crf

try:
    from torchcrf import CRF
    CRF_AVAILABLE = True
except ImportError:
    CRF_AVAILABLE = False
    print("Warning: pytorch-crf not installed. CRF layer will be disabled.")
    print("Install with: pip install pytorch-crf")

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

Install with: pip install pytorch-crf


In [4]:
os.environ['WANDB_DISABLED'] = 'true'
os.environ['WANDB_MODE'] = 'disabled'

In [6]:
class FocalLoss(nn.Module):
    """Focal Loss implementation to handle class imbalance"""
    def __init__(self, alpha=1.0, gamma=2.0, ignore_index=-100, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ignore_index = ignore_index
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, ignore_index=self.ignore_index, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class LabelSmoothingLoss(nn.Module):
    """Label Smoothing Loss for regularization"""
    def __init__(self, num_classes, smoothing=0.1, ignore_index=-100):
        super(LabelSmoothingLoss, self).__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.confidence = 1.0 - smoothing

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.num_classes - 1))
            mask = target != self.ignore_index
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            true_dist = true_dist * mask.unsqueeze(1).float()

        return torch.mean(torch.sum(-true_dist * pred, dim=-1))

class IOCDataProcessor:
    """Enhanced IOC data processor for new dataset format"""

    def __init__(self):
        # Updated entity types to match new dataset
        self.entity_types = [
            # Technical IOCs
            'IP', 'Domain', 'URL', 'File', 'Email', 'Vulnerability',
            # Semantic Entities
            'Type', 'Device', 'Vendor', 'Version', 'Software',
            'Function', 'Platform', 'Malware', 'ThreatActor', 'Other'
        ]
        self.label_encoder = LabelEncoder()
        self.label_to_id = {}
        self.id_to_label = {}
        self.label_distribution = Counter()

        # Updated entity mapping for new dataset format
        self.entity_mapping = {
            'IP': 'IP',
            'Domain': 'DOMAIN',
            'URL': 'URL',
            'File': 'FILE',
            'Email': 'EMAIL',
            'Vulnerability': 'VULNERABILITY',
            'Type': 'TYPE',
            'Device': 'DEVICE',
            'Vendor': 'VENDOR',
            'Version': 'VERSION',
            'Software': 'SOFTWARE',
            'Function': 'FUNCTION',
            'Platform': 'PLATFORM',
            'Malware': 'MALWARE',
            'ThreatActor': 'THREATACTOR',
            'Other': 'OTHER'
        }

    def load_dataset(self, json_path: str) -> List[Dict]:
        """Load the enhanced IOC dataset"""
        logger.info(f"Loading dataset from {json_path}")
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        logger.info(f"Loaded {len(data)} samples")
        return data

    def find_entity_positions(self, text: str, entity_value: str) -> List[Tuple[int, int]]:
        """Find all positions of an entity value in text with improved matching"""
        positions = []
        text_lower = text.lower()
        entity_lower = entity_value.lower().strip()

        # Handle multi-word entities
        entity_clean = ' '.join(entity_lower.split())

        start = 0
        while True:
            pos = text_lower.find(entity_clean, start)
            if pos == -1:
                break

            # Verify word boundaries for better matching
            is_start_valid = pos == 0 or text[pos-1].isspace() or text[pos-1] in '.,;:()[]{}"\''
            end_pos = pos + len(entity_value)
            is_end_valid = end_pos >= len(text) or text[end_pos].isspace() or text[end_pos] in '.,;:()[]{}"\''

            if is_start_valid and is_end_valid:
                positions.append((pos, pos + len(entity_value)))

            start = pos + 1

        return positions

    def clean_entity_value(self, entity_value: str, entity_type: str) -> str:
        """Clean and normalize entity values"""
        if not entity_value or entity_value == "NULL":
            return None

        cleaned = entity_value.strip()

        # Remove common prefixes
        if ":" in cleaned:
            parts = cleaned.split(":", 1)
            if parts[0] in self.entity_types:
                cleaned = parts[1].strip()

        # Handle specific cleaning based on entity type
        if entity_type == "File":
            if cleaned.startswith("FileHash-"):
                if ":" in cleaned:
                    cleaned = cleaned.split(":", 1)[1].strip()

        # NEW: Remove trailing punctuation that shouldn't be part of entities
        if entity_type in ["Domain", "URL", "Email"]:
            cleaned = cleaned.rstrip('.,;:')

        # NEW: Handle multi-word entities better (especially for ThreatActor)
        if entity_type == "ThreatActor":
            # Keep full multi-word names
            cleaned = ' '.join(cleaned.split())

        return cleaned if cleaned else None

    def convert_new_format_to_bio(self, text: str, entities_dict: Dict) -> List[str]:
        """Convert new dataset format to BIO tagging format"""
        tokens = text.split()
        bio_tags = ['O'] * len(tokens)

        # Create character to token mapping
        char_to_token = {}
        char_pos = 0
        for token_idx, token in enumerate(tokens):
            while char_pos < len(text) and text[char_pos].isspace():
                char_pos += 1
            for i in range(len(token)):
                if char_pos + i < len(text):
                    char_to_token[char_pos + i] = token_idx
            char_pos += len(token)

        # Process each entity type
        all_entities = []
        for entity_type, entity_values in entities_dict.items():
            if entity_type not in self.entity_mapping:
                continue

            mapped_type = self.entity_mapping[entity_type]

            # Handle both list and single value formats
            if not entity_values or entity_values == "NULL" or entity_values == ["NULL"]:
                continue

            # Ensure entity_values is a list
            if not isinstance(entity_values, list):
                entity_values = [entity_values]

            for entity_value in entity_values:
                cleaned_value = self.clean_entity_value(entity_value, entity_type)

                if not cleaned_value:
                    continue

                # Find positions in text
                positions = self.find_entity_positions(text, cleaned_value)
                for start_char, end_char in positions:
                    all_entities.append({
                        'start_pos': start_char,
                        'end_pos': end_char,
                        'entity_type': mapped_type,
                        'entity_value': cleaned_value
                    })

        # Sort entities by start position to handle overlaps
        sorted_entities = sorted(all_entities, key=lambda x: x['start_pos'])

        # Convert to BIO tags with overlap handling
        used_tokens = set()
        for entity in sorted_entities:
            start_char = entity['start_pos']
            end_char = entity['end_pos']
            entity_type = entity['entity_type']

            # Find corresponding tokens
            start_token = None
            end_token = None

            # Find start token
            for offset in range(5):  # Look ahead up to 5 characters
                if start_char + offset in char_to_token:
                    start_token = char_to_token[start_char + offset]
                    break

            # Find end token
            for offset in range(5):  # Look back up to 5 characters
                if end_char - 1 - offset in char_to_token:
                    end_token = char_to_token[end_char - 1 - offset]
                    break

            if start_token is not None and end_token is not None:
                # Skip if tokens are already used (handle overlaps)
                if start_token in used_tokens:
                    continue

                if start_token == end_token:
                    if bio_tags[start_token] == 'O':
                        bio_tags[start_token] = f'B-{entity_type}'
                        used_tokens.add(start_token)
                else:
                    if bio_tags[start_token] == 'O':
                        bio_tags[start_token] = f'B-{entity_type}'
                        used_tokens.add(start_token)
                    for token_idx in range(start_token + 1, min(end_token + 1, len(bio_tags))):
                        if bio_tags[token_idx] == 'O' and token_idx not in used_tokens:
                            bio_tags[token_idx] = f'I-{entity_type}'
                            used_tokens.add(token_idx)

        return bio_tags

    def prepare_training_data(self, dataset: List[Dict]) -> Tuple[List[str], List[List[str]]]:
        """Prepare texts and BIO tags for training with label distribution analysis"""
        texts = []
        all_tags = []
        skipped = 0

        for sample in dataset:
            text = sample['text']
            entities_dict = sample['entities']

            # Skip very long texts
            if len(text.split()) > 450:
                skipped += 1
                continue

            bio_tags = self.convert_new_format_to_bio(text, entities_dict)

            # Update label distribution
            self.label_distribution.update(bio_tags)

            # Include samples with at least one entity OR with meaningful content
            if any(tag != 'O' for tag in bio_tags):
                texts.append(text)
                all_tags.append(bio_tags)

        # Log statistics
        logger.info(f"Skipped {skipped} samples (too long)")
        logger.info(f"Prepared {len(texts)} training samples")

        # Log label distribution
        logger.info("\nLabel Distribution Analysis:")
        total_labels = sum(self.label_distribution.values())
        for label, count in self.label_distribution.most_common():
            percentage = (count / total_labels) * 100
            logger.info(f"  {label}: {count} ({percentage:.2f}%)")

        return texts, all_tags

    def create_label_mappings(self, all_tags: List[List[str]]):
        """Create label to ID mappings"""
        unique_labels = set()
        for tags in all_tags:
            unique_labels.update(tags)

        unique_labels = sorted(list(unique_labels))
        self.label_to_id = {label: idx for idx, label in enumerate(unique_labels)}
        self.id_to_label = {idx: label for label, idx in self.label_to_id.items()}

        logger.info(f"\nCreated mappings for {len(unique_labels)} labels:")
        for label in unique_labels:
            logger.info(f"  {label}: {self.label_to_id[label]}")

        return unique_labels

    def get_class_weights(self, all_tags: List[List[str]]) -> torch.Tensor:
        """Calculate class weights for handling imbalanced data"""
        flat_tags = [tag for tags in all_tags for tag in tags]
        unique_labels = list(self.label_to_id.keys())
        label_counts = [flat_tags.count(label) for label in unique_labels]

        total_samples = len(flat_tags)
        weights = []

        for i, count in enumerate(label_counts):
            if count > 0:
                # NEW: More aggressive weighting for rare classes
                if unique_labels[i] == 'O':
                    weight = total_samples / (len(unique_labels) * count * 5)  # Changed from 3 to 5
                elif count < 100:  # NEW: Boost very rare classes
                    weight = total_samples / (len(unique_labels) * count * 0.2)  # Stronger boost
                else:
                    weight = total_samples / (len(unique_labels) * count * 0.5)
            else:
                weight = 1.0
            weights.append(weight)

        logger.info("\nClass weights calculated:")
        for label, weight in zip(unique_labels, weights):
            logger.info(f"  {label}: {weight:.4f}")

        return torch.FloatTensor(weights)

class IOCDataset(Dataset):
    """Enhanced PyTorch Dataset for IOC extraction"""

    def __init__(self, texts: List[str], tags: List[List[str]],
                 tokenizer, label_to_id: Dict, max_length: int = 512):
        self.texts = texts
        self.tags = tags
        self.tokenizer = tokenizer
        self.label_to_id = label_to_id
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        tags = self.tags[idx]

        # Tokenize
        encoding = self.tokenizer(
            text.split(),
            is_split_into_words=True,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Align labels with tokens
        word_ids = encoding.word_ids()
        aligned_labels = []

        for word_id in word_ids:
            if word_id is None:
                aligned_labels.append(-100)
            elif word_id >= len(tags):
                aligned_labels.append(self.label_to_id['O'])
            else:
                aligned_labels.append(self.label_to_id[tags[word_id]])

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(aligned_labels, dtype=torch.long)
        }

class EnhancedIOCExtractionModel(nn.Module):
    """Enhanced DeBERTa-v3-based IOC extraction model"""

    def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1,
                 use_crf: bool = False, class_weights: Optional[torch.Tensor] = None,
                 loss_type: str = 'focal'):
        super().__init__()
        self.deberta = AutoModel.from_pretrained(model_name)
        self.num_labels = num_labels
        self.use_crf = use_crf and CRF_AVAILABLE
        self.loss_type = loss_type

        hidden_size = self.deberta.config.hidden_size
        # Three-layer classifier with more capacity
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = nn.LayerNorm(hidden_size)

        # NEW: Larger intermediate layer
        self.classifier_1 = nn.Linear(hidden_size, hidden_size)  # Changed from hidden_size//2
        self.classifier_1_norm = nn.LayerNorm(hidden_size)
        self.classifier_1_dropout = nn.Dropout(dropout_rate)

        # NEW: Additional layer
        self.classifier_2 = nn.Linear(hidden_size, hidden_size // 2)
        self.classifier_2_norm = nn.LayerNorm(hidden_size // 2)
        self.classifier_2_dropout = nn.Dropout(dropout_rate)

        self.classifier_3 = nn.Linear(hidden_size // 2, num_labels)

        self._init_weights()

        if self.use_crf:
            self.crf = CRF(num_labels, batch_first=True)
            logger.info("CRF layer enabled")

        if class_weights is not None:
            self.register_buffer('class_weights', class_weights)
        else:
            self.class_weights = None

        if loss_type == 'focal':
            self.loss_fn = FocalLoss(alpha=1.0, gamma=2.0, ignore_index=-100)
        elif loss_type == 'label_smoothing':
            self.loss_fn = LabelSmoothingLoss(num_labels, smoothing=0.1, ignore_index=-100)
        else:
            self.loss_fn = nn.CrossEntropyLoss(weight=class_weights, ignore_index=-100)

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def _init_weights(self):
        """Initialize classifier weights using Xavier uniform"""
        nn.init.xavier_uniform_(self.classifier_1.weight)
        nn.init.zeros_(self.classifier_1.bias)
        nn.init.xavier_uniform_(self.classifier_2.weight)
        nn.init.zeros_(self.classifier_2.bias)
        nn.init.xavier_uniform_(self.classifier_3.weight)  # NEW
        nn.init.zeros_(self.classifier_3.bias)  # NEW

    def forward(self, input_ids, attention_mask, labels=None, **kwargs):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        sequence_output = self.layer_norm(sequence_output)
        sequence_output = self.dropout(sequence_output)

        hidden = self.classifier_1(sequence_output)
        hidden = F.gelu(hidden)
        hidden = self.classifier_1_norm(hidden)
        hidden = self.classifier_1_dropout(hidden)

        # NEW: Second hidden layer
        hidden = self.classifier_2(hidden)
        hidden = F.gelu(hidden)
        hidden = self.classifier_2_norm(hidden)
        hidden = self.classifier_2_dropout(hidden)

        logits = self.classifier_3(hidden)

        loss = None
        if labels is not None:
            if self.use_crf:
                mask = attention_mask.bool()
                crf_labels = labels.clone()
                crf_labels[labels == -100] = 0
                loss = -self.crf(logits, crf_labels, mask=mask, reduction='mean')
            else:
                loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1))

        return {'loss': loss, 'logits': logits}

    def decode_predictions(self, logits, attention_mask):
        """Decode predictions using CRF if available"""
        if self.use_crf:
            mask = attention_mask.bool()
            batch_size = mask.size(0)
            for i in range(batch_size):
                mask[i, 0] = True
            predictions = self.crf.decode(logits, mask=mask)
            return predictions
        else:
            return torch.argmax(logits, dim=-1).tolist()

# [Rest of the code remains the same: EnhancedIOCModelTrainer, EnhancedIOCModelInference, and main() function]
# The key changes are in IOCDataProcessor class

class EnhancedIOCModelTrainer:
    """Enhanced training class"""

    def __init__(self, model_name: str = "microsoft/deberta-v3-base"):
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.processor = IOCDataProcessor()
        self.model = None

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def train(self,
              dataset_path: str,
              output_dir: str = "enhanced_ioc_deberta_model",
              test_size: float = 0.2,
              batch_size: int = 8,
              num_epochs: int = 6,
              learning_rate: float = 2e-5,
              use_crf: bool = False,
              loss_type: str = 'focal'):
        """Train the enhanced IOC extraction model"""

        os.makedirs(output_dir, exist_ok=True)

        # Load and process data
        dataset = self.processor.load_dataset(dataset_path)
        texts, all_tags = self.processor.prepare_training_data(dataset)
        unique_labels = self.processor.create_label_mappings(all_tags)

        # Calculate class weights
        class_weights = self.processor.get_class_weights(all_tags)

        # Split data
        split_idx = int(len(texts) * (1 - test_size))
        train_texts, test_texts = texts[:split_idx], texts[split_idx:]
        train_tags, test_tags = all_tags[:split_idx], all_tags[split_idx:]

        logger.info(f"\nTrain samples: {len(train_texts)}, Test samples: {len(test_texts)}")

        # Create datasets
        train_dataset = IOCDataset(
            train_texts, train_tags, self.tokenizer, self.processor.label_to_id, max_length=512
        )
        test_dataset = IOCDataset(
            test_texts, test_tags, self.tokenizer, self.processor.label_to_id, max_length=512
        )

        # Initialize model
        self.model = EnhancedIOCExtractionModel(
            self.model_name,
            len(unique_labels),
            use_crf=use_crf,
            class_weights=class_weights,
            loss_type=loss_type
        )
        self.model.to(self.model.device)

        # Training arguments
        total_steps = len(train_dataset) // batch_size * num_epochs
        warmup_steps = int(0.1 * total_steps)

        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            learning_rate=learning_rate,
            weight_decay=0.01,
            warmup_steps=warmup_steps,
            logging_steps=50,
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            greater_is_better=False,
            remove_unused_columns=False,
            dataloader_pin_memory=False,
            report_to=None,
            fp16=torch.cuda.is_available(),
            gradient_accumulation_steps=2,
            adam_epsilon=1e-8,
            max_grad_norm=1.0,
            lr_scheduler_type="cosine",
            save_total_limit=3,
            dataloader_num_workers=2,
        )

        data_collator = DataCollatorForTokenClassification(
            tokenizer=self.tokenizer,
            pad_to_multiple_of=8,
            return_tensors="pt"
        )

        class EnhancedTrainer(Trainer):
            def log(self, logs, start_time=None):
                super().log(logs)
                if 'train_loss' in logs:
                    logger.info(f"Step {self.state.global_step}: Train Loss = {logs['train_loss']:.4f}")
                if 'eval_loss' in logs:
                    logger.info(f"Step {self.state.global_step}: Eval Loss = {logs['eval_loss']:.4f}")

        trainer = EnhancedTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            data_collator=data_collator,
            processing_class=self.tokenizer
        )

        logger.info(f"\nTraining Configuration:")
        logger.info(f"  Model: {self.model_name}")
        logger.info(f"  Loss Type: {loss_type}")
        logger.info(f"  Use CRF: {use_crf}")
        logger.info(f"  Batch Size: {batch_size}")
        logger.info(f"  Learning Rate: {learning_rate}")
        logger.info(f"  Epochs: {num_epochs}")

        logger.info("\nStarting training...")
        try:
            trainer.train()
            logger.info("Training completed successfully!")

            final_eval = trainer.evaluate()
            logger.info(f"Final Evaluation Loss: {final_eval['eval_loss']:.4f}")

        except Exception as e:
            logger.error(f"Training failed: {e}")
            raise

        # Save model
        trainer.save_model()
        self.tokenizer.save_pretrained(output_dir)

        # Save mappings
        with open(f"{output_dir}/label_mappings.pkl", 'wb') as f:
            pickle.dump({
                'label_to_id': self.processor.label_to_id,
                'id_to_label': self.processor.id_to_label,
                'entity_mapping': self.processor.entity_mapping,
                'model_name': self.model_name,
                'use_crf': use_crf,
                'loss_type': loss_type,
                'class_weights': class_weights.tolist(),
                'label_distribution': dict(self.processor.label_distribution)
            }, f)

        # Evaluate
        self.evaluate(test_texts, test_tags)

        logger.info(f"\nTraining complete! Model saved to {output_dir}")
        return trainer.state.log_history

    def evaluate(self, test_texts: List[str], test_tags: List[List[str]]):
        """Evaluate model with detailed metrics"""
        logger.info("\nEvaluating model...")

        predictions = []
        true_labels = []
        confidence_scores = []

        for text, tags in zip(test_texts, test_tags):
            try:
                pred_entities = self.predict(text)
                pred_tags = ['O'] * len(text.split())

                for entity in pred_entities:
                    start_token = entity.get('start_token', 0)
                    end_token = entity.get('end_token', start_token)
                    entity_type = entity['entity_type']
                    confidence = entity.get('confidence', 0.0)

                    if start_token < len(pred_tags) and start_token >= 0:
                        pred_tags[start_token] = f"B-{entity_type}"
                        confidence_scores.append(confidence)
                        for i in range(start_token + 1, min(end_token + 1, len(pred_tags))):
                            if i < len(pred_tags):
                                pred_tags[i] = f"I-{entity_type}"
                                confidence_scores.append(confidence)

                min_len = min(len(pred_tags), len(tags))
                predictions.extend(pred_tags[:min_len])
                true_labels.extend(tags[:min_len])

            except Exception as e:
                logger.warning(f"Error predicting: {e}")
                continue

        if not predictions or not true_labels:
            logger.error("No valid predictions generated")
            return

        try:
            print("\nClassification Report:")
            report = classification_report(true_labels, predictions, output_dict=True, zero_division=0)
            print(classification_report(true_labels, predictions, zero_division=0))

            logger.info("\nKey Metrics:")
            logger.info(f"  Overall Accuracy: {report['accuracy']:.4f}")
            logger.info(f"  Macro Avg F1: {report['macro avg']['f1-score']:.4f}")
            logger.info(f"  Weighted Avg F1: {report['weighted avg']['f1-score']:.4f}")

            if confidence_scores:
                avg_confidence = np.mean(confidence_scores)
                logger.info(f"  Average Confidence: {avg_confidence:.4f}")

        except Exception as e:
            logger.error(f"Error generating report: {e}")

    def predict(self, text: str, threshold: float = 0.8) -> List[Dict]:
        """Predict entities in text"""
        if not self.model:
            raise ValueError("Model not trained yet!")

        tokens = text.split()
        if not tokens:
            return []

        encoding = self.tokenizer(
            tokens,
            is_split_into_words=True,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        if 'token_type_ids' in encoding:
            del encoding['token_type_ids']

        word_ids = encoding.word_ids()
        encoding = {k: v.to(self.model.device) for k, v in encoding.items()}

        self.model.eval()
        with torch.no_grad():
            outputs = self.model(**encoding)

            if self.model.use_crf:
                predictions = self.model.decode_predictions(outputs['logits'], encoding['attention_mask'])
                probs = torch.nn.functional.softmax(outputs['logits'], dim=-1)
            else:
                probs = torch.nn.functional.softmax(outputs['logits'], dim=-1)
                predictions = torch.argmax(probs, dim=-1)

        entities = []
        current_entity = None

        if self.model.use_crf:
            pred_labels = predictions[0] if isinstance(predictions[0], list) else predictions
        else:
            pred_labels = predictions[0]
            probs = probs[0]

        for i, word_id in enumerate(word_ids):
            if word_id is None or word_id >= len(tokens) or i >= len(pred_labels):
                continue

            if self.model.use_crf:
                predicted_label_id = pred_labels[i]
                confidence = 0.9
            else:
                predicted_label_id = pred_labels[i].item()
                confidence = torch.max(probs[i]).item()

            predicted_label = self.processor.id_to_label[predicted_label_id]

            if predicted_label != 'O' and confidence > threshold:
                if predicted_label.startswith('B-'):
                    if current_entity:
                        entities.append(current_entity)

                    if word_id < len(tokens):
                        current_entity = {
                            'start_pos': word_id,
                            'end_pos': word_id + 1,
                            'entity_type': predicted_label[2:],
                            'entity_value': tokens[word_id],
                            'confidence': confidence,
                            'start_token': word_id,
                            'end_token': word_id
                        }
                elif predicted_label.startswith('I-') and current_entity:
                    entity_type = predicted_label[2:]
                    if entity_type == current_entity['entity_type'] and word_id < len(tokens):
                        current_entity['entity_value'] += ' ' + tokens[word_id]
                        current_entity['end_token'] = word_id
                        current_entity['confidence'] = min(current_entity['confidence'], confidence)

        if current_entity:
            entities.append(current_entity)

        return entities


class EnhancedIOCModelInference:
    """Enhanced production inference class"""

    def __init__(self, model_dir: str):
        self.model_dir = model_dir
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)

        # Load mappings
        with open(f"{model_dir}/label_mappings.pkl", 'rb') as f:
            mappings = pickle.load(f)
            self.label_to_id = mappings['label_to_id']
            self.id_to_label = mappings['id_to_label']
            self.entity_mapping = mappings.get('entity_mapping', {})
            self.model_name = mappings.get('model_name', 'microsoft/deberta-v3-base')
            self.use_crf = mappings.get('use_crf', False)
            self.loss_type = mappings.get('loss_type', 'focal')
            self.class_weights = torch.FloatTensor(mappings.get('class_weights', []))

        # Load model
        self.model = EnhancedIOCExtractionModel(
            model_name=self.model_name,
            num_labels=len(self.label_to_id),
            use_crf=self.use_crf,
            class_weights=self.class_weights,
            loss_type=self.loss_type
        )

        # Load weights
        model_path = os.path.join(model_dir, "pytorch_model.bin")
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        else:
            model_path = os.path.join(model_dir, "model.safetensors")
            if os.path.exists(model_path):
                from safetensors.torch import load_file
                self.model.load_state_dict(load_file(model_path))
            else:
                raise FileNotFoundError(f"No model weights found in {model_dir}")

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

    def extract_iocs(self, text: str, confidence_threshold: float = 0.75) -> Dict:
        """Extract IOCs with categorization"""
        tokens = text.split()
        if not tokens:
            return {'technical_iocs': [], 'semantic_entities': [], 'text': text}

        encoding = self.tokenizer(
            tokens,
            is_split_into_words=True,
            truncation=True,
            padding='max_length',
            max_length=512,
            return_tensors='pt'
        )

        if 'token_type_ids' in encoding:
            del encoding['token_type_ids']

        word_ids = encoding.word_ids()
        encoding = {k: v.to(self.device) for k, v in encoding.items()}

        with torch.no_grad():
            outputs = self.model(**encoding)

            if self.model.use_crf:
                predictions = self.model.decode_predictions(outputs['logits'], encoding['attention_mask'])
                probs = torch.nn.functional.softmax(outputs['logits'], dim=-1)
            else:
                probs = torch.nn.functional.softmax(outputs['logits'], dim=-1)
                predictions = torch.argmax(probs, dim=-1)

        # Extract entities
        entities = []
        current_entity = None

        if self.model.use_crf:
            pred_labels = predictions[0] if isinstance(predictions[0], list) else predictions
        else:
            pred_labels = predictions[0]
            probs = probs[0]

        # Create word_id to token position mapping
        word_id_to_token_pos = {}
        for i, word_id in enumerate(word_ids):
            if word_id is not None:
                if word_id not in word_id_to_token_pos:
                    word_id_to_token_pos[word_id] = []
                word_id_to_token_pos[word_id].append(i)

        processed_word_ids = set()

        for i, word_id in enumerate(word_ids):
            if word_id is None or word_id >= len(tokens) or word_id in processed_word_ids:
                continue

            processed_word_ids.add(word_id)
            token_positions = word_id_to_token_pos[word_id]
            first_pos = token_positions[0]

            if first_pos >= len(pred_labels):
                continue

            if self.model.use_crf:
                predicted_label_id = pred_labels[first_pos]
                confidence = 0.9
            else:
                predicted_label_id = pred_labels[first_pos].item()
                confidence = torch.max(probs[first_pos]).item()

            predicted_label = self.id_to_label[predicted_label_id]

            if predicted_label != 'O' and confidence > confidence_threshold:
                if predicted_label.startswith('B-'):
                    if current_entity:
                        entities.append(current_entity)

                    entity_type = predicted_label[2:]
                    current_entity = {
                        'entity_type': entity_type,
                        'entity_value': tokens[word_id],
                        'start_pos': len(' '.join(tokens[:word_id])) + (1 if word_id > 0 else 0),
                        'end_pos': len(' '.join(tokens[:word_id + 1])) + (1 if word_id > 0 else 0),
                        'confidence': confidence,
                        'category': self._get_entity_category(entity_type)
                    }
                elif predicted_label.startswith('I-') and current_entity:
                    entity_type = predicted_label[2:]
                    if entity_type == current_entity['entity_type']:
                        current_entity['entity_value'] += ' ' + tokens[word_id]
                        current_entity['end_pos'] = len(' '.join(tokens[:word_id + 1])) + (1 if word_id > 0 else 0)
                        current_entity['confidence'] = min(current_entity['confidence'], confidence)

        if current_entity:
            entities.append(current_entity)

        # Separate by category
        technical_iocs = [e for e in entities if e['category'] == 'technical_ioc']
        semantic_entities = [e for e in entities if e['category'] == 'semantic_entity']

        return {
            'technical_iocs': technical_iocs,
            'semantic_entities': semantic_entities,
            'text': text,
            'total_entities': len(entities)
        }

    def _get_entity_category(self, entity_type: str) -> str:
        """Categorize entity types"""
        technical_types = {'IP', 'DOMAIN', 'URL', 'FILE', 'EMAIL', 'VULNERABILITY'}
        if entity_type in technical_types:
            return 'technical_ioc'
        else:
            return 'semantic_entity'




In [7]:
def main():
    """Main training function"""

    dataset_path = "balanced_ioc_dataset.json"

    if not os.path.exists(dataset_path):
        logger.error(f"Dataset file {dataset_path} not found!")
        return

    logger.info("=" * 80)
    logger.info("Training Enhanced IOC Extraction Model with New Dataset Format")
    logger.info("=" * 80)

    try:
        trainer = EnhancedIOCModelTrainer(model_name="microsoft/deberta-v3-base")

        # Train with CRF and Focal Loss (best for imbalanced data)
        trainer.train(
            dataset_path=dataset_path,
            output_dir="enhanced_ioc_model_v2",
            test_size=0.2,
            batch_size=4,
            num_epochs=10,  # Changed from 6 to 10
            learning_rate=3e-5,  # Changed from 2e-5 to 3e-5
            use_crf=True,
            loss_type='focal'
        )

        # Test the trained model
        logger.info("\n" + "=" * 80)
        logger.info("Testing Model Inference")
        logger.info("=" * 80)

        inferencer = EnhancedIOCModelInference("enhanced_ioc_model_v2")

        # Test with sample text from new dataset format
        test_text = """
        This analysis explores the infrastructure of Laundry Bear, a Russian state-sponsored APT group active since April 2024,
        targeting NATO countries and Ukraine. The investigation reveals connections to IP addresses 154.216.18.83 and 104.36.83.170.
        Key domains include aficors.com, aoc-gov.us, and app-v4-mybos.com. The malware uses CVE-2024-38196 for privilege escalation.
        The attack leverages Windows components and targets Microsoft systems with malicious executables.
        """

        results = inferencer.extract_iocs(test_text, confidence_threshold=0.6)

        print(f"\n{'='*80}")
        print(f"EXTRACTION RESULTS")
        print(f"{'='*80}")
        print(f"Total entities extracted: {results['total_entities']}")
        print(f"Technical IOCs: {len(results['technical_iocs'])}")
        print(f"Semantic Entities: {len(results['semantic_entities'])}")

        if results['technical_iocs']:
            print(f"\n{'Technical IOCs':^80}")
            print("-" * 80)
            for ioc in results['technical_iocs']:
                print(f"  Type: {ioc['entity_type']:<15} | Value: {ioc['entity_value']:<40} | Conf: {ioc['confidence']:.3f}")

        if results['semantic_entities']:
            print(f"\n{'Semantic Entities':^80}")
            print("-" * 80)
            for entity in results['semantic_entities']:
                print(f"  Type: {entity['entity_type']:<15} | Value: {entity['entity_value']:<40} | Conf: {entity['confidence']:.3f}")

        print(f"\n{'='*80}")
        logger.info("Training and testing completed successfully!")

    except Exception as e:
        logger.error(f"Process failed: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Epoch,Training Loss,Validation Loss
1,1.1842,0.280259
2,0.5793,0.180486
3,0.3836,0.107141
4,0.2307,0.078622
5,0.1658,0.059735
6,0.1182,0.064203
7,0.0974,0.056653
8,0.0706,0.054517
9,0.0609,0.065021
10,0.0549,0.069033



Classification Report:
                 precision    recall  f1-score   support

       B-DEVICE       0.84      0.76      0.80        99
       B-DOMAIN       0.95      0.81      0.88       156
         B-FILE       0.00      0.00      0.00         1
     B-FUNCTION       0.98      0.97      0.98       343
           B-IP       0.80      0.80      0.80         5
      B-MALWARE       0.94      0.88      0.91       352
        B-OTHER       0.97      0.77      0.86       515
     B-PLATFORM       0.98      0.95      0.97       130
     B-SOFTWARE       0.72      0.41      0.52        32
  B-THREATACTOR       0.90      0.76      0.82        37
         B-TYPE       0.99      0.84      0.91       582
          B-URL       0.98      0.76      0.85       275
       B-VENDOR       0.94      0.86      0.90       334
      B-VERSION       0.85      0.50      0.63       114
B-VULNERABILITY       0.89      0.84      0.86       118
       I-DEVICE       0.00      0.00      0.00        18
     I

In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
import shutil

source = "enhanced_ioc_model_v2"
destination = "/content/drive/MyDrive/enhanced_ioc_model_v2"

shutil.copytree(source, destination)
print(f"Model copied to: {destination}")

Model copied to: /content/drive/MyDrive/enhanced_ioc_model_v2
