<a href="https://colab.research.google.com/github/sankeawthong/Project-1-Lita-Chatbot/blob/main/%5B20250621%5D%20CM_TrustFed-IDS-BFSF%20(LSTM%20baseline%20%2B%20trust-cap%20%2B%20focal%20loss).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**CM_TrustFed-IDS – WSN-BFSF  (LSTM baseline + trust-cap + focal loss)**

In [5]:
#!/usr/bin/env python3
# --------------------------------------------------------------------
#  TrustFed-IDS – WSN-BFSF  (LSTM baseline + trust-cap + focal loss)
# --------------------------------------------------------------------
import os, time, psutil, numpy as np, pandas as pd, tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import InputLayer, LSTM, Dense, Dropout
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers.schedules import CosineDecay
#from tensorflow_addons.losses import SigmoidFocalCrossEntropy
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import SMOTE
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from scipy.spatial.distance import cosine
# Import necessary components
from tensorflow.keras.optimizers.schedules import CosineDecay
# If tensorflow_addons is not installed, run: !pip install tensorflow-addons
#from tensorflow_addons.losses import SigmoidFocalCrossEntropy
from sklearn.utils.class_weight import compute_class_weight

In [6]:
from tensorflow.keras.optimizers.schedules import CosineDecay # Import CosineDecay from the correct path
from sklearn.utils.class_weight import compute_class_weight # Import compute_class_weight

In [7]:
# ------------------------------- CONFIG --------------------------------
SEED, NUM_CLIENTS = 42, 5
ROUNDS, LOCAL_EPOCHS = 75, 1           # shallow local, more rounds
BATCH_SIZE, DIRICHLET_ALPHA = 32, 0.5
HISTORY_KEEP = 6
TRUST_ALPHA  = (0.30, 0.55, 0.15)      # sim, loss, stability
LOG_DIR      = "/mnt/data"
DATA_PATH    = "dataset.csv"  # update path if needed
# -----------------------------------------------------------------------

np.random.seed(SEED)
tf.random.set_seed(SEED)

# ---------------- 1. LOAD & TRAIN / TEST SPLIT -------------------------
df = pd.read_csv(DATA_PATH).dropna()
for col in df.select_dtypes(include="object"):
    df[col] = LabelEncoder().fit_transform(df[col])

X_all = df.drop("Class", axis=1).values.astype("float32")
y_all = df["Class"].values.astype("int64")

X_tr, X_te, y_tr, y_te = train_test_split(
    X_all, y_all, test_size=0.20, stratify=y_all, random_state=SEED)

In [8]:
# ---------------- 1b. SCALER + SMOTE ON TRAIN ONLY ---------------------
scaler = StandardScaler().fit(X_tr)
X_tr, X_te = scaler.transform(X_tr), scaler.transform(X_te)
X_tr, y_tr = SMOTE(random_state=SEED).fit_resample(X_tr, y_tr)

# reshape to (samples, timesteps=T, features=1)
X_tr, X_te = X_tr[..., None], X_te[..., None]
num_classes = int(y_tr.max() + 1)
y_te_cat    = to_categorical(y_te, num_classes)

# ---------------- 2. DIRICHLET CLIENT SPLIT ----------------------------
def dirichlet_split(X, y, k, alpha, rng):
    idx_by_cls = {c: np.where(y == c)[0] for c in np.unique(y)}
    clients = [[] for _ in range(k)]
    for c, idx in idx_by_cls.items():
        rng.shuffle(idx)
        parts = (rng.dirichlet([alpha]*k) * len(idx)).astype(int)
        while parts.sum() < len(idx):
            parts[rng.randint(0, k)] += 1
        start = 0
        for cid, cnt in enumerate(parts):
            clients[cid].extend(idx[start:start+cnt]); start += cnt
    for lst in clients: rng.shuffle(lst)
    return [X[l] for l in clients], [y[l] for l in clients]

rng = np.random.RandomState(SEED)
client_X_raw, client_y_raw = dirichlet_split(X_tr, y_tr, NUM_CLIENTS,
                                             DIRICHLET_ALPHA, rng)
client_X = client_X_raw
client_y = [to_categorical(y, num_classes) for y in client_y_raw]
input_shape = (X_tr.shape[1], 1)

