In [1]:
from loss_traces.data_processing.data_processing import (
    get_no_shuffle_train_loader,
    get_num_classes,
    get_trainset,
    get_testset,
    prepare_transform,
    prepare_loaders,)
from loss_traces.models.model import load_model
import random
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np

In [2]:
os.chdir("../")

In [3]:

MODEL_DIR = "trained_models"
config = {
    "dataset": "CIFAR10",
    "arch": "wrn28-2",
    "batchsize": 16,
    "num_workers": 4,
    "augment": True,
    "device": "cuda",
}


def _initialize_model_and_data(config):
    attack_loaders = [
        get_no_shuffle_train_loader(
            config["dataset"],
            config["arch"],
            config["batchsize"],
            config["num_workers"],
        )
    ]

    if config["augment"]:
        attack_loaders.append(
            get_no_shuffle_train_loader(
                config["dataset"],
                config["arch"],
                config["batchsize"],
                config["num_workers"],
                mirror_all=True,
            )
        )

    model = load_model(config["arch"], get_num_classes(config["dataset"])).to(
        config["device"]
    )

    return model, attack_loaders


In [4]:
## load student and teacher models
teacher_exp_id = "wrn28-2_CIFAR_5_l0"
student_exp_id = "wrn28-2_CIFAR_5_l3"

teacher, dataloader = _initialize_model_and_data(config)
teacher_saves = torch.load(f"{MODEL_DIR}/models/{teacher_exp_id}/target", weights_only=False)
teacher.load_state_dict(teacher_saves["model_state_dict"])

student, _ = _initialize_model_and_data(config)
saves = torch.load(f"{MODEL_DIR}/models/{student_exp_id}/target", weights_only=False)
# student.load_state_dict(saves["model_state_dict"])


In [None]:
def confidence_mask_logits(logits, masking_strength=1.0, temperature=1.0):
    """
    Apply confidence masking: preserve rank order but randomize confidence values
    
    Args:
        logits: original model logits [batch_size, num_classes]
        masking_strength: how much to randomize (0=no masking, 1=full random)
        temperature: temperature for softmax conversion
    
    Returns:
        masked_logits: logits with preserved order but masked confidence
    """
    batch_size, num_classes = logits.shape
    device = logits.device
    
    # Get the ranking (order) of original predictions
    _, original_rankings = torch.sort(logits, dim=1, descending=True)
    
    # Generate random confidence values for each sample
    if masking_strength > 0:
        # Random values between 0 and masking_strength
        random_confidences = torch.rand(batch_size, num_classes, device=device) * masking_strength
        
        # Create new logits that preserve ranking but use random confidences
        masked_logits = torch.zeros_like(logits)
        
        for batch_idx in range(batch_size):
            # Sort the random confidences in descending order
            sorted_random, _ = torch.sort(random_confidences[batch_idx], descending=True)
            
            # Assign these sorted random values according to original ranking
            for rank_idx, class_idx in enumerate(original_rankings[batch_idx]):
                masked_logits[batch_idx, class_idx] = sorted_random[rank_idx]
        
        # Apply temperature scaling
        masked_logits = masked_logits / temperature
        
    else:
        # No masking, return original logits
        masked_logits = logits / temperature
    
    return masked_logits


