# Example Network

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Generate synthetic dataset
X, y = make_moons(n_samples=1000, noise=0.2, random_state=0)
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)

# Convert to tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)

# Define simple NN
class SimpleNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

model = SimpleNN()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in range(300):
    model.train()
    optimizer.zero_grad()
    output = model(X_train)
    loss = criterion(output, y_train)
    loss.backward()
    optimizer.step()

# Evaluation
model.eval()
with torch.no_grad():
    preds = (model(X_test) > 0.5).float()
    accuracy = (preds.eq(y_test).sum() / len(y_test)).item()
print(f'Test Accuracy: {accuracy:.2f}')


Test Accuracy: 0.90


# Define Attack Setup

In [2]:
import torch
from framework.attack import attack

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Running on", device)

res = attack(model, X_test,y_test,0,"cpu")

res

Running on cpu


  X_tensor = torch.tensor(X, dtype=torch.float32, device=device)


Unnamed: 0,IDX,ACC,NAME,ORIG_VAL,FLIP_VAL
0,"(0, 0)",0.82,net.0.weight,-1.443894,1.443894
1,"(0, 1)",0.903333,net.0.weight,-0.640605,0.640605
2,"(1, 0)",0.896667,net.0.weight,0.091025,-0.091025
3,"(1, 1)",0.896667,net.0.weight,-0.05742,0.05742
4,"(2, 0)",0.893333,net.0.weight,-0.508874,0.508874
5,"(2, 1)",0.893333,net.0.weight,1.124696,-1.124696
6,"(3, 0)",0.693333,net.0.weight,-2.170932,2.170932
7,"(3, 1)",0.896667,net.0.weight,0.019511,-0.019511
8,"(4, 0)",0.896667,net.0.weight,0.626632,-0.626632
9,"(4, 1)",0.903333,net.0.weight,-0.53183,0.53183
