 **==================== MODEL DEFINITIONS ====================**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
import time
import json
import pickle
import os
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

# **CoverageLayer**

In [None]:
class CoverageLayer(nn.Module):
    def __init__(self):
        super(CoverageLayer, self).__init__()
        self.register_buffer('activations', None)
        self.num_sections = 10

    def forward(self, x):
        self.activations = x.detach()
        return x

    def get_neuron_outputs(self):
        return self.activations

    def compute_section_distribution(self, min_val, max_val):
        activations = self.activations
        delta = (max_val - min_val) / (self.num_sections - 1)
        delta = torch.where(delta == 0, torch.tensor(1e-10, device=delta.device), delta)

        section_indices = torch.clamp(((activations - min_val) / delta).long(), 0, self.num_sections - 1)
        section_counts = torch.bincount(section_indices.view(-1), minlength=self.num_sections)
        section_distributions = section_counts.float() / activations.numel()
        return section_distributions

# **VGG19WithCoverage**

In [None]:
class VGG19WithCoverage(nn.Module):
    def __init__(self):
        super(VGG19WithCoverage, self).__init__()
        original_vgg = models.vgg19(pretrained=True)

        # Feature layers
        self.conv1_1 = original_vgg.features[0]
        self.conv1_2 = original_vgg.features[2]
        self.pool1 = original_vgg.features[4]

        self.conv2_1 = original_vgg.features[5]
        self.conv2_2 = original_vgg.features[7]
        self.pool2 = original_vgg.features[9]

        self.conv3_1 = original_vgg.features[10]
        self.conv3_2 = original_vgg.features[12]
        self.conv3_3 = original_vgg.features[14]
        self.conv3_4 = original_vgg.features[16]
        self.pool3 = original_vgg.features[18]

        self.conv4_1 = original_vgg.features[19]
        self.conv4_2 = original_vgg.features[21]
        self.conv4_3 = original_vgg.features[23]
        self.conv4_4 = original_vgg.features[25]
        self.pool4 = original_vgg.features[27]

        self.conv5_1 = original_vgg.features[28]
        self.conv5_2 = original_vgg.features[30]
        self.conv5_3 = original_vgg.features[32]
        self.conv5_4 = original_vgg.features[34]
        self.pool5 = original_vgg.features[36]

        # Classifier layers
        self.fc1 = original_vgg.classifier[0]
        self.fc2 = original_vgg.classifier[3]
        self.fc3 = original_vgg.classifier[6]

        # Coverage layers
        self.coverage_layers = nn.ModuleList([CoverageLayer() for _ in range(18)])
        self.avgpool = original_vgg.avgpool

    def forward(self, x):
        # Block 1
        x = F.relu(self.conv1_1(x))
        x = self.coverage_layers[0](x)
        x = F.relu(self.conv1_2(x))
        x = self.pool1(x)
        x = self.coverage_layers[1](x)

        # Block 2
        x = F.relu(self.conv2_1(x))
        x = self.coverage_layers[2](x)
        x = F.relu(self.conv2_2(x))
        x = self.pool2(x)
        x = self.coverage_layers[3](x)

        # Block 3
        x = F.relu(self.conv3_1(x))
        x = self.coverage_layers[4](x)
        x = F.relu(self.conv3_2(x))
        x = self.coverage_layers[5](x)
        x = F.relu(self.conv3_3(x))
        x = self.coverage_layers[6](x)
        x = F.relu(self.conv3_4(x))
        x = self.pool3(x)
        x = self.coverage_layers[7](x)

        # Block 4
        x = F.relu(self.conv4_1(x))
        x = self.coverage_layers[8](x)
        x = F.relu(self.conv4_2(x))
        x = self.coverage_layers[9](x)
        x = F.relu(self.conv4_3(x))
        x = self.coverage_layers[10](x)
        x = F.relu(self.conv4_4(x))
        x = self.pool4(x)
        x = self.coverage_layers[11](x)

        # Block 5
        x = F.relu(self.conv5_1(x))
        x = self.coverage_layers[12](x)
        x = F.relu(self.conv5_2(x))
        x = self.coverage_layers[13](x)
        x = F.relu(self.conv5_3(x))
        x = self.coverage_layers[14](x)
        x = F.relu(self.conv5_4(x))
        x = self.pool5(x)
        x = self.coverage_layers[15](x)

        # Classifier
        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        x = F.relu(self.fc1(x))
        x = self.coverage_layers[16](x)
        x = F.dropout(x, p=0.5, training=self.training)

        x = F.relu(self.fc2(x))
        x = self.coverage_layers[17](x)
        x = F.dropout(x, p=0.5, training=self.training)

        x = self.fc3(x)
        return x

