<a href="https://colab.research.google.com/github/rdelhibabu/SubDataBase-0.91s-Reproducibility/blob/main/cibo_vfl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch botorch gpytorch scikit-learn

Collecting botorch
  Downloading botorch-0.16.1-py3-none-any.whl.metadata (9.9 kB)
Collecting gpytorch
  Downloading gpytorch-1.15.1-py3-none-any.whl.metadata (8.3 kB)
Collecting pyre_extensions (from botorch)
  Downloading pyre_extensions-0.0.32-py3-none-any.whl.metadata (4.0 kB)
Collecting linear_operator>=0.6 (from botorch)
  Downloading linear_operator-0.6-py3-none-any.whl.metadata (15 kB)
Collecting pyro-ppl>=1.8.4 (from botorch)
  Downloading pyro_ppl-1.9.1-py3-none-any.whl.metadata (7.8 kB)
Collecting jaxtyping (from gpytorch)
  Downloading jaxtyping-0.3.7-py3-none-any.whl.metadata (7.3 kB)
Collecting pyro-api>=0.1.1 (from pyro-ppl>=1.8.4->botorch)
  Downloading pyro_api-0.1.2-py3-none-any.whl.metadata (2.5 kB)
Collecting wadler-lindig>=0.1.3 (from jaxtyping->gpytorch)
  Downloading wadler_lindig-0.1.7-py3-none-any.whl.metadata (17 kB)
Collecting typing-inspect (from pyre_extensions->botorch)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting mypy-e

In [4]:
import torch
import torch.nn as nn
import numpy as np
import math
from sklearn.datasets import load_digits
from sklearn.preprocessing import StandardScaler

# BoTorch & GPyTorch imports
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition import UpperConfidenceBound
from botorch.optim import optimize_acqf
from gpytorch.kernels import ScaleKernel, MaternKernel

# ==========================================
# 1. VFL Simulation Environment (MAB-VFL Baseline Style)
# ==========================================

class BottomModel(nn.Module):
    """Client model: Projects raw features to an embedding."""
    def __init__(self, input_dim, embed_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 32),
            nn.ReLU(),
            nn.Linear(32, embed_dim),
            nn.ReLU() # Embeddings are typically activated
        )

    def forward(self, x):
        return self.net(x)

class TopModel(nn.Module):
    """Server model: Aggregates embeddings and predicts class."""
    def __init__(self, total_embed_dim, num_classes=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(total_embed_dim, 32),
            nn.ReLU(),
            nn.Linear(32, num_classes)
        )

    def forward(self, fused_embedding):
        return self.net(fused_embedding)