class ConfidenceMaskedKnowledgeDistillation(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.7, masking_strength=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.masking_strength = masking_strength
        self.ce_loss = nn.CrossEntropyLoss()
        
    def forward(self, student_logits, teacher_logits, targets):
        ce_loss = self.ce_loss(student_logits, targets)
        
        # Apply confidence masking to teacher logits
        masked_teacher_logits = confidence_mask_logits(
            teacher_logits, 
            masking_strength=self.masking_strength,
            temperature=self.temperature
        )
        
        # Knowledge distillation 
        soft_targets = F.softmax(masked_teacher_logits / self.temperature, dim=1)
        soft_predictions = F.log_softmax(student_logits / self.temperature, dim=1)
        
        kd_loss = F.kl_div(soft_predictions, soft_targets, reduction='batchmean') * (self.temperature ** 2)
        
        # Combined loss
        total_loss = (1 - self.alpha) * ce_loss + self.alpha * kd_loss
        
        return total_loss, ce_loss, kd_loss

def train_with_confidence_masked_kd(teacher_model, student_model, train_loader, 
                                  optimizer, device, temperature=3.0, 
                                  alpha=0.7, masking_strength=0.5):
    """
    Train student with confidence-masked knowledge distillation
    """
    teacher_model.eval()
    student_model.train()
    
    kd_loss_fn = ConfidenceMaskedKnowledgeDistillation(
        temperature=temperature,
        alpha=alpha, 
        masking_strength=masking_strength
    )
    
    epoch_stats = {'total_loss': 0, 'ce_loss': 0, 'kd_loss': 0}
    
    for  data, targets, batch_idx in train_loader:
        data, targets = data.to(device), targets.to(device)
        
        optimizer.zero_grad()
        
        with torch.no_grad():
            teacher_logits = teacher_model(data)
        
        student_logits = student_model(data)
        
        # Calculate loss with confidence masking applied to teacher logits
        total_loss, ce_loss, kd_loss = kd_loss_fn(student_logits, teacher_logits, targets)
        
        total_loss.backward()
        optimizer.step()
        
        # Track metrics
        epoch_stats['total_loss'] += total_loss.item()
        epoch_stats['ce_loss'] += ce_loss.item()
        epoch_stats['kd_loss'] += kd_loss.item()
        
            
    for key in epoch_stats:
        epoch_stats[key] /= len(train_loader)
    
    return epoch_stats

def evaluate_masking_effect(model, data_loader, device):
    model.eval()
    
    with torch.no_grad():
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, targets, _indices in data_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs).squeeze(-1).squeeze(-1)

                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct_idx = predicted.eq(targets.data).cpu()
                correct += correct_idx.sum()

        acc = 100. * float(correct) / float(total)
        print('test acc:', acc)
        return acc


In [8]:
from torch.utils.data import Subset, DataLoader
train_transform = prepare_transform(config['dataset'], config['arch'], config['augment'])
plain_transform = prepare_transform(config['dataset'], config['arch'])

train_superset = get_trainset(config['dataset'], train_transform)
testset = get_testset(config['dataset'], plain_transform)

num_classes = get_num_classes(config['dataset'])


trainset = Subset(train_superset, saves['trained_on_indices'])

train_loader = DataLoader(
    trainset,
    batch_size=config['batchsize'],
    shuffle=True,
    num_workers=config['num_workers'],
    pin_memory=True,
)
test_loader = DataLoader(
    testset,
    batch_size=config['batchsize'],
    shuffle=False,
    num_workers=config['num_workers'],
    pin_memory=True,
)

print(f"Train set size: {len(trainset)}")
print(f"Test set size: {len(testset)}")

Train set size: 21435
Test set size: 10000


In [7]:
def set_seed(seed=0):
    # set random seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed = 2546
set_seed(seed)

In [11]:
os.makedirs("trained_models/models/test_kd", exist_ok=True)

In [15]:
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)


for epoch in range(100):
    print(f"\nEpoch {epoch+1}")
    
    # Train with different masking strengths
    stats = train_with_confidence_masked_kd(
        teacher_model=teacher,
        student_model=student, 
        train_loader=train_loader,
        optimizer=optimizer,
        device=config["device"],
        temperature=3.0,
        alpha=0.7,  # Balance between CE and KD
        masking_strength=0  # How much to randomize confidence
    )
    results = evaluate_masking_effect(student, test_loader, config["device"])
    print(f"Epoch {epoch+1} - Total Loss: {stats['total_loss']:.4f}, CE Loss: {stats['ce_loss']:.4f}, KD Loss: {stats['kd_loss']:.4f}")
    print(f"Epoch {epoch+1} - Test Accuracy: {results:.2f}%")
    if epoch%5 == 0:
        if results > acc:
            acc = results
            dict = {
                'model_state_dict': student.state_dict(),
                'trained_on_indices': saves['trained_on_indices'],
                'arch': config['arch'],
                'seed' : seed,
                'hyperparameters': saves['hyperparameters'],
                'dataset': config['dataset'],
                'temperature': 3.0,
                'masking_strength': 0.5,
                'alpha': 0.7,
                'test_acc': results
            }
            torch.save(dict, f"trained_models/models/test_kd/epoch_{epoch}_acc_{acc}.pt")
            print(f"Saved model with accuracy {acc} at epoch {epoch}")

    #     acc = results
    #     dict = {
    #         'model_state_dict': student.state_dict(),
    #         'trained_on_indices': teacher_saves['trained_on_indices'],
    #         'arch': config['arch'],
    #         'seed' : seed,
    #         'hyperparameters': saves['hyperparameters'],
    #         'dataset': config['dataset'],
    #         'temperature': 3.0,
    #         'masking_strength': 0.5,
    #         'alpha': 0.7,
    #         'test_acc': results
    #     }
    #     torch.save(dict, f"trained_models/models/test_kd/epoch_{epoch}_acc_{acc}.pt")
    #     print(f"Saved model with accuracy {acc} at epoch {epoch}")
    # else:
    #     print(f"Model did not improve at epoch {epoch}, current best accuracy: {acc}")
    #     break