# **==================== DATA LOGGING ====================**

In [None]:
class TrainingDataLogger:
    def __init__(self):
        self.data = {
            'metadata': {
                'start_time': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                'end_time': None,
                'total_epochs': 0
            },
            'epochs': [],
            'config': {}
        }

    def add_epoch_data(self, epoch, train_metrics, test_metrics, coverage_stats):
        epoch_data = {
            'epoch': epoch,
            'train': train_metrics,
            'test': {
                'trusted': test_metrics['trusted'],
                'untrusted': test_metrics['untrusted']
            },
            'coverage': {}
        }

        # Add coverage layer statistics
        for i, layer_stats in enumerate(test_metrics['trusted']['coverage_stats']):
            epoch_data['coverage'][f'layer_{i+1}'] = {
                'trusted_coverage': layer_stats['coverage_percentage'],
                'untrusted_coverage': test_metrics['untrusted']['coverage_stats'][i]['coverage_percentage'],
                'distribution_similarity': layer_stats['distribution_similarity']
            }

        self.data['epochs'].append(epoch_data)

    def finalize(self, total_epochs):
        self.data['metadata']['end_time'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.data['metadata']['total_epochs'] = total_epochs

    def save(self, filename):
        # Save as JSON (human-readable)
        json_filename = f"{filename}.json"
        with open(json_filename, 'w') as f:
            json.dump(self.data, f, indent=4)

        # Save as pickle (preserves Python objects)
        pkl_filename = f"{filename}.pkl"
        with open(pkl_filename, 'wb') as f:
            pickle.dump(self.data, f)

        print(f"Saved training data to {json_filename} and {pkl_filename}")

    @staticmethod
    def load(filename):
        with open(f"{filename}.pkl", 'rb') as f:
            return pickle.load(f)


**==================== TRAINING COMPONENTS ====================**

# **SignatureGenerator**

In [None]:
class SignatureGenerator:
    def __init__(self, model, coverage_layer_indices, method):
        self.model = model
        self.coverage_layer_indices = coverage_layer_indices
        self.method = method

    def aggregate_signatures(self, trusted_loader, device):
        self.model.eval()
        signatures = [None] * len(self.model.coverage_layers)

        with torch.no_grad():
            # Initialize min/max with first batch
            data, _ = next(iter(trusted_loader))
            data = data.to(device)
            self.model(data)

            for i in self.coverage_layer_indices:
                layer = self.model.coverage_layers[i]
                activations = layer.get_neuron_outputs()
                signatures[i] = {
                    'min': activations.min(dim=0)[0],
                    'max': activations.max(dim=0)[0],
                    'distributions': None
                }

            # Process remaining batches
            for images, _ in trusted_loader:
                images = images.to(device)
                self.model(images)

                for i in self.coverage_layer_indices:
                    layer = self.model.coverage_layers[i]
                    activations = layer.get_neuron_outputs()

                    current_min = activations.min(dim=0)[0]
                    current_max = activations.max(dim=0)[0]

                    signatures[i]['min'] = torch.min(signatures[i]['min'], current_min)
                    signatures[i]['max'] = torch.max(signatures[i]['max'], current_max)

            if self.method == 'mrc':
                for i in self.coverage_layer_indices:
                    layer = self.model.coverage_layers[i]
                    signatures[i]['distributions'] = layer.compute_section_distribution(
                        signatures[i]['min'],
                        signatures[i]['max']
                    )

        return signatures

# **ConfidenceLoss**

In [None]:
class ConfidenceLoss(nn.Module):
    def __init__(self, signatures, coverage_layer_indices, method):
        super(ConfidenceLoss, self).__init__()
        self.signatures = signatures
        self.coverage_layer_indices = coverage_layer_indices
        self.method = method

        # Pre-compute and cache values
        self.cached_min_vals = []
        self.cached_max_vals = []
        self.cached_distributions = []

        for i in coverage_layer_indices:
            self.cached_min_vals.append(signatures[i]['min'])
            self.cached_max_vals.append(signatures[i]['max'])
            self.cached_distributions.append(signatures[i]['distributions'] if method == 'mrc' else None)

    def forward(self, model):
        total_loss = torch.tensor(0.0, device=next(model.parameters()).device)

        for idx, layer_idx in enumerate(self.coverage_layer_indices):
            layer = model.coverage_layers[layer_idx]
            activations = layer.get_neuron_outputs()

            min_vals = self.cached_min_vals[idx].to(activations.device)
            max_vals = self.cached_max_vals[idx].to(activations.device)

            if self.method == 'src':
                outside_range = torch.logical_or(activations < min_vals, activations > max_vals).float().mean()
                total_loss += outside_range
            elif self.method == 'mrc':
                outside_range = torch.logical_or(activations < min_vals, activations > max_vals).float().mean()

                num_sections = model.coverage_layers[layer_idx].num_sections
                delta = (max_vals - min_vals) / (num_sections - 1)
                delta = torch.where(delta == 0, torch.tensor(1e-10, device=delta.device), delta)

                section_indices = torch.clamp(((activations - min_vals) / delta).long(), 0, num_sections - 1)
                current_dist = torch.bincount(section_indices.view(-1), minlength=num_sections).float()
                current_dist /= activations.numel()

                signature_dist = self.cached_distributions[idx].to(activations.device)
                distribution_loss = F.mse_loss(current_dist, signature_dist)

                total_loss += outside_range + distribution_loss

        return total_loss

**==================== TRAINING AND TESTING ====================**

# **train**

In [None]:
# ==================== TRAINING AND TESTING ====================
def train(model, device, train_loader, optimizer, epoch, confidence_loss_fn, signatures, coverage_layer_indices, confidence_threshold=0.8):
    model.train()
    total_loss = total_confidence_loss = total_ce_loss = 0.0
    correct = high_confidence_correct = total_samples = high_confidence_total = 0
    start_time = time.time()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)
        ce_loss = F.cross_entropy(output, target)
        conf_loss = confidence_loss_fn(model)
        loss = ce_loss + conf_loss

        loss.backward()
        optimizer.step()

        # Calculate metrics
        with torch.no_grad():
            pred = output.argmax(dim=1)
            correct_mask = pred.eq(target)
            correct += correct_mask.sum().item()
            total_samples += len(data)

            confidence = F.softmax(output, dim=1).max(dim=1)[0]
            high_conf_mask = confidence >= confidence_threshold
            high_confidence_correct += (correct_mask & high_conf_mask).sum().item()
            high_confidence_total += high_conf_mask.sum().item()

            total_loss += loss.item()
            total_confidence_loss += conf_loss.item()
            total_ce_loss += ce_loss.item()

        if batch_idx % 100 == 0:
            elapsed = time.time() - start_time
            accuracy = 100. * correct / total_samples
            high_conf_acc = 100. * high_confidence_correct / max(1, high_confidence_total)

            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] '
                  f'Loss: {loss.item():.4f}, CE: {ce_loss.item():.4f}, Conf Loss: {conf_loss.item():.4f}, '
                  f'Accuracy: {accuracy:.2f}%, High Conf Acc: {high_conf_acc:.2f}%, Time: {elapsed:.2f}s')

    avg_loss = total_loss / len(train_loader)
    avg_conf_loss = total_confidence_loss / len(train_loader)
    avg_ce_loss = total_ce_loss / len(train_loader)
    accuracy = 100. * correct / total_samples
    high_conf_accuracy = 100. * high_confidence_correct / max(1, high_confidence_total)

    print(f'Train Epoch {epoch} Summary: Loss: {avg_loss:.4f}, CE: {avg_ce_loss:.4f}, '
          f'Conf Loss: {avg_conf_loss:.4f}, Accuracy: {accuracy:.2f}%, High Conf Acc: {high_conf_accuracy:.2f}%')

    return {
        'loss': avg_loss,
        'conf_loss': avg_conf_loss,
        'ce_loss': avg_ce_loss,
        'accuracy': accuracy,
        'high_conf_accuracy': high_conf_accuracy
    }

