**KAN: prepare dataset**

In [22]:
import matplotlib.pyplot as plt
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

X_columns = [
    'Header_Length', 'Protocol Type', 'Duration', 'Rate', 'Srate', 
    # 'Drate',
    # 'fin_flag_number', 'syn_flag_number', 'rst_flag_number', 'psh_flag_number',
    # 'ack_flag_number', 'ece_flag_number', 'cwr_flag_number', 'ack_count',
    # 'syn_count', 'fin_count', 'rst_count', 'HTTP', 'HTTPS', 'DNS', 'Telnet',
    # 'SMTP', 'SSH', 'IRC', 'TCP', 'UDP', 'DHCP', 'ARP', 'ICMP', 'IGMP', 
    'IPv','LLC', 
    'Tot sum', 'Min', 'Max', 'AVG', 'Std', 'Tot size', 'IAT', 'Number',
    'Magnitue', 'Radius', 'Covariance',
    'Variance', 'Weight'
]

Y_columns = ['label_L1']

label_L1_mapping = {"MQTT": 0, "Benign": 1, "Recon": 2, "ARP_Spoofing": 3}
label_L2_mapping = {"MQTT-DDoS-Connect_Flood": 4, "MQTT-DDoS-Publish_Flood": 5, 
                    "MQTT-DoS-Connect_Flood": 6, "MQTT-DoS-Publish_Flood": 7,
                    "MQTT-Malformed_Data": 8, "benign": 9, 
                    "Recon-OS_Scan": 10, "Recon-Port_Scan": 11,
                    "arp_spoofing": 12}


# Read the CSV file
df = pd.read_csv('/home/zyang44/Github/baseline_cicIOT/CIC_IoMT/19classes/filtered_train_l_4_11.csv')
df['label_L1'] = df['label_L1'].map(label_L1_mapping)
df['label_L2'] = df['label_L2'].map(label_L1_mapping)

# Shuffle the dataframe before splitting into training and test sets
df = df.sample(frac=1, random_state=42)
# 90% as training set and 10% as test set
train_size = int(len(df) * 0.9)
train_df, test_df = df.iloc[:train_size, :], df.iloc[train_size:, :]

scaler = StandardScaler()
train_X_scaled = scaler.fit_transform(train_df[X_columns])
test_X_scaled = scaler.transform(test_df[X_columns])
train_y = train_df[Y_columns].values.ravel()
test_y = test_df[Y_columns].values.ravel()

# take Y_columns as the label, and transfering to one-hot coded
dataset = {
    'train_input': torch.tensor(train_X_scaled, dtype=torch.float32, device=device),
    'train_label': F.one_hot(torch.tensor(train_y, dtype=torch.long, device=device), num_classes=4),
    'test_input': torch.tensor(test_X_scaled, dtype=torch.float32, device=device),
    'test_label': F.one_hot(torch.tensor(test_y, dtype=torch.long, device=device), num_classes=4)
}
print("Data prepared.",
      f"Train set: {dataset['train_input'].shape, dataset['train_label'].shape}",
      f"Test set: {dataset['test_input'].shape, dataset['test_label'].shape}", sep="\n")

cuda:0
Data prepared.
Train set: (torch.Size([89918, 20]), torch.Size([89918, 4]))
Test set: (torch.Size([9991, 20]), torch.Size([9991, 4]))


**Build myKAN**

A wrapper function, to only get the logits output by the last layer.

In [23]:
import torch.nn as nn

