**Prepare dataset**

In [None]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
from utils import LogitsToPredicate, MLP, MultiKANModel, DataLoader, DataLoaderMulti
from kan import KAN

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

X_columns = [
    'Header_Length', 'Protocol Type', 'Duration', 'Rate', 'Srate', 
    # 'Drate',
    # 'fin_flag_number', 'syn_flag_number', 'rst_flag_number', 'psh_flag_number',
    # 'ack_flag_number', 'ece_flag_number', 'cwr_flag_number', 'ack_count',
    # 'syn_count', 'fin_count', 'rst_count', 'HTTP', 'HTTPS', 'DNS', 'Telnet',
    # 'SMTP', 'SSH', 'IRC', 'TCP', 'UDP', 'DHCP', 'ARP', 'ICMP', 'IGMP', 
    'IPv','LLC', 
    'Tot sum', 'Min', 'Max', 'AVG', 'Std', 'Tot size', 'IAT', 'Number',
    'Magnitue', 'Radius', 'Covariance',
    # 'Variance', 'Weight'
]

Y_columns = ['label_L2']

label_L1_mapping = {"MQTT": 0, "Benign": 1} 
label_L2_mapping = {"MQTT-DDoS-Connect_Flood": 0, "MQTT-DDoS-Publish_Flood": 1, 
                    "MQTT-DoS-Connect_Flood": 2, "MQTT-DoS-Publish_Flood": 3,
                    "MQTT-Malformed_Data": 4, "Benign": 5} 


# Read the CSV file
df = pd.read_csv('/home/zyang44/Github/baseline_cicIOT/CIC_IoMT/19classes/filtered_train_l_2_6.csv')
df['label_L1'] = df['label_L1'].map(label_L1_mapping)
df['label_L2'] = df['label_L2'].map(label_L2_mapping)

# Shuffle the dataframe before splitting into training and test sets
df = df.sample(frac=1, random_state=42)
# 90% as training set and 10% as test set
train_size = int(len(df) * 0.9)
train_df, test_df = df.iloc[:train_size, :], df.iloc[train_size:, :]

scaler = StandardScaler()
train_X_scaled = scaler.fit_transform(train_df[X_columns])
test_X_scaled = scaler.transform(test_df[X_columns])
print("Any NaN in test_X_scaled:", np.isnan(test_X_scaled).any())
print("Any Inf in test_X_scaled:", np.isinf(test_X_scaled).any())

train_y = train_df[Y_columns].values.ravel()
test_y = test_df[Y_columns].values.ravel()
print("Unique train_y values:", np.unique(train_y))
print("Unique test_y values:", np.unique(test_y))
# take Y_columns as the label, and transfering to one-hot coded
dataset = {
    'train_input': torch.tensor(train_X_scaled, dtype=torch.float32, device=device),
    'train_label': F.one_hot(torch.tensor(train_y, dtype=torch.long, device=device), num_classes=6),
    'test_input': torch.tensor(test_X_scaled, dtype=torch.float32, device=device),
    'test_label': F.one_hot(torch.tensor(test_y, dtype=torch.long, device=device), num_classes=6)
}
print("Data prepared.",
      f"Train set: {dataset['train_input'].shape, dataset['train_label'].shape}",
      f"Test set: {dataset['test_input'].shape, dataset['test_label'].shape}", sep="\n")

cuda:0
Any NaN in test_X_scaled: False
Any Inf in test_X_scaled: False
Unique train_y values: [0 1 2 3 4 5]
Unique test_y values: [0 1 2 3 4 5]
Data prepared.
Train set: (torch.Size([35945, 18]), torch.Size([35945, 6]))
Test set: (torch.Size([3994, 18]), torch.Size([3994, 6]))


In [7]:
def compute_accuracy(loader, model):
    total_correct = 0
    total_samples = 0
    for data, labels in loader:
        logits = model(data)
        preds = torch.argmax(logits, dim=1)
        total_correct += (preds == labels).sum()
        total_samples += labels.numel()
    return total_correct.float() / total_samples