def test(model, device, test_loader, untrusted_loader, confidence_loss_fn, signatures, coverage_layer_indices, confidence_threshold=0.8):
    model.eval()

    def evaluate_loader(loader, is_trusted=True):
        test_loss = confidence_loss = 0.0
        correct = high_confidence_correct = total_samples = high_confidence_total = 0
        confidence_scores = []
        coverage_stats = [{
            'within_range': 0,
            'total': 0,
            'coverage_percentage': 0,
            'distribution_similarity': 0
        } for _ in coverage_layer_indices]

        confidence_bins = {
            '0.0-0.2': {'correct': 0, 'total': 0},
            '0.2-0.4': {'correct': 0, 'total': 0},
            '0.4-0.6': {'correct': 0, 'total': 0},
            '0.6-0.8': {'correct': 0, 'total': 0},
            '0.8-0.9': {'correct': 0, 'total': 0},
            '0.9-1.0': {'correct': 0, 'total': 0},
        }

        with torch.no_grad():
            for data, target in loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                confidence = F.softmax(output, dim=1).max(dim=1)[0]
                confidence_scores.extend(confidence.cpu().tolist())

                if is_trusted:
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()
                    conf_loss = confidence_loss_fn(model)
                    confidence_loss += conf_loss.item()

                    pred = output.argmax(dim=1)
                    correct_mask = pred.eq(target)
                    correct += correct_mask.sum().item()
                    total_samples += len(data)

                    high_conf_mask = confidence >= confidence_threshold
                    high_confidence_correct += (correct_mask & high_conf_mask).sum().item()
                    high_confidence_total += high_conf_mask.sum().item()

                    # Update confidence bins
                    for conf, corr in zip(confidence, correct_mask):
                        conf_val = conf.item()
                        if conf_val < 0.2:
                            bin_key = '0.0-0.2'
                        elif conf_val < 0.4:
                            bin_key = '0.2-0.4'
                        elif conf_val < 0.6:
                            bin_key = '0.4-0.6'
                        elif conf_val < 0.8:
                            bin_key = '0.6-0.8'
                        elif conf_val < 0.9:
                            bin_key = '0.8-0.9'
                        else:
                            bin_key = '0.9-1.0'

                        confidence_bins[bin_key]['total'] += 1
                        if corr:
                            confidence_bins[bin_key]['correct'] += 1

                # Track coverage statistics
                for stat_idx, layer_idx in enumerate(coverage_layer_indices):
                    layer = model.coverage_layers[layer_idx]
                    activations = layer.get_neuron_outputs()

                    min_vals = signatures[layer_idx]['min'].to(device)
                    max_vals = signatures[layer_idx]['max'].to(device)

                    within_range = torch.logical_and(
                        activations >= min_vals,
                        activations <= max_vals
                    ).float().sum().item()

                    total_activations = activations.numel()
                    coverage_stats[stat_idx]['within_range'] += within_range
                    coverage_stats[stat_idx]['total'] += total_activations

                    if signatures[layer_idx]['distributions'] is not None:
                        num_sections = layer.num_sections
                        delta = (max_vals - min_vals) / (num_sections - 1)
                        delta = torch.where(delta == 0, torch.tensor(1e-10, device=delta.device), delta)

                        section_indices = torch.clamp(((activations - min_vals) / delta).long(), 0, num_sections - 1)
                        current_dist = torch.bincount(section_indices.view(-1), minlength=num_sections).float()
                        current_dist /= activations.numel()

                        signature_dist = signatures[layer_idx]['distributions'].to(device)
                        dist_similarity = 1.0 - F.mse_loss(current_dist, signature_dist).item()
                        coverage_stats[stat_idx]['distribution_similarity'] += dist_similarity

        # Calculate final metrics
        if is_trusted:
            test_loss /= len(loader.dataset)
            confidence_loss /= len(loader)
            accuracy = 100. * correct / total_samples if total_samples > 0 else 0
            high_conf_accuracy = 100. * high_confidence_correct / max(1, high_confidence_total)
        else:
            test_loss = confidence_loss = accuracy = high_conf_accuracy = 0

        avg_confidence = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0

        # Calculate accuracy per confidence bin
        bin_metrics = {}
        for bin_key, values in confidence_bins.items():
            if values['total'] > 0:
                bin_acc = 100. * values['correct'] / values['total'] if is_trusted else 0
                bin_metrics[bin_key] = {
                    'accuracy': bin_acc,
                    'samples': values['total'],
                    'percentage': 100. * values['total'] / total_samples if total_samples > 0 else 0
                }
            else:
                bin_metrics[bin_key] = {'accuracy': 0, 'samples': 0, 'percentage': 0}

        # Calculate final coverage stats
        for stat_idx in range(len(coverage_layer_indices)):
            if coverage_stats[stat_idx]['total'] > 0:
                coverage_stats[stat_idx]['coverage_percentage'] = (
                    100.0 * coverage_stats[stat_idx]['within_range'] / coverage_stats[stat_idx]['total']
                )

                if signatures[coverage_layer_indices[stat_idx]]['distributions'] is not None:
                    coverage_stats[stat_idx]['distribution_similarity'] /= len(loader)

        return {
            'loss': test_loss,
            'conf_loss': confidence_loss,
            'accuracy': accuracy,
            'avg_confidence': avg_confidence,
            'high_conf_accuracy': high_conf_accuracy,
            'bin_metrics': bin_metrics,
            'coverage_stats': coverage_stats,
            'confidence_scores': confidence_scores
        }

    # Evaluate trusted data
    print("\nEvaluating on trusted data...")
    trusted_results = evaluate_loader(test_loader, is_trusted=True)

    # Evaluate untrusted data
    print("\nEvaluating on untrusted data...")
    untrusted_results = evaluate_loader(untrusted_loader, is_trusted=False)

    # Compare results
    print("\nConfidence Comparison:")
    print(f"Trusted data avg confidence: {trusted_results['avg_confidence']:.4f}")
    print(f"Untrusted data avg confidence: {untrusted_results['avg_confidence']:.4f}")

    print("\nCoverage Comparison (trusted vs untrusted):")
    for i, layer_idx in enumerate(coverage_layer_indices[:5]):  # Show first 5 layers
        trusted_cov = trusted_results['coverage_stats'][i]['coverage_percentage']
        untrusted_cov = untrusted_results['coverage_stats'][i]['coverage_percentage']
        print(f"  Layer {layer_idx+1}: {trusted_cov:.2f}% vs {untrusted_cov:.2f}% coverage")

    return {
        'trusted': trusted_results,
        'untrusted': untrusted_results
    }