class MultiKANModel(nn.Module):
    def __init__(self, kan):
        """
        Wrap an already built MultKAN instance.
        Args:
            kan: a MultKAN model (which has attributes such as act_fun, symbolic_fun, node_bias, node_scale,
                 subnode_bias, subnode_scale, depth, width, mult_homo, mult_arity, input_id, symbolic_enabled, etc.)
        """
        super(MultiKANModel, self).__init__()
        self.kan = kan

    def forward(self, x, training=False, singularity_avoiding=False, y_th=10.):
        # Select input features according to input_id
        x = x[:, self.kan.input_id.long()]
        # Loop through each layer
        for l in range(self.kan.depth):
            # Get outputs from the numerical branch (KANLayer) of current layer
            x_numerical, preacts, postacts_numerical, postspline = self.kan.act_fun[l](x)
            # Get output from the symbolic branch if enabled
            if self.kan.symbolic_enabled:
                x_symbolic, postacts_symbolic = self.kan.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding, y_th=y_th)
            else:
                x_symbolic = 0.
            # Sum the numerical and symbolic outputs
            x = x_numerical + x_symbolic

            # Subnode affine transformation
            x = self.kan.subnode_scale[l][None, :] * x + self.kan.subnode_bias[l][None, :]

            # Process multiplication nodes
            dim_sum = self.kan.width[l+1][0]
            dim_mult = self.kan.width[l+1][1]
            if dim_mult > 0:
                if self.kan.mult_homo:
                    for i in range(self.kan.mult_arity-1):
                        if i == 0:
                            x_mult = x[:, dim_sum::self.kan.mult_arity] * x[:, dim_sum+1::self.kan.mult_arity]
                        else:
                            x_mult = x_mult * x[:, dim_sum+i+1::self.kan.mult_arity]
                else:
                    for j in range(dim_mult):
                        acml_id = dim_sum + int(np.sum(self.kan.mult_arity[l+1][:j]))
                        for i in range(self.kan.mult_arity[l+1][j]-1):
                            if i == 0:
                                x_mult_j = x[:, [acml_id]] * x[:, [acml_id+1]]
                            else:
                                x_mult_j = x_mult_j * x[:, [acml_id+i+1]]
                        if j == 0:
                            x_mult = x_mult_j
                        else:
                            x_mult = torch.cat([x_mult, x_mult_j], dim=1)
                # Concatenate sum and mult parts
                x = torch.cat([x[:, :dim_sum], x_mult], dim=1)

            # Node affine transformation
            x = self.kan.node_scale[l][None, :] * x + self.kan.node_bias[l][None, :]

        # Final x corresponds to the logits output of the whole model
        return x

**LTN Setting**

In [24]:
import ltn
import ltn.fuzzy_ops

# define the connectives, quantifiers, and the SatAgg
Not = ltn.Connective(ltn.fuzzy_ops.NotStandard())
And = ltn.Connective(ltn.fuzzy_ops.AndProd())   # And = ltn.Connective(custom_fuzzy_ops.AndProd())
Or = ltn.Connective(ltn.fuzzy_ops.OrProbSum())
Forall = ltn.Quantifier(ltn.fuzzy_ops.AggregPMeanError(p=2), quantifier="f")
Exists = ltn.Quantifier(ltn.fuzzy_ops.AggregPMean(p=2), quantifier="e")
Implies = ltn.Connective(ltn.fuzzy_ops.ImpliesReichenbach())
SatAgg = ltn.fuzzy_ops.SatAgg()

# define ltn constants
l_MQTT = ltn.Constant(torch.tensor([1, 0, 0, 0]))
l_Benign = ltn.Constant(torch.tensor([0, 1, 0, 0]))
l_Recon = ltn.Constant(torch.tensor([0, 0, 1, 0]))
l_ARP_Spoofing = ltn.Constant(torch.tensor([0, 0, 0, 1]))

In [27]:
from utils import MLP, LogitsToPredicate, DataLoader, DataLoaderMulti
from kan import KAN

# Define the MLP predicate
mlp = MLP(layer_sizes=(20, 64, 4)).to(device)
P_mlp = ltn.Predicate(LogitsToPredicate(mlp))

# Define myKAN predicate
kan = KAN(width=[20, 10, 4], grid=5, k=3, seed=42, device=device)
mykan = MultiKANModel(kan)
P_kan = ltn.Predicate(LogitsToPredicate(mykan))

checkpoint directory created: ./model
saving model version 0.0


**KNN effect**

LTN(kan) v.s LTN(mlp)

In [29]:
# Define the DataLoader adapted to the LTN input format. 'data' is same, 'labels' is numeric (not one-hot)
train_loader = DataLoader(data=dataset['train_input'], labels=torch.tensor(train_y, dtype=torch.long, device=device), batch_size=len(dataset['train_input']))
test_loader = DataLoader(data=dataset['test_input'], labels=torch.tensor(test_y, dtype=torch.long, device=device), batch_size=len(dataset['test_input']))

