In [None]:
import os
import sys
import random
import timeit
import wandb

import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image

# Get the directory where the current script is located
current_dir = os.path.dirname(os.getcwd()).split('\\')

# Construct the path to your target folder (e.g., 'data' inside the repo)
target_folder = "/".join(current_dir[:current_dir.index('src')+1])
sys.path.append(os.path.abspath(target_folder))

from models.model_definition import PatchEmbedding, ViT
from models.dataset_definition import UnswNb15

In [None]:
RANDOM_SEED = 42
BATCH_SIZE = 512
EPOCHS = 5
LEARNING_RATE = 1e-3
PATCH_SIZE = 8
HEIGHT = 32
WIDTH = 64
IN_CHANNELS = 3
NUM_HEADS = 8
DROPOUT = 0.1
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 8
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # (8**2)*3=192
NUM_PATCHES = (HEIGHT // PATCH_SIZE) * (WIDTH // PATCH_SIZE) # 4*8=32
NUM_CLASSES = 2

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

save_folder = target_folder + "/models/saved/"
print(save_folder)

In [None]:
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape)

In [None]:
model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES

In [None]:
model

In [None]:
dataset = UnswNb15(binary=True, mapping_file_name="unswnb15_img_selection.csv")
print(len(dataset))

In [None]:
dataset.classes_list

In [None]:
train_split = int(0.8 * len(dataset))
train, val = random_split(dataset, [train_split, len(dataset) - train_split])

train_dataloader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
print(len(train))
print(len(val))

In [None]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

In [None]:
val_features, val_labels = next(iter(val_dataloader))
print(f"Feature batch shape: {val_features.size()}")
print(f"Labels batch shape: {val_labels.size()}")

In [None]:
def precision_recall_f1(predictions, labels):
    y_true = []
    y_pred = []
    for x,y in zip(predictions, labels):
        y_pred.append(x)
        y_true.append(list(y).index(1.0))
        
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro")
    return p, r, f1

predictions = torch.Tensor(np.array([0, 1, 0, 0, 1]))
labels = torch.Tensor(np.array([[1, 0], [0, 1], [0, 1], [1, 0], [1, 0]]))
p, r, f1 = precision_recall_f1(predictions, labels)
print(f"Precision: {p}")
print(f"Recall: {r}")
print(f"F1 score: {f1}")

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)

run = wandb.init(
    project = "DP",
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "ViT",
        "dataset": "UNSW-NB15-payload-exp",
        "epochs": EPOCHS,
    }
)

start = timeit.default_timer()
for epoch in tqdm(range(EPOCHS), position=0, leave=True):
    model.train()
    train_labels = []
    train_preds = []
    train_running_loss = 0
    for idx, (img, label) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)

        train_labels.extend(label.cpu().detach())
        train_preds.extend(y_pred_label.cpu().detach())
        
        loss = criterion(y_pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item()
    train_loss = train_running_loss / (idx + 1)

    model.eval()
    val_labels = []
    val_preds = []
    val_running_loss = 0
    with torch.no_grad():
        for idx, (img, label) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
            img = img.float().to(device)
            label = label.float().to(device)         
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)
            
            val_labels.extend(label.cpu().detach())
            val_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            val_running_loss += loss.item()
    val_loss = val_running_loss / (idx + 1)

    print("-"*30)
    print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
    print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
    train_accuracy = sum(1 for x,y in zip(train_preds, train_labels) if x == list(y).index(1.0)) / len(train_labels)
    print(f"Train Accuracy EPOCH {epoch+1}: {train_accuracy:.4f}")
    val_accuracy = sum(1 for x,y in zip(val_preds, val_labels) if x == list(y).index(1.0)) / len(val_labels)
    print(f"Valid Accuracy EPOCH {epoch+1}: {val_accuracy:.4f}")
    precision, recall, f1score = precision_recall_f1(train_preds, train_labels)
    print(f"Precision: {precision}, Recall: {recall}, F1 score: {f1score}")
    print("-"*30)
    wandb.log(
        {
            "epoch": epoch,
            "train_acc": train_accuracy,
            "train_loss": train_loss,
            "val_acc": val_accuracy,
            "val_loss": val_loss,
            "precision": precision,
            "recall": recall,
            "f1 score": f1score
        }
    )


stop = timeit.default_timer()
print(f"Training Time: {stop-start:.2f}s")

In [None]:
# Save as artifact for version control.
torch.save(model.state_dict(), save_folder+'model_unsw_payload_binary_exp')
artifact = wandb.Artifact('model_unsw_payload_binary_exp', type='model')
artifact.add_file(save_folder+'model_unsw_payload_binary_exp')
run.log_artifact(artifact)
run.finish()

In [None]:
saved_model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS)
saved_model.load_state_dict(torch.load(save_folder+'model_unsw_payload_binary_v2'))
saved_model.to(device)
saved_model.eval()

In [None]:
dataset = UnswNb15(binary=True, mapping_file_name="unswnb15_img_new_attacks.csv")
print(len(dataset))

In [None]:
test_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
test_labels = []
test_preds = []
with torch.no_grad():
    for idx, (img, label) in enumerate(tqdm(test_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)         
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        
        test_labels.extend(label.cpu().detach())
        test_preds.extend(y_pred_label.cpu().detach())

test_accuracy = sum(1 for x,y in zip(test_preds, test_labels) if x == list(y).index(1.0)) / len(test_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")
t_precision, t_recall, t_f1score = precision_recall_f1(test_preds, test_labels)
print(f"Precision: {t_precision}, Recall: {t_recall}, F1 score: {t_f1score}")
print("-"*30)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

cm = confusion_matrix([list(y).index(1.0) for y in test_labels], test_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Anomaly", "Normal"])
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()