# **test**

In [None]:
def test(model, device, test_loader, untrusted_loader, confidence_loss_fn, signatures, coverage_layer_indices, confidence_threshold=0.8):
    model.eval()

    def evaluate_loader(loader, is_trusted=True):
        test_loss = confidence_loss = 0.0
        correct = high_confidence_correct = total_samples = high_confidence_total = 0
        confidence_scores = []
        coverage_stats = [{
            'within_range': 0,
            'total': 0,
            'coverage_percentage': 0,
            'distribution_similarity': 0
        } for _ in coverage_layer_indices]

        confidence_bins = {
            '0.0-0.2': {'correct': 0, 'total': 0},
            '0.2-0.4': {'correct': 0, 'total': 0},
            '0.4-0.6': {'correct': 0, 'total': 0},
            '0.6-0.8': {'correct': 0, 'total': 0},
            '0.8-0.9': {'correct': 0, 'total': 0},
            '0.9-1.0': {'correct': 0, 'total': 0},
        }

        with torch.no_grad():
            for data, target in loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                confidence = F.softmax(output, dim=1).max(dim=1)[0]
                confidence_scores.extend(confidence.cpu().tolist())

                if is_trusted:
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()
                    conf_loss = confidence_loss_fn(model)
                    confidence_loss += conf_loss.item()

                    pred = output.argmax(dim=1)
                    correct_mask = pred.eq(target)
                    correct += correct_mask.sum().item()
                    total_samples += len(data)

                    high_conf_mask = confidence >= confidence_threshold
                    high_confidence_correct += (correct_mask & high_conf_mask).sum().item()
                    high_confidence_total += high_conf_mask.sum().item()

                    # Update confidence bins
                    for conf, corr in zip(confidence, correct_mask):
                        conf_val = conf.item()
                        if conf_val < 0.2:
                            bin_key = '0.0-0.2'
                        elif conf_val < 0.4:
                            bin_key = '0.2-0.4'
                        elif conf_val < 0.6:
                            bin_key = '0.4-0.6'
                        elif conf_val < 0.8:
                            bin_key = '0.6-0.8'
                        elif conf_val < 0.9:
                            bin_key = '0.8-0.9'
                        else:
                            bin_key = '0.9-1.0'

                        confidence_bins[bin_key]['total'] += 1
                        if corr:
                            confidence_bins[bin_key]['correct'] += 1

                # Track coverage statistics
                for stat_idx, layer_idx in enumerate(coverage_layer_indices):
                    layer = model.coverage_layers[layer_idx]
                    activations = layer.get_neuron_outputs()

                    min_vals = signatures[layer_idx]['min'].to(device)
                    max_vals = signatures[layer_idx]['max'].to(device)

                    within_range = torch.logical_and(
                        activations >= min_vals,
                        activations <= max_vals
                    ).float().sum().item()

                    total_activations = activations.numel()
                    coverage_stats[stat_idx]['within_range'] += within_range
                    coverage_stats[stat_idx]['total'] += total_activations

                    if signatures[layer_idx]['distributions'] is not None:
                        num_sections = layer.num_sections
                        delta = (max_vals - min_vals) / (num_sections - 1)
                        delta = torch.where(delta == 0, torch.tensor(1e-10, device=delta.device), delta)

                        section_indices = torch.clamp(((activations - min_vals) / delta).long(), 0, num_sections - 1)
                        current_dist = torch.bincount(section_indices.view(-1), minlength=num_sections).float()
                        current_dist /= activations.numel()

                        signature_dist = signatures[layer_idx]['distributions'].to(device)
                        dist_similarity = 1.0 - F.mse_loss(current_dist, signature_dist).item()
                        coverage_stats[stat_idx]['distribution_similarity'] += dist_similarity

        # Calculate final metrics
        if is_trusted:
            test_loss /= len(loader.dataset)
            confidence_loss /= len(loader)
            accuracy = 100. * correct / total_samples if total_samples > 0 else 0
            high_conf_accuracy = 100. * high_confidence_correct / max(1, high_confidence_total)
        else:
            test_loss = confidence_loss = accuracy = high_conf_accuracy = 0

        avg_confidence = sum(confidence_scores) / len(confidence_scores) if confidence_scores else 0

        # Calculate accuracy per confidence bin
        bin_metrics = {}
        for bin_key, values in confidence_bins.items():
            if values['total'] > 0:
                bin_acc = 100. * values['correct'] / values['total'] if is_trusted else 0
                bin_metrics[bin_key] = {
                    'accuracy': bin_acc,
                    'samples': values['total'],
                    'percentage': 100. * values['total'] / total_samples if total_samples > 0 else 0
                }
            else:
                bin_metrics[bin_key] = {'accuracy': 0, 'samples': 0, 'percentage': 0}

        # Calculate final coverage stats
        for stat_idx in range(len(coverage_layer_indices)):
            if coverage_stats[stat_idx]['total'] > 0:
                coverage_stats[stat_idx]['coverage_percentage'] = (
                    100.0 * coverage_stats[stat_idx]['within_range'] / coverage_stats[stat_idx]['total']
                )

                if signatures[coverage_layer_indices[stat_idx]]['distributions'] is not None:
                    coverage_stats[stat_idx]['distribution_similarity'] /= len(loader)

        return {
            'loss': test_loss,
            'conf_loss': confidence_loss,
            'accuracy': accuracy,
            'avg_confidence': avg_confidence,
            'high_conf_accuracy': high_conf_accuracy,
            'bin_metrics': bin_metrics,
            'coverage_stats': coverage_stats,
            'confidence_scores': confidence_scores
        }

    # Evaluate trusted data
    print("\nEvaluating on trusted data...")
    trusted_results = evaluate_loader(test_loader, is_trusted=True)

    # Evaluate untrusted data
    print("\nEvaluating on untrusted data...")
    untrusted_results = evaluate_loader(untrusted_loader, is_trusted=False)

    # Compare results
    print("\nConfidence Comparison:")
    print(f"Trusted data avg confidence: {trusted_results['avg_confidence']:.4f}")
    print(f"Untrusted data avg confidence: {untrusted_results['avg_confidence']:.4f}")

    print("\nCoverage Comparison (trusted vs untrusted):")
    for i, layer_idx in enumerate(coverage_layer_indices[:5]):  # Show first 5 layers
        trusted_cov = trusted_results['coverage_stats'][i]['coverage_percentage']
        untrusted_cov = untrusted_results['coverage_stats'][i]['coverage_percentage']
        print(f"  Layer {layer_idx+1}: {trusted_cov:.2f}% vs {untrusted_cov:.2f}% coverage")

    return {
        'trusted': trusted_results,
        'untrusted': untrusted_results
    }

