In [None]:
import random
import timeit
import wandb

import numpy as np
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from model_definition import ViT
from dataset_definition import CicIds2017
from model_utils import precision_recall_f1

In [None]:
CHECKPOINT = True
MODEL_NAME = "model_cic_payload_bin_serial_bi_dir_exp_4"

EPOCHS = 1
RANDOM_SEED = 42
BATCH_SIZE = 64
LEARNING_RATE = 1e-5
PATCH_SIZE = 8
HEIGHT = 128
WIDTH = 128
IN_CHANNELS = 3
NUM_HEADS = 16
DROPOUT = 0.1
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 24
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # (8**2)*3=192
NUM_PATCHES = (HEIGHT // PATCH_SIZE) * (WIDTH // PATCH_SIZE) # 4*8=32
NUM_CLASSES = 2

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES

In [None]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
model, trainable_params

In [None]:
dataset = CicIds2017(mapping_file_name="clean\cicids2017_img_bi_dir_selection.csv", image_folder_name="clean\image_bi_dir", binary=True, hdf5=True)
print(len(dataset), len(dataset.classes_list))

In [None]:
val_split = int(0.9 * len(dataset))
train, val = random_split(dataset, [val_split, len(dataset) - val_split])
print(len(train))
print(len(val))

train_dataloader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)

if CHECKPOINT:
    checkpoint = torch.load("saved/" + MODEL_NAME)
    starting_epoch = checkpoint.get("epoch")
    run_id = checkpoint.get("run_id")
    model.load_state_dict(checkpoint.get("model_state"))
    optimizer.load_state_dict(checkpoint.get("optimizer_state"))
else:
    starting_epoch = 0
    run_id = wandb.util.generate_id()

In [None]:
run = wandb.init(
    project = "DP",
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "ViT",
        "dataset": "CIC-IDS-2017-payload-exp-new",
        "epochs": EPOCHS,
    },
    id=run_id,
    resume="allow",
)

In [None]:
start = timeit.default_timer()
for epoch in tqdm(range(starting_epoch, starting_epoch+EPOCHS), position=0, leave=True):
    model.train()
    train_labels = []
    train_preds = []
    train_running_loss = 0
    for idx, (img, label) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)

        train_labels.extend(label.cpu().detach())
        train_preds.extend(y_pred_label.cpu().detach())
        
        loss = criterion(y_pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item()

    train_loss = train_running_loss / (idx + 1)

    model.eval()
    val_labels = []
    val_preds = []
    val_running_loss = 0
    with torch.no_grad():
        for idx, (img, label) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
            img = img.float().to(device)
            label = label.float().to(device)         
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)
            
            val_labels.extend(label.cpu().detach())
            val_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            val_running_loss += loss.item()
    val_loss = val_running_loss / (idx + 1)
    
    print("-"*30)
    print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
    print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
    train_accuracy = sum(1 for x,y in zip(train_preds, train_labels) if x == list(y).index(1.0)) / len(train_labels)
    print(f"Train Accuracy EPOCH {epoch+1}: {train_accuracy:.4f}")
    val_accuracy = sum(1 for x,y in zip(val_preds, val_labels) if x == list(y).index(1.0)) / len(val_labels)
    print(f"Valid Accuracy EPOCH {epoch+1}: {val_accuracy:.4f}")
    precision, recall, f1score = precision_recall_f1(train_preds, train_labels)
    print(f"Precision: {precision}, Recall: {recall}, F1 score: {f1score}")
    print("-"*30)
    
    torch.save(
        {
            "epoch": starting_epoch+EPOCHS,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "run_id": run_id,        
        },
        "saved/" + MODEL_NAME
    )
    
    wandb.log(
        { 
            "epoch": epoch,
            "train_acc": train_accuracy,
            "train_loss": train_loss,
            "val_acc": val_accuracy,
            "val_loss": val_loss,
            "precision": precision,
            "recall": recall,
            "f1 score": f1score
        }
    )

stop = timeit.default_timer()
print(f"Training Time: {stop-start:.2f}s")

In [None]:
val_labels = []
val_preds = []
with torch.no_grad():
    for idx, (img, label) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)         
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        
        val_labels.extend(label.cpu().detach())
        val_preds.extend(y_pred_label.cpu().detach())

test_accuracy = sum(1 for x,y in zip(val_preds, val_labels) if x == list(y).index(1.0)) / len(val_labels)
print(f"Val Accuracy: {test_accuracy:.4f}")
t_precision, t_recall, t_f1score = precision_recall_f1(val_preds, val_labels)
print(f"Precision: {t_precision}, Recall: {t_recall}, F1 score: {t_f1score}")
print("-"*30)

In [None]:
artifact = wandb.Artifact(MODEL_NAME, type='model')
artifact.add_file("saved/" + MODEL_NAME)
run.log_artifact(artifact)
run.finish()

In [None]:
cm = confusion_matrix([list(y).index(1.0) for y in val_labels], val_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Anomaly", "Normal"])
fig, ax = plt.subplots(figsize=(8, 8))
disp.plot(cmap="Blues", ax=ax)
plt.title("Confusion Matrix")
plt.show()

In [None]:
saved_model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS)
checkpoint = torch.load("best/" + MODEL_NAME)
saved_model.load_state_dict(checkpoint.get("model_state"))
# saved_model.load_state_dict(torch.load("best/" + "model_cic_payload_bin_serial_bi_dir_exp_2"))
saved_model.to(device)
saved_model.eval()

In [None]:
val_labels = []
val_preds = []
with torch.no_grad():
    for idx, (img, label) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)         
        y_pred = saved_model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        
        val_labels.extend(label.cpu().detach())
        val_preds.extend(y_pred_label.cpu().detach())

test_accuracy = sum(1 for x,y in zip(val_preds, val_labels) if x == list(y).index(1.0)) / len(val_labels)
print(f"Val Accuracy: {test_accuracy:.4f}")
t_precision, t_recall, t_f1score = precision_recall_f1(val_preds, val_labels)
print(f"Precision: {t_precision}, Recall: {t_recall}, F1 score: {t_f1score}")
print("-"*30)

In [None]:
cm = confusion_matrix([list(y).index(1.0) for y in val_labels], val_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Anomaly", "Normal"])
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()

In [None]:
saved_model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS)
checkpoint = torch.load("saved/" + MODEL_NAME)
saved_model.load_state_dict(checkpoint.get("model_state"))
# saved_model.load_state_dict(torch.load("best/" + "model_cic_payload_bin_serial_bi_dir_exp_2"))
saved_model.to(device)
saved_model.eval()

In [None]:
test_set = CicIds2017(binary=True, image_folder_name="clean\image_bi_dir", mapping_file_name="clean\cicids2017_img_bi_dir_new_attacks.csv", hdf5=True)
test_dataloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
len(test_set)

In [None]:
test_labels = []
test_preds = []
with torch.no_grad():
    for idx, (img, label) in enumerate(tqdm(test_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)         
        y_pred = saved_model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        
        test_labels.extend(label.cpu().detach())
        test_preds.extend(y_pred_label.cpu().detach())

test_accuracy = sum(1 for x,y in zip(test_preds, test_labels) if x == list(y).index(1.0)) / len(test_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")
t_precision, t_recall, t_f1score = precision_recall_f1(test_preds, test_labels)
print(f"Precision: {t_precision}, Recall: {t_recall}, F1 score: {t_f1score}")
print("-"*30)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix([list(y).index(1.0) for y in test_labels], test_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Anomaly", "Normal"])
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()