# Define the DataLoader adapted to the LTN input format. 'data' is same, 'labels' is numeric (not one-hot)
train_loader = DataLoader(
    data=dataset['train_input'], 
    labels=torch.tensor(train_y, dtype=torch.long, device=device), 
    batch_size=len(dataset['train_input']))
test_loader = DataLoader(
    data=dataset['test_input'], 
    labels=torch.tensor(test_y, dtype=torch.long, device=device), 
    batch_size=len(dataset['test_input']))

**MLP with CrossEntropyLoss**

In [None]:
# Define the MLP predicate
mlp = MLP(layer_sizes=(18, 10, 6)).to(device)

# MLP with standard loss fn 
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
for epoch in range(401):
    optimizer.zero_grad()
    logits = mlp(train_loader.data, training=True)
    loss = criterion(logits, train_loader.labels)
    loss.backward()
    optimizer.step()

    # test
    if epoch % 10 == 0:
        acc = compute_accuracy(test_loader, mlp)
        print(f"Epoch {epoch}, Loss: {loss.item()}, Test accuracy: {acc.item()}")

**KAN-LTN Setup**

1. Build myKAN, inject hierarchy-rules in LTN (LTN Interpretation)
2. Do autoswap in KAN, to identify modular structures (KAN Interpretation)

In [4]:
import ltn
import ltn.fuzzy_ops

# define the connectives, quantifiers, and the SatAgg
Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())
And = ltn.Connective(ltn.fuzzy_ops.AndProd())   # And = ltn.Connective(custom_fuzzy_ops.AndProd())
Or = ltn.Connective(ltn.fuzzy_ops.OrProbSum())
Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")
Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e")
Implies = ltn.Connective(ltn.fuzzy_ops.ImpliesReichenbach())
SatAgg = ltn.fuzzy_ops.SatAgg()

# Define myKAN predicate
kan = KAN(width=[18, 10, 6], grid=5, k=3, seed=42, device=device)
mykan = MultiKANModel(kan)
P_kan = ltn.Predicate(LogitsToPredicate(mykan))

checkpoint directory created: ./model
saving model version 0.0


**LTN effect**

KAN(LTN) v.s mlp(standard loss fn)

In [10]:
# define ltn constants
l_MQTT_DDoS_Connect_Flood = ltn.Constant(torch.tensor([1, 0, 0, 0, 0, 0]))
l_MQTT_DDoS_Publish_Flood = ltn.Constant(torch.tensor([0, 1, 0, 0, 0, 0]))
l_MQTT_DoS_Connect_Flood = ltn.Constant(torch.tensor([0, 0, 1, 0, 0, 0]))
l_MQTT_DoS_Publish_Flood = ltn.Constant(torch.tensor([0, 0, 0, 1, 0, 0]))
l_MQTT_Malformed_Data = ltn.Constant(torch.tensor([0, 0, 0, 0, 1, 0]))
l_Benign = ltn.Constant(torch.tensor([0, 0, 0, 0, 0, 1]))