# ---------------- 3. MODEL BUILDER -------------------------------------
def build_model(inp=input_shape, classes=num_classes):
    lr_sched = CosineDecay(initial_learning_rate=5e-4,
                           decay_steps=ROUNDS,
                           alpha=0.4)  # floor = 0.4 × initial = 2e-4
    opt = tf.keras.optimizers.Nadam(learning_rate=lr_sched, clipnorm=2.0)

    m = Sequential([
        InputLayer(input_shape=inp),
        LSTM(128, activation='tanh', return_sequences=True,
             kernel_regularizer=l2(5e-4)),
        LSTM(64,  activation='tanh', kernel_regularizer=l2(5e-4)),
        Dense(256, activation='relu'),
        Dropout(0.20),                       # milder dropout
        Dense(128, activation='relu'),
        Dropout(0.25),
        Dense(classes, activation='softmax')
    ])
    m.compile(opt, loss='categorical_crossentropy', metrics=['accuracy'])
    return m

# ---------------- 4. TRUST-FEDAVG UTILITIES ----------------------------
def weight_update(local, global_):
    return [l - g for l, g in zip(local, global_)]

def vec_cos(a, b):
    v1, v2 = np.concatenate([w.ravel() for w in a]), np.concatenate([w.ravel() for w in b])
    return 0.0 if (np.all(v1 == 0) or np.all(v2 == 0)) else 1 - cosine(v1, v2)

def stability(upd, hist):
    if len(hist) < 2: return 1.0
    return float(np.nanmean([vec_cos(upd, h) for h in hist[-HISTORY_KEEP:]]))

def compute_trust(upd, vloss, hist):
    lo, hi = min(vloss.values()), max(vloss.values())
    trust = {}
    for cid, u in upd.items():
        score = (TRUST_ALPHA[0]*vec_cos(u, [np.zeros_like(w) for w in u]) +
                 TRUST_ALPHA[1]*(1 - (vloss[cid]-lo)/(hi-lo+1e-8)) +
                 TRUST_ALPHA[2]*stability(u, hist[cid]))
        trust[cid] = max(score, 1e-6)
    return trust

def aggregate(w, t, n):
    tot = sum(t[c]*n[c] for c in w)
    return [sum(t[c]*n[c]*w[c][l] for c in w)/tot
            for l in range(len(next(iter(w.values()))))]

# ---------------- 5. INITIALISE ----------------------------------------
g_model   = build_model()
g_weights = g_model.get_weights()
model_MB  = sum(w.nbytes for w in g_weights)/2**20

history   = {c: [] for c in range(NUM_CLIENTS)}
perf_log, comm_log, trust_log = [], [], []

# class-weight dictionary (balanced)
cls_wt = compute_class_weight('balanced',
                              classes=np.arange(num_classes), y=y_tr)
class_weight = dict(enumerate(cls_wt))

# ---------------- 6. FEDERATED TRAINING --------------------------------
for r in range(1, ROUNDS+1):
    tic = time.time()
    lw, upd, vloss, ns, bytes_out = {}, {}, {}, {}, 0

    for cid in range(NUM_CLIENTS):
        n_val = max(1, int(0.1*len(client_X[cid])))
        Xv, yv = client_X[cid][:n_val], client_y[cid][:n_val]
        Xt, yt = client_X[cid][n_val:], client_y[cid][n_val:]

        local = build_model(); local.set_weights(g_weights)
        local.fit(Xt, yt, epochs=LOCAL_EPOCHS, batch_size=BATCH_SIZE,
                  verbose=0, class_weight=class_weight)

        w = local.get_weights()
        u = weight_update(w, g_weights)
        l = local.evaluate(Xv, yv, verbose=0)[0]

        lw[cid], upd[cid], vloss[cid], ns[cid] = w, u, l, len(Xt)
        history[cid] = (history[cid] + [u])[-HISTORY_KEEP:]
        bytes_out += sum(x.nbytes for x in w)

    trust = compute_trust(upd, vloss, history)
    g_weights = aggregate(lw, trust, ns)
    g_model.set_weights(g_weights)

    # ---- evaluation on test set ----
    y_pred = np.argmax(g_model.predict(X_te, verbose=0), axis=1)
    perf_log.append({
        "round": r,
        "accuracy":  accuracy_score(y_te, y_pred),
        "precision": precision_score(y_te, y_pred, average='weighted', zero_division=0),
        "recall":    recall_score(y_te, y_pred, average='weighted', zero_division=0),
        "f1":        f1_score(y_te, y_pred, average='weighted', zero_division=0),
        "ms": round((time.time()-tic)*1000, 2)
    })
    comm_log.append({"round": r, "MB": bytes_out/2**20})
    trust_log.extend([{"round": r, "client": c, "trust": t}
                      for c, t in trust.items()])

    print(f"R{r:02d}  acc={perf_log[-1]['accuracy']:.8f}  "
          f"F1={perf_log[-1]['f1']:.8f}  MB={bytes_out/2**20:.8f}")