class VFLSystem:
    def __init__(self, n_clients, feature_dim, num_classes=10):
        self.n_clients = n_clients
        self.feature_dim = feature_dim
        # Distribute features evenly
        self.client_feat_dim = feature_dim // n_clients
        self.embed_dim = 8

        self.clients = nn.ModuleList([
            BottomModel(self.client_feat_dim, self.embed_dim)
            for _ in range(n_clients)
        ])
        self.server = TopModel(self.embed_dim * n_clients, num_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward_pass(self, x_parts):
        # Unsqueeze each embedding to add a batch dimension (1, embed_dim)
        # before concatenating along dim=1.
        embeddings = [client(x).unsqueeze(0) for client, x in zip(self.clients, x_parts)]
        fused = torch.cat(embeddings, dim=1)
        logits = self.server(fused)
        return logits

    def get_loss(self, x_parts, target_label):
        """Returns loss. (Higher loss = Successful Attack)"""
        logits = self.forward_pass(x_parts)
        # We target Untargeted Attack: Maximize CrossEntropy with True Label
        loss = self.criterion(logits, target_label.unsqueeze(0))
        return loss

# ==========================================
# 2. CIBO-VFL Attack Engine
# ==========================================

class CIBOAttacker:
    def __init__(self, vfl_system, budget_T, latent_dim_per_client=2):
        self.vfl = vfl_system
        self.T = budget_T
        self.d_sub = latent_dim_per_client

        # Total latent dimension = (M clients * d_sub)
        # We structure latent space so Z = [z_client_1, z_client_2, ...]
        self.total_latent_dim = self.vfl.n_clients * self.d_sub

        # History
        self.train_x = [] # Latent vectors Z
        self.train_y = [] # Losses

        # GP Model
        self.gp = None

        # Client Importance Scores (Initialized Uniformly)
        self.beta = torch.ones(self.vfl.n_clients) / self.vfl.n_clients

    def upsample_projection(self, z_flat):
        """
        Maps latent vector Z (d) -> Perturbation Delta (D).
        We upsample each client's segment of Z separately.
        """
        z_reshaped = z_flat.view(self.vfl.n_clients, self.d_sub)
        deltas = []

        for i in range(self.vfl.n_clients):
            # Bilinear upsampling 1D (using interpolate)
            # z_c: [d_sub] -> [1, 1, d_sub] -> [1, 1, feat_dim]
            z_c = z_reshaped[i].view(1, 1, -1)
            delta_c = torch.nn.functional.interpolate(
                z_c,
                size=self.vfl.client_feat_dim,
                mode='linear',
                align_corners=False
            )
            deltas.append(delta_c.view(-1))

        return deltas # List of perturbations per client

    def update_surrogate_model(self):
        """Fits GP with ARD kernel to observation history."""
        if len(self.train_x) < 5: return # Not enough data yet

        X = torch.stack(self.train_x).double()
        Y = torch.stack(self.train_y).unsqueeze(-1).double()

        # Define GP with ARD (one lengthscale per latent dim)
        self.gp = SingleTaskGP(X, Y)
        self.gp.covar_module = ScaleKernel(
            MaternKernel(ard_num_dims=self.total_latent_dim)
        )

        mll = ExactMarginalLogLikelihood(self.gp.likelihood, self.gp)
        fit_gpytorch_mll(mll)

        # --- CIBO CORE: EXTRACT IMPORTANCE ---
        # Get lengthscales: shape (1, total_latent_dim)
        ls = self.gp.covar_module.base_kernel.lengthscale.detach().view(-1)

        # Aggregation: beta_client = sum(1/lengthscale) for that client's latent dims
        ls_reshaped = ls.view(self.vfl.n_clients, self.d_sub)
        inv_ls = 1.0 / (ls_reshaped + 1e-6) # Avoid div by zero
        self.beta = inv_ls.sum(dim=1)

        print(f"  [Info] Updated Client Importance: {self.beta.numpy().round(3)}")

    def select_clients(self):
        """Select Top-T clients based on beta scores."""
        # Add small noise to break ties/encourage exploration initially
        noisy_beta = self.beta + torch.randn_like(self.beta) * 0.01
        _, indices = torch.topk(noisy_beta, self.T)
        return indices.tolist()

    def run_attack(self, x_target, y_target, n_iters=30, update_freq=5):
        print(f"\n--- Starting CIBO-VFL Attack (T={self.T}) ---")
        x_parts_orig = torch.split(x_target, self.vfl.client_feat_dim)

        best_loss = -float('inf')

        # Initial Random Sampling (Warmup)
        # We perturb ALL clients randomly initially to gather sensitivity data
        z = torch.randn(self.total_latent_dim)

        for i in range(n_iters):
            # 1. Select Clients
            if i < 5:
                # Warmup: random selection or all
                active_clients = list(range(self.vfl.n_clients))
            else:
                active_clients = self.select_clients()

                # 2. Bayesian Optimization Step
                # Find z that maximizes acquisition function (UCB)
                if self.gp is not None:
                    UCB = UpperConfidenceBound(self.gp, beta=0.1)
                    # Optimize z in [-2, 2] bound
                    bounds = torch.stack([-2.0 * torch.ones(self.total_latent_dim),
                                           2.0 * torch.ones(self.total_latent_dim)])
                    candidate, _ = optimize_acqf(
                        UCB, bounds=bounds, q=1, num_restarts=5, raw_samples=20
                    )
                    z = candidate.squeeze()

            # 3. Construct Adversarial Example
            deltas = self.upsample_projection(z)
            x_adv = [t.clone() for t in x_parts_orig]

            # Apply perturbation ONLY to selected clients
            # (Note: In strict BO, masking inputs creates non-stationarity.
            #  Here we assume the GP learns the 'masked' effect as low sensitivity)
            for c_idx in active_clients:
                # Scale perturbation
                x_adv[c_idx] = x_adv[c_idx] + (deltas[c_idx] * 0.5)

            # 4. Query System
            loss = self.vfl.get_loss(x_adv, y_target)

            # 5. Record Data
            self.train_x.append(z)
            self.train_y.append(loss.detach()) # Maximize Loss

            if loss.item() > best_loss:
                best_loss = loss.item()

            print(f"Iter {i+1:02d} | Loss: {loss.item():.4f} | Active Clients: {active_clients}")

            # 6. Periodic Update of GP and Importance Scores
            if (i + 1) % update_freq == 0:
                self.update_surrogate_model()

        return best_loss

# ==========================================
# 3. Main Execution
# ==========================================

if __name__ == "__main__":
    # A. Setup Data
    data = load_digits()
    X = torch.tensor(data.data, dtype=torch.float32)
    y = torch.tensor(data.target, dtype=torch.long)
    X = (X - X.mean()) / X.std() # Normalize

    # B. Setup VFL System (4 Clients, 16 features each)
    vfl_env = VFLSystem(n_clients=4, feature_dim=64)

    # (Optional) Pre-train VFL model briefly so attack is meaningful
    # For demo, we assume random weights or just run as is.
    print("VFL System Initialized.")

    # C. Run Attack on a single sample
    target_idx = 0
    attacker = CIBOAttacker(vfl_env, budget_T=2, latent_dim_per_client=4)

    final_loss = attacker.run_attack(X[target_idx], y[target_idx], n_iters=25)
    print(f"\nAttack Complete. Best Loss Achieved: {final_loss:.4f}")

VFL System Initialized.

--- Starting CIBO-VFL Attack (T=2) ---
Iter 01 | Loss: 2.4335 | Active Clients: [0, 1, 2, 3]
Iter 02 | Loss: 2.4335 | Active Clients: [0, 1, 2, 3]
Iter 03 | Loss: 2.4335 | Active Clients: [0, 1, 2, 3]
Iter 04 | Loss: 2.4335 | Active Clients: [0, 1, 2, 3]
Iter 05 | Loss: 2.4335 | Active Clients: [0, 1, 2, 3]


  check_min_max_scaling(
  check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)


  [Info] Updated Client Importance: [5.771 5.771 5.771 5.771]
Iter 06 | Loss: 2.4119 | Active Clients: [1, 0]
Iter 07 | Loss: 2.4115 | Active Clients: [1, 3]
Iter 08 | Loss: 2.4327 | Active Clients: [3, 2]
Iter 09 | Loss: 2.3952 | Active Clients: [2, 3]
Iter 10 | Loss: 2.4096 | Active Clients: [0, 3]


  check_min_max_scaling(


  [Info] Updated Client Importance: [5.771 5.771 5.771 5.771]
Iter 11 | Loss: 2.4212 | Active Clients: [3, 2]
Iter 12 | Loss: 2.4260 | Active Clients: [0, 1]
Iter 13 | Loss: 2.4211 | Active Clients: [1, 3]
Iter 14 | Loss: 2.4318 | Active Clients: [3, 0]
Iter 15 | Loss: 2.4304 | Active Clients: [0, 2]


  check_min_max_scaling(


  [Info] Updated Client Importance: [5.771 5.771 5.771 5.771]
Iter 16 | Loss: 2.4447 | Active Clients: [3, 2]
Iter 17 | Loss: 2.4203 | Active Clients: [2, 1]
Iter 18 | Loss: 2.4297 | Active Clients: [1, 0]
Iter 19 | Loss: 2.4332 | Active Clients: [2, 0]
Iter 20 | Loss: 2.4254 | Active Clients: [1, 2]
  [Info] Updated Client Importance: [5.77  5.769 5.77  5.77 ]
Iter 21 | Loss: 2.4467 | Active Clients: [3, 2]


  check_min_max_scaling(


Iter 22 | Loss: 2.4178 | Active Clients: [2, 0]
Iter 23 | Loss: 2.4228 | Active Clients: [3, 0]
Iter 24 | Loss: 2.4438 | Active Clients: [3, 1]
Iter 25 | Loss: 2.4426 | Active Clients: [0, 2]
  [Info] Updated Client Importance: [5.771 5.771 5.77  5.771]

Attack Complete. Best Loss Achieved: 2.4467


  check_min_max_scaling(