# **==================== MAIN FUNCTION ====================**

# **Main**

In [None]:
def main():
    # Configuration
    batch_size = 64
    epochs = 10
    learning_rate = 0.001
    momentum = 0.9
    confidence_threshold = 0.9
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Data transforms
    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    mnist_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Load datasets
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=mnist_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    untrusted_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Model setup
    model = VGG19WithCoverage().to(device)
    model.fc3 = nn.Linear(4096, 10).to(device)  # Modify for CIFAR-10

    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=epochs)

    # Coverage layer configuration
    coverage_layer_indices = list(range(18))  # All 18 coverage layers

    # Initialize data logger
    logger = TrainingDataLogger()
    logger.data['config'] = {
        'batch_size': batch_size,
        'epochs': epochs,
        'learning_rate': learning_rate,
        'confidence_threshold': confidence_threshold,
        'model': 'VGG19WithCoverage'
    }

    # Initialize signature generator
    signature_generator = SignatureGenerator(model, coverage_layer_indices, method='mrc')
    signatures = signature_generator.aggregate_signatures(train_loader, device)

    # Initialize confidence loss
    confidence_loss_fn = ConfidenceLoss(signatures, coverage_layer_indices, method='mrc')

    # Main training loop
    start_time = time.time()
    for epoch in range(1, epochs + 1):
        print(f"\n{'='*20} Epoch {epoch}/{epochs} {'='*20}")

        # Update signatures every 5 epochs
        if epoch % 5 == 0:
            print("Updating activation signatures...")
            signatures = signature_generator.aggregate_signatures(train_loader, device)
            confidence_loss_fn = ConfidenceLoss(signatures, coverage_layer_indices, method='mrc')

        # Training
        train_metrics = train(model, device, train_loader, optimizer, epoch,
                            confidence_loss_fn, signatures, coverage_layer_indices,
                            confidence_threshold)

        # Testing
        test_metrics = test(model, device, test_loader, untrusted_loader,
                          confidence_loss_fn, signatures, coverage_layer_indices,
                          confidence_threshold)

        # Store all data
        logger.add_epoch_data(epoch, train_metrics, test_metrics, test_metrics['trusted']['coverage_stats'])

        # Update learning rate
        scheduler.step()

        # Epoch summary
        print(f"\nEpoch {epoch} Summary:")
        print(f"Train - Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.2f}%")
        print(f"Test - Loss: {test_metrics['trusted']['loss']:.4f}, Accuracy: {test_metrics['trusted']['accuracy']:.2f}%")
        print(f"Untrusted Data Avg Confidence: {test_metrics['untrusted']['avg_confidence']:.4f}")

    # Finalize and save data
    logger.finalize(epochs)
    save_filename = f"training_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    logger.save(save_filename)



    # Final summary
    total_time = time.time() - start_time
    print(f"\n{'='*20} Training Complete {'='*20}")
    print(f"Total training time: {total_time:.2f}s ({total_time/60:.2f}min)")
    print(f"Final test accuracy: {test_metrics['trusted']['accuracy']:.2f}%")
    print(f"Final test avg confidence: {test_metrics['trusted']['avg_confidence']:.4f}")
    print(f"Final high confidence accuracy: {test_metrics['trusted']['high_conf_accuracy']:.2f}%")
    print(f"Final untrusted data avg confidence: {test_metrics['untrusted']['avg_confidence']:.4f}")