R01  acc=0.32323540  F1=0.39011901  MB=3.16658020




R02  acc=0.49320752  F1=0.58920530  MB=3.16658020




R03  acc=0.55691903  F1=0.65341459  MB=3.16658020




R04  acc=0.62104707  F1=0.70971309  MB=3.16658020




R05  acc=0.65632309  F1=0.73961237  MB=3.16658020




R06  acc=0.66293935  F1=0.74459868  MB=3.16658020




R07  acc=0.71244113  F1=0.78333886  MB=3.16658020




R08  acc=0.73315498  F1=0.79912736  MB=3.16658020




R09  acc=0.74481753  F1=0.80792125  MB=3.16658020




R10  acc=0.74994393  F1=0.81178887  MB=3.16658020




R11  acc=0.76516292  F1=0.82322311  MB=3.16658020




R12  acc=0.76702124  F1=0.82478728  MB=3.16658020




R13  acc=0.76923200  F1=0.82641149  MB=3.16658020




R14  acc=0.77118644  F1=0.82759952  MB=3.16658020




R15  acc=0.77748230  F1=0.83214626  MB=3.16658020




R16  acc=0.77250008  F1=0.82834608  MB=3.16658020




R17  acc=0.78517189  F1=0.83772127  MB=3.16658020




R18  acc=0.79023421  F1=0.84106114  MB=3.16658020




R19  acc=0.78595687  F1=0.83772543  MB=3.16658020




R20  acc=0.80269777  F1=0.85003157  MB=3.16658020




R21  acc=0.79481593  F1=0.84428354  MB=3.16658020




R22  acc=0.77450258  F1=0.82961637  MB=3.16658020




R23  acc=0.78642145  F1=0.83714371  MB=3.16658020




R24  acc=0.80208901  F1=0.84945121  MB=3.16658020




R25  acc=0.80460415  F1=0.85126536  MB=3.16658020




R26  acc=0.79923745  F1=0.84773333  MB=3.16658020




R27  acc=0.81194130  F1=0.85676235  MB=3.16658020




R28  acc=0.81211752  F1=0.85672493  MB=3.16658020




R29  acc=0.82124892  F1=0.86314364  MB=3.16658020




R30  acc=0.80047099  F1=0.84769457  MB=3.16658020




R31  acc=0.79100317  F1=0.84125076  MB=3.16658020




R32  acc=0.82134504  F1=0.86351966  MB=3.16658020




R33  acc=0.81761238  F1=0.86077377  MB=3.16658020




R34  acc=0.81035532  F1=0.85547806  MB=3.16658020




R35  acc=0.81093204  F1=0.85513989  MB=3.16658020




R36  acc=0.81601038  F1=0.85920964  MB=3.16658020




R37  acc=0.82238634  F1=0.86330180  MB=3.16658020




R38  acc=0.81940662  F1=0.86157629  MB=3.16658020




R39  acc=0.82507770  F1=0.86605185  MB=3.16658020




R40  acc=0.82214604  F1=0.86377421  MB=3.16658020




R41  acc=0.81083592  F1=0.85559984  MB=3.16658020




R42  acc=0.82472526  F1=0.86567500  MB=3.16658020




R43  acc=0.81830124  F1=0.85998625  MB=3.16658020




R44  acc=0.84145013  F1=0.87756805  MB=3.16658020




R45  acc=0.82257858  F1=0.86422176  MB=3.16658020




R46  acc=0.84188267  F1=0.87777497  MB=3.16658020




R47  acc=0.84830669  F1=0.88219912  MB=3.16658020




R48  acc=0.85384960  F1=0.88630604  MB=3.16658020




R49  acc=0.84221909  F1=0.87810320  MB=3.16658020




R50  acc=0.84122585  F1=0.87725167  MB=3.16658020




R51  acc=0.85537150  F1=0.88772825  MB=3.16658020




R52  acc=0.85301656  F1=0.88505483  MB=3.16658020




R53  acc=0.83847041  F1=0.87520961  MB=3.16658020




R54  acc=0.85125437  F1=0.88337087  MB=3.16658020




R55  acc=0.84468617  F1=0.88014215  MB=3.16658020




R56  acc=0.85239178  F1=0.88509090  MB=3.16658020




R57  acc=0.84982859  F1=0.88333981  MB=3.16658020