Epoch 1
test acc: 83.53
Epoch 1 - Total Loss: 0.7916, CE Loss: 0.2274, KD Loss: 1.0334
Epoch 1 - Test Accuracy: 83.53%
Saved model with accuracy 83.53 at epoch 0

Epoch 2
test acc: 83.85
Epoch 2 - Total Loss: 0.7318, CE Loss: 0.1962, KD Loss: 0.9613
Epoch 2 - Test Accuracy: 83.85%

Epoch 3
test acc: 83.11
Epoch 3 - Total Loss: 0.6924, CE Loss: 0.1872, KD Loss: 0.9089
Epoch 3 - Test Accuracy: 83.11%

Epoch 4
test acc: 82.61
Epoch 4 - Total Loss: 0.6787, CE Loss: 0.1781, KD Loss: 0.8932
Epoch 4 - Test Accuracy: 82.61%

Epoch 5
test acc: 84.49
Epoch 5 - Total Loss: 0.6564, CE Loss: 0.1749, KD Loss: 0.8628
Epoch 5 - Test Accuracy: 84.49%

Epoch 6
test acc: 83.73
Epoch 6 - Total Loss: 0.6133, CE Loss: 0.1546, KD Loss: 0.8098
Epoch 6 - Test Accuracy: 83.73%
Saved model with accuracy 83.73 at epoch 5

Epoch 7
test acc: 84.92
Epoch 7 - Total Loss: 0.5958, CE Loss: 0.1504, KD Loss: 0.7867
Epoch 7 - Test Accuracy: 84.92%

Epoch 8
test acc: 84.27
Epoch 8 - Total Loss: 0.5912, CE Loss: 0.1508, KD

In [17]:


dict = {
    'model_state_dict': student.state_dict(),
    'trained_on_indices': teacher_saves['trained_on_indices'],
    'arch': config['arch'],
    'seed' : seed,
    'hyperparameters': saves['hyperparameters'],
    'dataset': config['dataset'],
    'temperature': 3.0,
    'masking_strength': 0,
    'alpha': 0.7,
    'test_acc': results
}
path = f"{MODEL_DIR}/models/kd/target"
if not os.path.exists(os.path.dirname(path)):
    os.makedirs(os.path.dirname(path))

torch.save(dict, path)
print(f"Model saved to {path}")

Model saved to trained_models/models/kd/target


In [20]:
## recompute lira
import subprocess
subprocess.run([ "python", "-m", "src.loss_traces.run_attack_pipeline",
    "--exp_id", "kd",
    "--target", "target",
    "--arch", "wrn28-2",
    "--dataset", "CIFAR10",
    "--lira-only",
    "--layer", "0",
])

Initialized Attack Pipeline for experiment: kd
Architecture: wrn28-2, Dataset: CIFAR10
Shadow models: 64, GPU: CPU
Model directory: trained_models/models/kd
Storage directory: trained_models

STARTING: Running LiRA attack on kd
Command: python3 -m loss_traces.attacks --exp_id kd --attack LiRA --arch wrn28-2 --dataset CIFAR10 --gpu  --target_id target --n_shadows 32 --layer 0 --layer_folder None
agument=True, arch=wrn28-2, dataset=CIFAR10
Dataset size: 50000
No intermediate results found at trained_models/scaled_logits_intermediate/kd.pt. Computing intermediate results...
Computing all metrics...
Computed metrics for shadow 0
Computed metrics for shadow 1
Computed metrics for shadow 2
Computed metrics for shadow 3
Computed metrics for shadow 4
Computed metrics for shadow 5
Computed metrics for shadow 6
Computed metrics for shadow 7
Computed metrics for shadow 8
Computed metrics for shadow 9
Computed metrics for shadow 10
Computed metrics for shadow 11
Computed metrics for shadow 12
Comp

CompletedProcess(args=['python', '-m', 'src.loss_traces.run_attack_pipeline', '--exp_id', 'kd', '--target', 'target', '--arch', 'wrn28-2', '--dataset', 'CIFAR10', '--lira-only', '--layer', '0'], returncode=0)

## 0 Masking

In [8]:
optimizer = torch.optim.Adam(student.parameters(), lr=0.001)