# filter the data by label_L1, data is a tensor(n*20), label_L1 is a tensor(n*4)
# based on the label_L1, we store the data into different ltn.Variable
def compute_sat_levels(loader, P):
	sat_level  = 0
	for data, labels in loader:
		x = ltn.Variable("x", data)
		x_MQTT_DDoS_Connect_Flood = ltn.Variable("x_MQTT_DDoS_Connect_Flood", data[labels == 0])
		x_MQTT_DDoS_Publish_Flood = ltn.Variable("x_MQTT_DDoS_Publish_Flood", data[labels == 1])
		x_MQTT_DoS_Connect_Flood = ltn.Variable("x_MQTT_DoS_Connect_Flood", data[labels == 2])
		x_MQTT_DoS_Publish_Flood = ltn.Variable("x_MQTT_DoS_Publish_Flood", data[labels == 3])
		x_MQTT_Malformed_Data = ltn.Variable("x_MQTT_Malformed_Data", data[labels == 4])
		x_Benign = ltn.Variable("x_Benign", data[labels == 5])

		sat_level = SatAgg(
			Forall(x_MQTT_DDoS_Connect_Flood, P(x_MQTT_DDoS_Connect_Flood, l_MQTT_DDoS_Connect_Flood)),
			Forall(x_MQTT_DDoS_Publish_Flood, P(x_MQTT_DDoS_Publish_Flood, l_MQTT_DDoS_Publish_Flood)),
			Forall(x_MQTT_DoS_Connect_Flood, P(x_MQTT_DoS_Connect_Flood, l_MQTT_DoS_Connect_Flood)),
			Forall(x_MQTT_DoS_Publish_Flood, P(x_MQTT_DoS_Publish_Flood, l_MQTT_DoS_Publish_Flood)),
			Forall(x_MQTT_Malformed_Data, P(x_MQTT_Malformed_Data, l_MQTT_Malformed_Data)),
			Forall(x_Benign, P(x_Benign, l_Benign))
		)
	return sat_level
    

optimizer_kan = torch.optim.Adam(P_kan.parameters(), lr=0.001)

for epoch in range(401):
    # Train the KAN
	optimizer_kan.zero_grad()
	sat_kan = compute_sat_levels(train_loader, P_kan)
	loss = 1. - sat_kan
	loss.backward()
	optimizer_kan.step()
	train_loss_kan = loss.item()
	# Test the KAN
	acc_kan = compute_accuracy(test_loader, kan)
	test_sat_kan = compute_sat_levels(test_loader, P_kan)
	print(f"Epoch {epoch} | KAN (loss/acc/sat): {train_loss_kan:.3f}/{acc_kan:.3f}/{sat_kan:.3f}({test_sat_kan:.3f})")


Epoch 0 | KAN (loss/acc/sat): 0.832/0.308/0.168(0.169)
Epoch 1 | KAN (loss/acc/sat): 0.831/0.319/0.169(0.170)
Epoch 2 | KAN (loss/acc/sat): 0.831/0.327/0.169(0.170)
Epoch 3 | KAN (loss/acc/sat): 0.830/0.339/0.170(0.171)
Epoch 4 | KAN (loss/acc/sat): 0.829/0.346/0.171(0.172)
Epoch 5 | KAN (loss/acc/sat): 0.829/0.358/0.171(0.172)
Epoch 6 | KAN (loss/acc/sat): 0.828/0.365/0.172(0.173)
Epoch 7 | KAN (loss/acc/sat): 0.827/0.372/0.173(0.174)
Epoch 8 | KAN (loss/acc/sat): 0.826/0.386/0.174(0.175)
Epoch 9 | KAN (loss/acc/sat): 0.826/0.395/0.174(0.175)
Epoch 10 | KAN (loss/acc/sat): 0.825/0.405/0.175(0.176)
Epoch 11 | KAN (loss/acc/sat): 0.824/0.414/0.176(0.177)
Epoch 12 | KAN (loss/acc/sat): 0.824/0.423/0.176(0.178)
Epoch 13 | KAN (loss/acc/sat): 0.823/0.432/0.177(0.178)
Epoch 14 | KAN (loss/acc/sat): 0.822/0.441/0.178(0.179)
Epoch 15 | KAN (loss/acc/sat): 0.821/0.447/0.179(0.180)
Epoch 16 | KAN (loss/acc/sat): 0.821/0.453/0.179(0.181)
Epoch 17 | KAN (loss/acc/sat): 0.820/0.457/0.180(0.182)
Ep

Key observation:
* [Convergence] loss converge faster. 
* [SAT] the overall Sat level higher, means that it converges along with the rules satisfied well.
* [Acc] accuracy converge eariler, and slightly better at the end.
