In [4]:
import os
import sys
import json
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import glob 
from typing import Optional


In [5]:

def load_batch(batch_file: str, batch_size: int = 32, device: torch.device = torch.device('cpu')) -> Optional[DataLoader]:
    """Load a batch file, ensure labels are set, and return a DataLoader."""
    if not os.path.exists(batch_file):
        print(f"Batch file not found: {batch_file}")
        return None

    batch = torch.load(batch_file)
    processed = []

    for graph in batch:
        try:
            # Assign binary label (malware = 1, benign = 0) based on 'family'
            family = getattr(graph, 'family', None)
            graph.y = torch.tensor(1 if family else 0, dtype=torch.long).to(device)

            # Move graph attributes to the target device
            graph.x = graph.x.to(device)
            graph.edge_index = graph.edge_index.to(device)
            graph.edge_attr = graph.edge_attr.to(device)
            processed.append(graph)
        except Exception as e:
            print(f"Error processing graph: {str(e)}")
            continue

    if not processed:
        return None
    return DataLoader(processed, batch_size=batch_size, shuffle=True)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

base_dir = '/data/saranyav/gcn_new/bodmas_batches_new'
batch_files = {split: sorted(glob.glob(os.path.join(base_dir, split, 'batch_*.pt'))) 
                for split in ['train', 'val', 'test']}
print(f"Found {len(batch_files['train'])} training batches and {len(batch_files['val'])} validation batches.")

# Use load_batch instead of directly loading
train_loader = load_batch(batch_files['train'][0], batch_size=32, device=device)
val_loader = load_batch(batch_files['val'][0], batch_size=32, device=device)
test_loader = load_batch(batch_files['test'][0], batch_size=32, device=device)

Using device: cuda
Found 352 training batches and 76 validation batches.


  batch = torch.load(batch_file)
  batch = torch.load(batch_file)


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import DataLoader
from collections import defaultdict
import numpy as np
from tqdm import tqdm


# GNN model definition
class MalwareGNN(nn.Module):
    def __init__(self, num_features, hidden_dim=64):
        super().__init__()
        # GCN layers
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        # Output for binary classification
        self.classifier = nn.Linear(hidden_dim, 2)  # 2 classes: benign and malware

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        # GCN layers with ReLU activation
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        # Global mean pooling
        x = global_mean_pool(x, batch)
        # Classification output
        return self.classifier(x)
    
def inspect_data_distribution(loader):
    """Analyze label distribution and feature statistics in the data"""
    label_counts = defaultdict(int)
    feature_stats = defaultdict(list)
    graph_sizes = []
    edge_counts = []
    
    print("\nInspecting data distribution...")
    for batch in loader:
        # Check labels
        for label in batch.y.cpu().numpy():
            label_counts[int(label)] += 1
            
        # Check features
        features = batch.x.cpu().numpy()
        feature_stats['mean'].append(np.mean(features))
        feature_stats['std'].append(np.std(features))
        feature_stats['min'].append(np.min(features))
        feature_stats['max'].append(np.max(features))
        
        # Check graph structure
        graphs = batch.to_data_list()
        for graph in graphs:
            graph_sizes.append(graph.x.shape[0])  # number of nodes
            edge_counts.append(graph.edge_index.shape[1])  # number of edges
    
    print("\nLabel Distribution:")
    total_samples = sum(label_counts.values())
    for label, count in label_counts.items():
        print(f"Label {label}: {count} samples ({100 * count / total_samples:.2f}%)")
    
    print("\nFeature Statistics:")
    for stat, values in feature_stats.items():
        print(f"{stat.capitalize()}: {np.mean(values):.4f}")
    
    print("\nGraph Structure Statistics:")
    print(f"Average nodes per graph: {np.mean(graph_sizes):.2f}")
    print(f"Average edges per graph: {np.mean(edge_counts):.2f}")
    return label_counts, feature_stats