for epoch in range(10):
    print(f"\nEpoch {epoch+1}")
    
    # Train with different masking strengths
    stats = train_with_confidence_masked_kd(
        teacher_model=teacher,
        student_model=student, 
        train_loader=train_loader,
        optimizer=optimizer,
        device=config["device"],
        temperature=3.0,
        alpha=0.7,  
        masking_strength=0  
    )
    
    print(f"Epoch stats: {stats}")

# Evaluate the effect of masking on final model
print("\nEvaluating masking effect on student model:")
results = evaluate_masking_effect(student, test_loader, config["device"])



Epoch 1
Epoch stats: {'total_loss': 2.6427344292764547, 'ce_loss': 0.9679769947189392, 'kd_loss': 3.3604876547384475}

Epoch 2
Epoch stats: {'total_loss': 1.9661887448259598, 'ce_loss': 0.7216846397705495, 'kd_loss': 2.499547674308125}

Epoch 3
Epoch stats: {'total_loss': 1.7237788391128535, 'ce_loss': 0.6312054941205455, 'kd_loss': 2.1920245812248895}

Epoch 4
Epoch stats: {'total_loss': 1.549623906555194, 'ce_loss': 0.5790883142594069, 'kd_loss': 1.9655677521373702}

Epoch 5
Epoch stats: {'total_loss': 1.4271465303115332, 'ce_loss': 0.5295255353301287, 'kd_loss': 1.811841260601295}

Epoch 6
Epoch stats: {'total_loss': 1.3152549459967793, 'ce_loss': 0.4842796057506786, 'kd_loss': 1.6713872516643369}

Epoch 7
Epoch stats: {'total_loss': 1.215445734641526, 'ce_loss': 0.44351075715144056, 'kd_loss': 1.5462750258235236}

Epoch 8
Epoch stats: {'total_loss': 1.139014896269952, 'ce_loss': 0.4168308498516622, 'kd_loss': 1.4485223602958772}

Epoch 9
Epoch stats: {'total_loss': 1.0707332525147

In [9]:


dict = {
    'model_state_dict': student.state_dict(),
    'trained_on_indices': teacher_saves['trained_on_indices'],
    'arch': config['arch'],
    'seed' : seed,
    'hyperparameters': saves['hyperparameters'],
    'dataset': config['dataset'],
    'temperature': 3.0,
    'masking_strength': 0.5,
    'alpha': 0.7,
    'test_acc': results
}
path = f"{MODEL_DIR}/models/kd_0_mask/target"
if not os.path.exists(os.path.dirname(path)):
    os.makedirs(os.path.dirname(path))

torch.save(dict, path)
print(f"Model saved to {path}")

Model saved to trained_models/models/kd_0_mask/target


In [10]:
import subprocess
subprocess.run([ "python", "-m", "src.loss_traces.run_attack_pipeline",
    "--exp_id", "kd_0_mask",
    "--target", "target",
    "--arch", "wrn28-2",
    "--dataset", "CIFAR10",
    "--lira-only",
    "--layer", "0",
])

Initialized Attack Pipeline for experiment: kd_0_mask
Architecture: wrn28-2, Dataset: CIFAR10
Shadow models: 64, GPU: CPU
Model directory: trained_models/models/kd_0_mask
Storage directory: trained_models

STARTING: Running LiRA attack on kd_0_mask
Command: python3 -m loss_traces.attacks --exp_id kd_0_mask --attack LiRA --arch wrn28-2 --dataset CIFAR10 --gpu  --target_id target --n_shadows 32 --layer 0 --layer_folder None
agument=True, arch=wrn28-2, dataset=CIFAR10
Dataset size: 50000
No intermediate results found at trained_models/scaled_logits_intermediate/kd_0_mask.pt. Computing intermediate results...
Computing all metrics...
Computed metrics for shadow 0
Computed metrics for shadow 1
Computed metrics for shadow 2
Computed metrics for shadow 3
Computed metrics for shadow 4
Computed metrics for shadow 5
Computed metrics for shadow 6
Computed metrics for shadow 7
Computed metrics for shadow 8
Computed metrics for shadow 9
Computed metrics for shadow 10
Computed metrics for shadow 11


CompletedProcess(args=['python', '-m', 'src.loss_traces.run_attack_pipeline', '--exp_id', 'kd_0_mask', '--target', 'target', '--arch', 'wrn28-2', '--dataset', 'CIFAR10', '--lira-only', '--layer', '0'], returncode=0)