In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from opacus import PrivacyEngine
import numpy as np
import os

# Define data transformations for data augmentation and normalization
train_transforms = [
        transforms.Resize(size=(180,180)),
        transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
        # transforms.RandomRotation(degrees=60),
        #transforms.RandomGrayscale(),
        #transforms.RandomHorizontalFlip(),
    ]
grayscale_transforms = train_transforms.copy()
grayscale_transforms.append(transforms.Grayscale(num_output_channels=3))
train_transforms_end = [
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
data_transforms = {
    'train': transforms.Compose(train_transforms+train_transforms_end),
    'val': transforms.Compose([
        transforms.Resize(size=(180,180)),
        #transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


In [2]:
# Define the data directory
ROOT_DIR = os.path.abspath(os.curdir)
print(ROOT_DIR)
dataset_name='dataset3'
data_dir = os.path.join(ROOT_DIR, dataset_name)
print(data_dir)

concatdatasets = []
concatdatasets_val = []
# Create data loaders
#og_image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), transforms.Compose(train_transforms_end)) for x in ['train']}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
tr_grayscale_dataset = {x: datasets.ImageFolder(os.path.join(data_dir, x), transforms.Compose(grayscale_transforms+train_transforms_end)) for x in ['train','val']}
grayscale_transforms.append(transforms.RandomHorizontalFlip(p=1))
tr_grayscale_flipped_dataset = {x: datasets.ImageFolder(os.path.join(data_dir, x), transforms.Compose(grayscale_transforms+train_transforms_end)) for x in ['train','val']}

#concatdatasets.append(og_image_datasets['train'])
concatdatasets.append(image_datasets['train'])
concatdatasets.append(tr_grayscale_dataset['train'])
concatdatasets.append(tr_grayscale_flipped_dataset['train'])

concatdatasets_val.append(image_datasets['val'])
concatdatasets_val.append(tr_grayscale_dataset['val'])
concatdatasets_val.append(tr_grayscale_flipped_dataset['val'])

r_times = 5;
rotate_transf = train_transforms
tr_rotate = []
print(rotate_transf)
for i in range(5):
    rotate_transf = train_transforms.copy()
    for j in range(i):
         rotate_transf.append(transforms.RandomRotation(degrees=(60,60)))
    concatdatasets.append(datasets.ImageFolder(os.path.join(data_dir, 'train'), transforms.Compose(rotate_transf+train_transforms_end)))
    concatdatasets_val.append(datasets.ImageFolder(os.path.join(data_dir, 'val'), transforms.Compose(rotate_transf+train_transforms_end)))

image_datasets['train'] = ConcatDataset(concatdatasets)
image_datasets['val'] = ConcatDataset(concatdatasets_val)
print(len(image_datasets['train']))
print(len(image_datasets['val']))

/home/lemawul/PyTorch
/home/lemawul/PyTorch/dataset3
[Resize(size=(180, 180), interpolation=bilinear, max_size=None, antialias=True), ColorJitter(brightness=(0.5, 1.5), contrast=(0.5, 1.5), saturation=(0.5, 1.5), hue=(-0.5, 0.5))]
2768
288


In [8]:
# Ensure smaller batch sizes to avoid OOM
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
print(dataset_sizes)

class_names = image_datasets['train'].datasets[0].classes
print(class_names)


{'train': 2768, 'val': 288}
['Circle', 'Cross', 'Goat', 'Person', 'Spiral', 'Stag', 'Zigzag']


In [9]:
def validate_model(model, val_loader):
    import torch
    model.eval()  # Set model to evaluation mode
    criterion = torch.nn.CrossEntropyLoss()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item()

            # Compute accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == y_batch).sum().item()
            total += y_batch.size(0)

    val_loss /= len(val_loader)
    accuracy = correct / total
    print(f"Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}")
    return val_loss, accuracy

In [11]:
import os

# Directory containing saved teacher models
teacher_dir='resnet18'
TEACHER_MODELS_DIR = os.path.join(ROOT_DIR, teacher_dir)  # Replace with the actual path

# Function to load a teacher model
def load_teacher_model(model_path, num_classes=7):
    model = models.resnet18(pretrained=True)
    num_features = model.fc.in_features
    #num_features = model.classifier.in_features
    model.fc = nn.Linear(num_features, num_classes)
    model.load_state_dict(torch.load(model_path))  # Load model parameters
    model.eval()  # Set to evaluation mode
    return model

# Load all teacher models in the directory
teacher_ensemble = []
for filename in os.listdir(TEACHER_MODELS_DIR):
    if filename.endswith(".pt") or filename.endswith(".pth"):  # Check for model files
        model_path = os.path.join(TEACHER_MODELS_DIR, filename)
        teacher_ensemble.append(load_teacher_model(model_path, len(class_names)))

print(f"Loaded {len(teacher_ensemble)} teacher models.")


Loaded 3 teacher models.


In [12]:
for idx, teacher in enumerate(teacher_ensemble):
    print(f"Validating Teacher {idx + 1}...")
    validate_model(teacher, dataloaders['val'])

Validating Teacher 1...
Validation Loss: 1.6871, Accuracy: 0.6875
Validating Teacher 2...
Validation Loss: 2.4529, Accuracy: 0.6250
Validating Teacher 3...
Validation Loss: 2.6742, Accuracy: 0.5694


In [13]:
teacher_logits = []  # To store teacher predictions

from tqdm import tqdm  # For progress bar

def generate_teacher_logits(teacher_model, train_loader):
    teacher_model.eval()  # Set to evaluation mode to prevent training operations (e.g., dropout)

    running_loss = 0.0

    # Add a tqdm progress bar for batch progress
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc="Generating Logits")

    for batch_idx, (X_batch, y_batch) in progress_bar:
        # No training, no optimizer
        with torch.no_grad():  # Disable gradient computation
            logits = teacher_model(X_batch)  # Generate logits by passing input through the teacher model

        # Optionally, apply softmax to get probabilities (for distillation)
        softmax = torch.nn.Softmax(dim=1)
        teacher_probs = softmax(logits)

        # Save logits to the global teacher_logits list
        teacher_logits.append(logits.cpu().numpy())  # Convert to numpy and append (for later use in distillation or analysis)

        # Optionally compute and track the loss (monitoring purposes only)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, y_batch)
        running_loss += loss.item()

        # Update tqdm progress bar description with batch loss (optional)
        progress_bar.set_postfix({"Batch Loss": loss.item()})

    # If you tracked loss, you can print average loss
    avg_loss = running_loss / len(train_loader)
    print(f"Average Loss: {avg_loss:.4f}")


# Example of calling the function for each teacher in the ensemble
for teacher in teacher_ensemble:
    # We don't need to retrain the teacher, just generate logits
    teacher.eval()  # Set the teacher model to evaluation mode
    generate_teacher_logits(teacher, dataloaders['train'])

# After running this, teacher_logits will contain logits from all batches across all teachers in the ensemble.
print(f"Generated Logits for {len(teacher_logits)} batches.")


Generating Logits: 100%|███████████████████████████████████████████████| 87/87 [00:20<00:00,  4.26it/s, Batch Loss=2.19]

Average Loss: 2.1338



Generating Logits: 100%|████████████████████████████████████████████| 87/87 [00:20<00:00,  4.26it/s, Batch Loss=1.01e-6]

Average Loss: 0.0056



Generating Logits: 100%|███████████████████████████████████████████████| 87/87 [00:20<00:00,  4.25it/s, Batch Loss=1.09]

Average Loss: 2.7878
Generated Logits for 261 batches.





In [14]:
# Aggregate noisy logits with smaller batches and CPU memory optimization
def aggregate_teacher_logits(teacher_ensemble, data_loader, noise_scale=0.1, class_num=7):
    logits_list = []

    for X_batch, _ in data_loader:
        batch_logits = torch.zeros(len(X_batch), class_num).cpu()  # Move to CPU
        for teacher in teacher_ensemble:
            # Ensure no gradients are computed and move to CPU
            with torch.no_grad():
                logits = teacher(X_batch.cpu())  # Move inputs to CPU if they're on GPU
                batch_logits += logits
        batch_logits /= len(teacher_ensemble)  # Average logits
        noisy_logits = batch_logits + torch.normal(0, noise_scale, batch_logits.shape).cpu()  # Add DP noise
        logits_list.append(noisy_logits)

        # Optional: clear cache after each batch to free up memory
        torch.cuda.empty_cache()

    return torch.cat(logits_list)

# Generate noisy logits from the teacher ensemble
teacher_logits_noisy = aggregate_teacher_logits(teacher_ensemble, dataloaders['train'], noise_scale=0.1, class_num=len(class_names))

# Check size of noisy logits
print(f"Generated Noisy Logits of size: {teacher_logits_noisy.shape}")


Generated Noisy Logits of size: torch.Size([2768, 7])


In [15]:
student_model = models.resnet18(pretrained=True)
num_features = student_model.fc.in_features
student_model.fc = nn.Linear(num_features, len(class_names))


In [16]:
import torch
import torch.nn as nn

def replace_batchnorm_with_groupnorm(model):
    for name, module in model.named_children():
        if isinstance(module, nn.BatchNorm2d):  # BatchNorm2d example
            # Create GroupNorm layer with appropriate num_groups (typically batch size or smaller)
            num_features = module.num_features
            num_groups = 32  # You can adjust this value as needed
            group_norm = nn.GroupNorm(num_groups, num_features)
            setattr(model, name, group_norm)
        else:
            replace_batchnorm_with_groupnorm(module)
    return model

# Example of replacing BatchNorm with GroupNorm in your student model
student_model = replace_batchnorm_with_groupnorm(student_model)

In [17]:
print(dataloaders['train'].batch_size)  # This should give you a valid integer batch size
print(dataloaders['val'].batch_size) 

32
32


In [18]:
def train_student_with_validation(student_model, teacher_logits, train_loader, val_loader, epsilon, delta):
    import torch.nn.functional as F
    from opacus import PrivacyEngine
    from torch.optim import SGD
    import torch.nn as nn
    from torch.utils.tensorboard import SummaryWriter

    # Initialize TensorBoard writer
    writer = SummaryWriter()

    # Extract and store batch size before DP is applied
    batch_size = train_loader.batch_size
    if batch_size is None:
        raise ValueError("Batch size is not defined in the DataLoader. Please ensure it is set during initialization.")

    student_optimizer = SGD(student_model.parameters(), lr=0.01)
    criterion = nn.KLDivLoss(reduction="batchmean")  # KL Divergence for distillation loss
    privacy_engine = PrivacyEngine()

    # Add differential privacy to the student model
    student_model, student_optimizer, train_loader = privacy_engine.make_private(
        module=student_model,
        optimizer=student_optimizer,
        data_loader=train_loader,
        noise_multiplier=1.0,  # Set noise multiplier for DP
        max_grad_norm=1.0,    # Clip gradients
    )

    # Training loop
    for epoch in range(100):  # Train for 100 epochs
        student_model.train()  # Set the model to training mode
        epoch_loss = 0.0

        for i, (X_batch, _) in enumerate(train_loader):
            student_optimizer.zero_grad()

            # Get teacher's logits for this batch
            target_logits = teacher_logits[:X_batch.size(0)]

            # Forward pass through the student
            student_outputs = student_model(X_batch)

            # Compute distillation loss
            loss = criterion(
                F.log_softmax(student_outputs, dim=1),
                F.softmax(target_logits, dim=1)
            )
            loss.backward()
            student_optimizer.step()

            # Accumulate epoch loss
            epoch_loss += loss.item()

        # Calculate average loss for the epoch
        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch + 1} Training Loss: {epoch_loss:.4f}")

        # Log training loss to TensorBoard
        writer.add_scalar('Loss/Training', epoch_loss, epoch + 1)

        # Validate the student model after each epoch
        print(f"Validating Student Model at Epoch {epoch + 1}...")
        val_loss, val_accuracy = validate_model(student_model, val_loader)

        # Log validation loss and accuracy to TensorBoard
        writer.add_scalar('Loss/Validation', val_loss, epoch + 1)
        writer.add_scalar('Accuracy/Validation', val_accuracy, epoch + 1)

    # Privacy accounting
    epsilon_value = privacy_engine.get_epsilon(delta=delta)
    print(f"Model trained with (ε = {epsilon_value:.2f}, δ = {delta}) differential privacy")

    # Log privacy parameters to TensorBoard
    writer.add_text('Privacy', f"ε = {epsilon_value:.2f}, δ = {delta}")

    # Close the TensorBoard writer
    writer.close()

In [19]:
print(dataloaders['train'].batch_size)
train_student_with_validation(student_model, teacher_logits_noisy, dataloaders['train'], dataloaders['val'], epsilon=3.0, delta=1e-5)

32




Epoch 1 Training Loss: 1.7281
Validating Student Model at Epoch 1...
Validation Loss: 1.8533, Accuracy: 0.3056
Epoch 2 Training Loss: 1.7204
Validating Student Model at Epoch 2...
Validation Loss: 1.8614, Accuracy: 0.2917
Epoch 3 Training Loss: 1.7110
Validating Student Model at Epoch 3...
Validation Loss: 1.8617, Accuracy: 0.3021
Epoch 4 Training Loss: 1.7125
Validating Student Model at Epoch 4...
Validation Loss: 1.8581, Accuracy: 0.3056
Epoch 5 Training Loss: 1.6954
Validating Student Model at Epoch 5...
Validation Loss: 1.8652, Accuracy: 0.2986
Epoch 6 Training Loss: 1.7078
Validating Student Model at Epoch 6...
Validation Loss: 1.8605, Accuracy: 0.3021
Epoch 7 Training Loss: 1.7009
Validating Student Model at Epoch 7...
Validation Loss: 1.8644, Accuracy: 0.3021
Epoch 8 Training Loss: 1.7026
Validating Student Model at Epoch 8...
Validation Loss: 1.8582, Accuracy: 0.2986
Epoch 9 Training Loss: 1.6928
Validating Student Model at Epoch 9...
Validation Loss: 1.8700, Accuracy: 0.2882
E

In [20]:
torch.save(student_model.state_dict(), f"student_model_opacus_logged_{teacher_dir}_100.pth")