In [None]:
if __name__ == "__main__":
    main()

Using device: cuda


100%|██████████| 170M/170M [00:03<00:00, 43.4MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 17.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 478kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.36MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.9MB/s]
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 172MB/s]



Train Epoch: 1 [0/50000] Loss: 2.3256, CE: 2.3251, Conf Loss: 0.0005, Accuracy: 10.94%, High Conf Acc: 0.00%, Time: 2.78s
Train Epoch: 1 [6400/50000] Loss: 0.6306, CE: 0.6043, Conf Loss: 0.0263, Accuracy: 60.32%, High Conf Acc: 94.80%, Time: 190.09s
Train Epoch: 1 [12800/50000] Loss: 0.5326, CE: 0.5106, Conf Loss: 0.0221, Accuracy: 70.59%, High Conf Acc: 95.84%, Time: 377.57s
Train Epoch: 1 [19200/50000] Loss: 0.5479, CE: 0.5238, Conf Loss: 0.0241, Accuracy: 75.37%, High Conf Acc: 96.53%, Time: 564.87s
Train Epoch: 1 [25600/50000] Loss: 0.3856, CE: 0.3664, Conf Loss: 0.0192, Accuracy: 77.72%, High Conf Acc: 96.83%, Time: 752.31s
Train Epoch: 1 [32000/50000] Loss: 0.4714, CE: 0.4548, Conf Loss: 0.0166, Accuracy: 79.43%, High Conf Acc: 97.08%, Time: 940.01s
Train Epoch: 1 [38400/50000] Loss: 0.2430, CE: 0.2232, Conf Loss: 0.0198, Accuracy: 80.83%, High Conf Acc: 97.20%, Time: 1127.46s
Train Epoch: 1 [44800/50000] Loss: 0.3485, CE: 0.3267, Conf Loss: 0.0219, Accuracy: 81.75%, High Conf A