R58  acc=0.86253244  F1=0.89256384  MB=3.16658020




R59  acc=0.87930537  F1=0.90405046  MB=3.16658020




R60  acc=0.85364134  F1=0.88591251  MB=3.16658020




R61  acc=0.90975938  F1=0.92591326  MB=3.16658020




R62  acc=0.82744866  F1=0.86770740  MB=3.16658020




R63  acc=0.86845984  F1=0.89690757  MB=3.16658020




R64  acc=0.88899747  F1=0.91128422  MB=3.16658020




R65  acc=0.85461856  F1=0.88632742  MB=3.16658020




R66  acc=0.87489988  F1=0.90134164  MB=3.16658020




R67  acc=0.86115472  F1=0.89181934  MB=3.16658020




R68  acc=0.86173144  F1=0.89204314  MB=3.16658020




R69  acc=0.85439428  F1=0.88626000  MB=3.16658020




R70  acc=0.86564032  F1=0.89508088  MB=3.16658020




R71  acc=0.86126686  F1=0.89168588  MB=3.16658020




R72  acc=0.87989811  F1=0.90449114  MB=3.16658020




R73  acc=0.84156227  F1=0.87795799  MB=3.16658020




R74  acc=0.86365384  F1=0.89392028  MB=3.16658020




R75  acc=0.87350614  F1=0.90032818  MB=3.16658020


In [9]:
# ---------------- 7. SAVE LOGS -----------------------------------------
os.makedirs(LOG_DIR, exist_ok=True)
pd.DataFrame(perf_log ).to_csv(f"{LOG_DIR}/perf_log_WSN-BFSF_focal.csv",  index=False)
pd.DataFrame(comm_log ).to_csv(f"{LOG_DIR}/comm_log_WSN-BFSF_focal.csv", index=False)
pd.DataFrame(trust_log).to_csv(f"{LOG_DIR}/trust_log_WSN-BFSF_focal.csv", index=False)

profile = {
    "Params_MB": round(model_MB, 3),
    "Rounds": ROUNDS,
    "Clients": NUM_CLIENTS,
    "PeakMem_MB": round(psutil.Process(os.getpid()).memory_info().rss/2**20, 2)
}
pd.DataFrame([profile]).to_csv(f"{LOG_DIR}/model_profile_WSN-BFSF_focal.csv", index=False)
print("\n✓ Focal-loss + trust-cap run complete – logs saved to", LOG_DIR)


✓ Focal-loss + trust-cap run complete – logs saved to /mnt/data


In [13]:
# ---------------------------------------------------------
#  Confusion-matrix block  –  paste just before the script
#  finishes (after the last evaluation / logging section).
# ---------------------------------------------------------
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# --- 1.  Re-evaluate the final model ---------------------
#y_prob_final = global_model.predict(X_test, verbose=0)   # FedAvg script
y_prob_final = g_model.predict(X_te, verbose=0)         # TrustFed script
y_pred_final = np.argmax(y_prob_final, axis=1)
y_true_final = y_te                                    # Use y_te instead of y_test

# --- 2.  Confusion matrix (raw counts) ------------------
cm = confusion_matrix(y_true_final, y_pred_final)

# Optional: label list; replace with real class names if you have them
class_labels = [f"C{c}" for c in range(cm.shape[0])]

# --- 3.  Save the raw matrix for archival ---------------
# Define DATA_PATH as the filename of the dataset
DATA_PATH = "dataset.csv"
# Construct the cm_path using the current directory
cm_path = f"cm_{DATA_PATH.split('/')[-1].split('.')[0]}_FedAvg.csv" # Removed '/mnt/data/'
np.savetxt(cm_path, cm, delimiter=",", fmt="%d")
print("Confusion-matrix CSV written to", cm_path)

# --- 4.  Make a heat-map figure -------------------------
plt.figure(figsize=(4.5,4))
sns.heatmap(cm,
            annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=class_labels, yticklabels=class_labels,
            linewidths=.5, linecolor='grey')
plt.xlabel("Predicted label")
plt.ylabel("True label")
# Use a fixed string "FedAvg" for the title
plt.title("Confusion Matrix – FedAvg-IDS")
plt.tight_layout()

fig_path = cm_path.replace(".csv", ".png")
plt.savefig(fig_path, dpi=300)        # -> e.g. cm_WSN-DS_FedAvg.png
plt.close()
print("Figure saved to", fig_path)
# ---------------------------------------------------------

Confusion-matrix CSV written to cm_dataset_FedAvg.csv
Figure saved to cm_dataset_FedAvg.png