# filter the data by label_L1, data is a tensor(n*20), label_L1 is a tensor(n*4)
# based on the label_L1, we store the data into different ltn.Variable
def compute_sat_levels(loader, P):
	sat_level  = 0
	for data, labels in loader:
		x = ltn.Variable("x", data)
		x_MQTT = ltn.Variable("x_MQTT", data[labels == 0])
		x_Benign = ltn.Variable("x_Benign", data[labels == 1])
		x_Recon = ltn.Variable("x_Recon", data[labels == 2])
		x_ARP_Spoofing = ltn.Variable("x_ARP_Spoofing", data[labels == 3])
		
		sat_level = SatAgg(
			Forall(x_MQTT, P(x_MQTT, l_MQTT)),
			Forall(x_Benign, P(x_Benign, l_Benign)),
			Forall(x_Recon, P(x_Recon, l_Recon)),
			Forall(x_ARP_Spoofing, P(x_ARP_Spoofing, l_ARP_Spoofing))
		)
	return sat_level


def compute_accuracy(loader, model):
    total_correct = 0
    total_samples = 0
    for data, labels in loader:
        logits = model(data)
        preds = torch.argmax(logits, dim=1)
        total_correct += (preds == labels).sum()
        total_samples += labels.numel()
    return total_correct.float() / total_samples
    

optimizer_mlp = torch.optim.Adam(P_mlp.parameters(), lr=0.0015)
optimizer_kan = torch.optim.Adam(P_kan.parameters(), lr=0.0015)

for epoch in range(151):
	# Train the MLP
    optimizer_mlp.zero_grad()
    sat_mlp = compute_sat_levels(train_loader, P_mlp)
    loss = 1. - sat_mlp
    loss.backward()
    optimizer_mlp.step()
    train_loss_mlp  = loss.item()

    # Train the KAN
    optimizer_kan.zero_grad()
    sat_kan = compute_sat_levels(train_loader, P_kan)
    loss = 1. - sat_kan
    loss.backward()
    optimizer_kan.step()
    train_loss_kan = loss.item()
	
    # Test
    acc_mlp = compute_accuracy(test_loader, mlp)
    acc_kan = compute_accuracy(test_loader, kan)

    test_sat_mlp = compute_sat_levels(test_loader, P_mlp)
    test_sat_kan = compute_sat_levels(test_loader, P_kan)

    print(f"Epoch {epoch} | MLP (loss/acc/sat): {train_loss_mlp:.3f}/{acc_mlp:.3f}/{sat_mlp:.3f}({test_sat_mlp:.3f}) | KAN (loss/acc/sat): {train_loss_kan:.3f}/{acc_kan:.3f}/{sat_kan:.3f}({test_sat_kan:.3f})")



Epoch 0 | MLP (loss/acc/sat): 0.616/0.561/0.384(0.389) | KAN (loss/acc/sat): 0.694/0.675/0.306(0.308)
Epoch 1 | MLP (loss/acc/sat): 0.613/0.561/0.387(0.392) | KAN (loss/acc/sat): 0.692/0.677/0.308(0.310)
Epoch 2 | MLP (loss/acc/sat): 0.610/0.563/0.390(0.395) | KAN (loss/acc/sat): 0.690/0.680/0.310(0.312)
Epoch 3 | MLP (loss/acc/sat): 0.608/0.564/0.392(0.398) | KAN (loss/acc/sat): 0.689/0.680/0.311(0.314)
Epoch 4 | MLP (loss/acc/sat): 0.605/0.564/0.395(0.401) | KAN (loss/acc/sat): 0.687/0.677/0.313(0.316)
Epoch 5 | MLP (loss/acc/sat): 0.602/0.565/0.398(0.403) | KAN (loss/acc/sat): 0.685/0.676/0.315(0.317)
Epoch 6 | MLP (loss/acc/sat): 0.600/0.564/0.400(0.406) | KAN (loss/acc/sat): 0.683/0.676/0.317(0.319)
Epoch 7 | MLP (loss/acc/sat): 0.597/0.565/0.403(0.408) | KAN (loss/acc/sat): 0.682/0.675/0.318(0.321)
Epoch 8 | MLP (loss/acc/sat): 0.595/0.566/0.405(0.411) | KAN (loss/acc/sat): 0.680/0.675/0.320(0.323)
Epoch 9 | MLP (loss/acc/sat): 0.593/0.567/0.407(0.413) | KAN (loss/acc/sat): 0.678