def check_model_predictions(model, loader, device):
    """Analyze model's prediction patterns"""
    model.eval()
    all_preds = []
    all_labels = []
    all_confidences = []
    
    print("\nAnalyzing model predictions...")
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)
            probs = F.softmax(out, dim=1)
            preds = out.argmax(dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())
            all_confidences.extend(probs.max(dim=1)[0].cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_confidences = np.array(all_confidences)
    
    print("\nPrediction Statistics:")
    print(f"Average confidence: {np.mean(all_confidences):.4f}")
    print(f"Min confidence: {np.min(all_confidences):.4f}")
    print(f"Max confidence: {np.max(all_confidences):.4f}")
    
    print("\nPrediction Distribution:")
    unique_preds, pred_counts = np.unique(all_preds, return_counts=True)
    for pred, count in zip(unique_preds, pred_counts):
        print(f"Class {pred}: {count} predictions ({100 * count / len(all_preds):.2f}%)")
    
    return all_preds, all_labels, all_confidences

def check_graph_uniqueness(loader):
    """Check if there are duplicate graphs in the dataset"""
    print("\nChecking for duplicate graphs...")
    graph_hashes = set()
    duplicates = 0
    total = 0
    
    for batch in loader:
        graphs = batch.to_data_list()
        for graph in graphs:
            # Create a hash of the graph's key properties
            features_hash = hash(graph.x.cpu().numpy().tobytes())
            edges_hash = hash(graph.edge_index.cpu().numpy().tobytes())
            graph_hash = (features_hash, edges_hash)
            
            if graph_hash in graph_hashes:
                duplicates += 1
            graph_hashes.add(graph_hash)
            total += 1
    
    print(f"Found {duplicates} duplicate graphs out of {total} total graphs")
    print(f"Duplicate percentage: {100 * duplicates / total:.2f}%")
    return duplicates, total

def inspect_family_distribution(loader):
    """Analyze the distribution of malware families"""
    family_counts = defaultdict(int)
    print("\nAnalyzing malware family distribution...")
    
    for batch in loader:
        graphs = batch.to_data_list()
        for graph in graphs:
            if hasattr(graph, 'family'):
                family_counts[graph.family] += 1
    
    print("\nFamily Distribution:")
    total_samples = sum(family_counts.values())
    for family, count in sorted(family_counts.items()):
        print(f"Family {family}: {count} samples ({100 * count / total_samples:.2f}%)")
    
    return family_counts

# Example usage:
def run_diagnostics(model, train_loader, val_loader, device):
    print("=== Training Data Analysis ===")
    train_label_counts, train_feature_stats = inspect_data_distribution(train_loader)
    train_duplicates, train_total = check_graph_uniqueness(train_loader)
    train_family_dist = inspect_family_distribution(train_loader)
    train_preds, train_labels, train_conf = check_model_predictions(model, train_loader, device)
    
    print("\n=== Validation Data Analysis ===")
    val_label_counts, val_feature_stats = inspect_data_distribution(val_loader)
    val_duplicates, val_total = check_graph_uniqueness(val_loader)
    val_family_dist = inspect_family_distribution(val_loader)
    val_preds, val_labels, val_conf = check_model_predictions(model, val_loader, device)
    
    # Additional cross-dataset analysis
    print("\n=== Cross-Dataset Analysis ===")
    train_families = set(train_family_dist.keys())
    val_families = set(val_family_dist.keys())
    overlapping_families = train_families.intersection(val_families)
    print(f"Number of overlapping families: {len(overlapping_families)}")
    print(f"Families only in training: {len(train_families - val_families)}")
    print(f"Families only in validation: {len(val_families - train_families)}")

# To use this code in your notebook:

# First load your model and data loaders as in your original code
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = load_batch(batch_files['train'][0], batch_size=32, device=device)
val_loader = load_batch(batch_files['val'][0], batch_size=32, device=device)

# Get number of features from the first batch
first_batch = next(iter(train_loader))
num_features = first_batch.x.size(1)
model = MalwareGNN(num_features).to(device)

# Then run the diagnostics
run_diagnostics(model, train_loader, val_loader, device)


  batch = torch.load(batch_file)


=== Training Data Analysis ===

Inspecting data distribution...

Label Distribution:
Label 1: 100 samples (100.00%)

Feature Statistics:
Mean: 0.1829
Std: 1.2584
Min: 0.0000
Max: 995.2500

Graph Structure Statistics:
Average nodes per graph: 9191.79
Average edges per graph: 12438.68

Checking for duplicate graphs...
Found 16 duplicate graphs out of 100 total graphs
Duplicate percentage: 16.00%

Analyzing malware family distribution...

Family Distribution:
Family autoit: 6 samples (6.00%)
Family cambot: 1 samples (1.00%)
Family ceeinject: 3 samples (3.00%)
Family delfiles: 1 samples (1.00%)
Family dinwod: 1 samples (1.00%)
Family ditertag: 1 samples (1.00%)
Family dofoil: 2 samples (2.00%)
Family fearso: 1 samples (1.00%)
Family fuerboos: 5 samples (5.00%)
Family gandcrab: 2 samples (2.00%)
Family gepys: 1 samples (1.00%)
Family gupboot: 2 samples (2.00%)
Family hpgandcrab: 1 samples (1.00%)
Family juched: 1 samples (1.00%)
Family klez: 1 samples (1.00%)
Family mira: 3 samples (3.00%)


In [14]:
import pandas as pd
from pathlib import Path
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def analyze_metadata_pipeline():
    """Analyze each step of the metadata processing pipeline to find where benign samples might be lost."""
    
    # 1. First check the primary metadata
    logger.info("Reading primary metadata...")
    primary_df = pd.read_csv('bodmas_metadata_cleaned.csv')
    logger.info(f"\nPrimary metadata shape: {primary_df.shape}")
    logger.info("\nColumns in primary metadata:")
    logger.info(primary_df.columns.tolist())
    
    if 'family' in primary_df.columns:
        logger.info("\nFamily distribution in primary metadata:")
        logger.info(primary_df['family'].value_counts(dropna=False))
        logger.info(f"\nNull families: {primary_df['family'].isnull().sum()}")
    
    # 2. Check malware types metadata
    logger.info("\nReading malware types metadata...")
    malware_types_df = pd.read_csv('bodmas_malware_category.csv')
    logger.info(f"\nMalware types metadata shape: {malware_types_df.shape}")
    logger.info("\nColumns in malware types metadata:")
    logger.info(malware_types_df.columns.tolist())
    
    # 3. Check available files
    data_dir = Path('cfg_analysis_results/cfg_analysis_results')
    available_files = list(data_dir.glob('*.json.gz'))
    available_shas = [f.stem.replace('.json', '') for f in available_files]
    logger.info(f"\nFound {len(available_files)} files in data directory")
    
    # 4. Simulate the merge and filtering process
    logger.info("\nSimulating metadata processing pipeline...")
    
    # Extract filename without extension for joining
    primary_df['filename'] = primary_df['sha'].apply(lambda x: x)
    malware_types_df['filename'] = malware_types_df['sha256'].apply(lambda x: Path(x).stem)
    
    # First merge
    merged_df = pd.merge(
        primary_df,
        malware_types_df[['filename', 'category']],
        on='filename',
        how='left'
    )
    logger.info(f"\nAfter initial merge shape: {merged_df.shape}")
    
    # Check family distribution after merge
    if 'family' in merged_df.columns:
        logger.info("\nFamily distribution after merge:")
        logger.info(merged_df['family'].value_counts(dropna=False))
    
    # Filter to available files
    filtered_df = merged_df[merged_df['filename'].isin(available_shas)]
    logger.info(f"\nAfter filtering to available files shape: {filtered_df.shape}")
    
    if 'family' in filtered_df.columns:
        logger.info("\nFamily distribution after filtering:")
        logger.info(filtered_df['family'].value_counts(dropna=False))
    
    return primary_df, malware_types_df, merged_df, filtered_df

def check_file_contents(sha_list, data_dir='cfg_analysis_results/cfg_analysis_results'):
    """Check the actual contents of a few files to verify their structure."""
    import gzip
    import json
    
    logger.info("\nChecking sample file contents...")
    for sha in sha_list[:5]:  # Check first 5 files
        filepath = Path(data_dir) / f"{sha}.json.gz"
        if filepath.exists():
            try:
                with gzip.open(filepath, 'rt') as f:
                    data = json.load(f)
                    logger.info(f"\nFile: {sha}")
                    logger.info(f"Keys in file: {list(data.keys())}")
                    if 'graph_structure' in data:
                        logger.info(f"Graph structure keys: {list(data['graph_structure'].keys())}")
            except Exception as e:
                logger.error(f"Error reading file {sha}: {str(e)}")

# To use:

primary_df, malware_types_df, merged_df, filtered_df = analyze_metadata_pipeline()

# Check some sample files
sample_shas = filtered_df['sha'].head().tolist()
check_file_contents(sample_shas)


INFO:__main__:Reading primary metadata...


FileNotFoundError: [Errno 2] No such file or directory: 'bodmas_metadata_cleaned.csv'

In [2]:
# inspect /bodmas_batches/train/batch_0139.pt 

file = '/data/saranyav/gcn_new/bodmas_batches/train/batch_0147.pt'


In [21]:
def check_file_existence():
    # Read metadata
    print("Reading metadata file...")
    df = pd.read_csv('/data/saranyav/gcn_new/bodmas_metadata_cleaned.csv')
    
    # Check a single file first
    first_sha = df['sha'].iloc[0]
    first_path = f'/data/datasets/bodmas_exes/refanged_exes/{first_sha}_refang.exe'
    print(f"\nChecking first file:")
    print(f"SHA: {first_sha}")
    print(f"Full path: {first_path}")
    print(f"Exists: {os.path.exists(first_path)}")
    
    # List some files in the directory
    exe_dir = '/data/datasets/bodmas_exes/refanged_exes/'
    if os.path.exists(exe_dir):
        print("\nFirst few files in directory:")
        for file in os.listdir(exe_dir)[:5]:
            print(file)

In [15]:
batch

NameError: name 'batch' is not defined

In [14]:
def prepare_group_mappings(json_path):
    """Convert JSON group mappings to use integer group IDs."""
    with open(json_path, 'r') as f:
        raw_data = json.load(f)
    
    # Convert string group IDs to integers
    family_to_group = {}
    for group_id_str, families in raw_data.items():
        group_id = int(group_id_str)
        for family in families:
            family_to_group[family] = group_id
            
    return family_to_group

# Example usage
json_path = '/data/saranyav/gcn_new/behavioral_analysis/behavioral_groups.json'
family_to_group = prepare_group_mappings(json_path)
file_path = "/data/saranyav/gcn_new/bodmas_batches/train/batch_0147.pt"
batch = load_batch(file_path, family_to_group, batch_size=32)

  batch_data = torch.load(batch_file)


In [16]:
family_to_group

{'dapato': 2,
 'coinminer': 2,
 'grandcrab': 2,
 'upatre': 2,
 'unruy': 2,
 'plite': 2,
 'autorun': 2,
 'urelas': 2,
 'small': 2,
 'padodor': 2,
 'fuery': 2,
 'trickbot': 2,
 'wacatac': 2,
 'skeeyah': 2,
 'lunam': 2,
 'gc13003b': 2,
 'vflooder': 2,
 'gupboot': 2,
 'omaneat': 2,
 'fuerboos': 2,
 'skeeeyah': 2,
 'qqpass': 2,
 'ceeinject': 2,
 'autoit': 2,
 'gepys': 2,
 'tescrypt': 2,
 'blocker': 2,
 'vbclone': 2,
 'nanocore': 2,
 'dorv': 2,
 'tinba': 2,
 'fukru': 2,
 'glupteba': 2,
 'shifu': 2,
 'delf': 2,
 'ditertag': 2,
 'aenjaris': 2,
 'bladabindi': 2,
 'pkeylog': 2,
 'nanobot': 2,
 'emelent': 2,
 'neconyd': 2,
 'injector': 2,
 'kirts': 2,
 'bho': 2,
 'dynamer': 2,
 'smokeloader': 2,
 'tiggre': 2,
 'trickybot': 2,
 'revetrat': 2,
 'nitol': 2,
 'banker': 2,
 'systex': 2,
 'noancooe': 2,
 'mocrt': 2,
 'sillyfdc': 2,
 'shipup': 2,
 'sakurel': 2,
 'simbot': 2,
 'ursnif': 2,
 'autinject': 2,
 'plugx': 2,
 'occamy': 2,
 'quasarrat': 2,
 'Unknown': 2,
 'vbinject': 2,
 'ircbot': 2,
 'cryptomi

In [6]:
# # check if each item in this file has edge_attr
# # Example content:
# # Data(x=[6420, 14], edge_index=[2, 6655], edge_attr=[2072, 1], num_nodes=6420, sha='24a686ae67f97fbe479bb4c49f1ee2b6c7bd5bd7c0b6d59767dd14d36da14ca2', timestamp='2019-08-31 22:39:01 UTC', family='vflooder', malware_type='trojan')

# print("\nChecking for edge_attr in each item:")

# if isinstance(data, dict):
#     for key, value in data.items():
#         if hasattr(value, 'edge_attr'):
#             print(f"Key: {key}, edge_attr shape: {value.edge_attr.shape}")
#         else:
#             print(f"Key: {key}, edge_attr not found")


Checking for edge_attr in each item:


In [10]:
batch_data = torch.load('/data/saranyav/gcn_new/bodmas_batches/train/batch_0157.pt')
# for i, graph in enumerate(batch_data):
#     print(f"Graph {i}:")
#     print(f"x shape: {getattr(graph, 'x', None)}")
#     print(f"edge_index shape: {getattr(graph, 'edge_index', None)}")
#     print(f"edge_attr shape: {getattr(graph, 'edge_attr', None)}")


# Inspect the batch content
for i, graph in enumerate(batch):
    print(f"Graph {i}:")
    print(f"  x shape: {graph.x.shape}")
    print(f"  edge_index shape: {graph.edge_index.shape}")
    print(f"  edge_attr shape: {graph.edge_attr.shape if graph.edge_attr is not None else 'None'}")
    print(f"  Family: {getattr(graph, 'family', None)}")
    print(f"  y: {getattr(graph, 'y', None)}")
    print(f"  Is y None? {'Yes' if getattr(graph, 'y', None) is None else 'No'}")
    print("-" * 50)

Graph 0:
  x shape: torch.Size([62829, 14])
  edge_index shape: torch.Size([2, 91463])
  edge_attr shape: torch.Size([30827, 1])
  Family: agent
  y: None
  Is y None? Yes
--------------------------------------------------
Graph 1:
  x shape: torch.Size([3262, 14])
  edge_index shape: torch.Size([2, 4614])
  edge_attr shape: torch.Size([2435, 1])
  Family: small
  y: None
  Is y None? Yes
--------------------------------------------------
Graph 2:
  x shape: torch.Size([1585, 14])
  edge_index shape: torch.Size([2, 1411])
  edge_attr shape: torch.Size([321, 1])
  Family: upatre
  y: None
  Is y None? Yes
--------------------------------------------------
Graph 3:
  x shape: torch.Size([3262, 14])
  edge_index shape: torch.Size([2, 4614])
  edge_attr shape: torch.Size([2435, 1])
  Family: small
  y: None
  Is y None? Yes
--------------------------------------------------
Graph 4:
  x shape: torch.Size([330, 14])
  edge_index shape: torch.Size([2, 437])
  edge_attr shape: torch.Size([256

  batch_data = torch.load('/data/saranyav/gcn_new/bodmas_batches/train/batch_0157.pt')


In [10]:
import torch



# Function to check for family='none'
def check_family_none(data):
    if isinstance(data, dict):
        for key, value in data.items():
            if isinstance(value, dict) or isinstance(value, list):
                # Recursively inspect nested dictionaries or lists
                result = check_family_none(value)
                if result:
                    return result
            elif hasattr(value, 'family') and getattr(value, 'family', None) == 'none':
                return True
    elif isinstance(data, list):
        for item in data:
            if isinstance(item, dict) or isinstance(item, list):
                # Recursively inspect nested lists or dictionaries
                result = check_family_none(item)
                if result:
                    return result
            elif hasattr(item, 'family') and getattr(item, 'family', None) == 'none':
                return True
    elif hasattr(data, 'family') and getattr(data, 'family', None) == 'none':
        return True
    return False

# Check for family='none'
found = check_family_none(data)
if found:
    print("A key or attribute with family='none' exists in the data.")
else:
    print("No key or attribute with family='none' was found in the data.")

No key or attribute with family='none' was found in the data.


In [8]:
import glob
from pathlib import Path

batch_files = glob.glob('/data/saranyav/gcn_new/bodmas_batches/train/batch_*.pt')

for batch_file in batch_files:
    #print(f"Inspecting {batch_file}...")
    batch = torch.load(batch_file)

    for i, graph in enumerate(batch):
        #print(f"Graph {i}: edge_attr exists? {'edge_attr' in graph}")
        if not hasattr(graph, 'edge_attr') or graph.edge_attr is None:
            print(f"Graph {i} is missing edge_attr. Assigning default value.")

  batch = torch.load(batch_file)


In [None]:
import os
import torch

# Directory containing the .pt files
directory =  "/data/saranyav/gcn_new/bodmas_batches/train"


# Function to check if family='none' exists in a data object
def has_family_none(data):
    if isinstance(data, dict):
        for key, value in data.items():
            if isinstance(value, (dict, list)):
                if has_family_none(value):
                    return True
            elif hasattr(value, 'family') and getattr(value, 'family', None) == 'none':
                return True
    elif isinstance(data, list):
        for item in data:
            if isinstance(item, (dict, list)):
                if has_family_none(item):
                    return True
            elif hasattr(item, 'family') and getattr(item, 'family', None) == 'none':
                return True
    elif hasattr(data, 'family') and getattr(data, 'family', None) == 'none':
        return True
    return False

# Iterate through all .pt files in the directory
for root, _, files in os.walk(directory):
    for file in files:
        if file.endswith(".pt"):
            file_path = os.path.join(root, file)
            try:
                # Load the .pt file
                data = torch.load(file_path)
                # Check for family='none'
                if has_family_none(data):
                    print(f"Found family='none' in file: {file_path}")
            except Exception as e:
                print(f"Could not process file {file_path}: {e}")