In [None]:
import os
import random
import wandb

import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image

In [None]:
RANDOM_SEED = 42

# Input params
HEIGHT = 64
WIDTH = 128
IN_CHANNELS = 3
NUM_CLASSES = 10

# Embedding params
PATCH_SIZE = 16
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS
NUM_PATCHES = (HEIGHT // PATCH_SIZE) * (WIDTH // PATCH_SIZE)
DROPOUT = 0.01

# Model params
NUM_HEADS = 8
NUM_ENCODERS = 4
ACTIVATION="gelu"

# Optim params
ADAM_BETAS = (0.9, 0.999)
ADAM_WEIGHT_DECAY = 0
LEARNING_RATE = 1e-2

# Train loop params
EPOCHS = 30
BATCH_SIZE = 512

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size,
            ),
            nn.Flatten(2))

        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
        self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+in_channels, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):        
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings + x
        x = self.dropout(x)
        return x

In [None]:
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape)

In [None]:
class ViT(nn.Module):
    def __init__(self, num_patches, num_classes, patch_size, embed_dim, num_encoders, num_heads, dropout, activation, in_channels):
        super().__init__()
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  # Apply MLP on the CLS token only
        return x

In [None]:
model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES

In [None]:
model

In [None]:
class UNSW_NB15(Dataset):
    BASE_PATH = "C:\VScode Projects\FIIT_MASTERS\DP\datasets\\UNSW_NB15"
    MAPPING_FILE = "\\unswnb15_img_serialized_5_non_shuffled.csv"
    FOLDER = "\image_serialized_5_non_shuffled"
    index: int
    batch_size: int
    classes_count: int
    classes_list: list
    
    def __init__(self, shuffle: bool = False):        
        self.mapping = pd.read_csv(self.BASE_PATH+self.MAPPING_FILE)
        self.mapping = pd.get_dummies(self.mapping, columns=['label'])
        
        if shuffle:
            self.mapping = self.mapping.sample(frac=1) # shuffle
            
        self.classes_list = [label.split("_")[1] for label in self.mapping.columns[1:]]
        
        self.mapping = self.mapping.to_numpy()
        
        self.classes_count = len(self.mapping[0]) - 1
        
        self.transform = transforms.Compose([transforms.ToTensor()]) 
        
    def __len__(self):
        return len(self.mapping)
    
    def __getitem__(self, idx):
        img_name = self.mapping[idx, 0]
        img_path = os.path.join(self.BASE_PATH + self.FOLDER, img_name)
        img = read_image(img_path)
        
        label = [1 if label_class is True else 0 for label_class in self.mapping[idx, 1:]]
        label = np.array(label)
        
        return img, label
    
    def translate_encoded_label(self, encoded_label):
        return self.classes_list[list(encoded_label).index(1)]

In [None]:
dataset = UNSW_NB15()
print(len(dataset))

In [None]:
train_split = int(0.9 * len(dataset))
val_split = int(0.8 * len(dataset))
train, test = random_split(dataset, [train_split, len(dataset) - train_split])
train, val = random_split(train, [val_split, len(train) - val_split])


train_dataloader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
print(len(train))
print(len(val))
print(len(test))

In [None]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

In [None]:
val_features, val_labels = next(iter(val_dataloader))
print(f"Feature batch shape: {val_features.size()}")
print(f"Labels batch shape: {val_labels.size()}")

In [None]:
test_features, test_labels = next(iter(test_dataloader))
print(f"Feature batch shape: {test_features.size()}")
print(f"Labels batch shape: {test_labels.size()}")

In [None]:
def precision_recall_f1(predictions, labels):
    y_true = []
    y_pred = []
    for x,y in zip(predictions, labels):
        y_pred.append(x)
        y_true.append(list(y).index(1.0))
        
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro")
    return p, r, f1

predictions = torch.Tensor(np.array([0, 1, 0, 0, 2]))
labels = torch.Tensor(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]]))
p, r, f1 = precision_recall_f1(predictions, labels)
print(f"Precision: {p}")
print(f"Recall: {r}")
print(f"F1 score: {f1}")

In [None]:
from typing import Any, Callable, Concatenate, Optional, ParamSpec, TypeVar, Dict
from functools import wraps

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

P = ParamSpec("P")
A = TypeVar("A")
B = TypeVar("B")

def wandb_init(config_: Dict[str, str]):
    def wandb_init_(func: Callable[Concatenate[A, P], B]):
        @wraps(func)
        def wrapper(*args: P.args, **kwargs: P.kwargs) -> Optional[Any]:
            wandb.init(project=config_['project_name'] if config_ is not None else None, config=config_, dir=CURRENT_DIR)
            result = func(*args, **kwargs)
            wandb.finish()

            return result

        return wrapper

    return wandb_init_


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)

run = wandb.init(
    project = "DP",
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "ViT",
        "dataset": "UNSW-NB15-payload",
        "epochs": EPOCHS,
    }
)

def fit():
    def train_loop():
        model.train()
        train_labels = []
        train_preds = []
        train_running_loss = 0
        for idx, (img, label) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
            img = img.float().to(device)
            label = label.float().to(device)
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)

            train_labels.extend(label.cpu().detach())
            train_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_running_loss += loss.item()
        return train_running_loss / (idx + 1), train_labels, train_preds
        
    def valid():
        model.eval()
        val_labels = []
        val_preds = []
        val_running_loss = 0
        with torch.no_grad():
            for idx, (img, label) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
                img = img.float().to(device)
                label = label.float().to(device)         
                y_pred = model(img)
                y_pred_label = torch.argmax(y_pred, dim=1)
                
                val_labels.extend(label.cpu().detach())
                val_preds.extend(y_pred_label.cpu().detach())
                
                loss = criterion(y_pred, label)
                val_running_loss += loss.item()
        return val_running_loss / (idx + 1), val_labels, val_preds
        
    for epoch in tqdm(range(EPOCHS), position=0, leave=True):
        train_loss, train_labels, train_preds = train_loop()
        val_loss, val_labels, val_preds = valid()
        
        print("-"*30)
        print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
        print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
        train_accuracy = sum(1 for x,y in zip(train_preds, train_labels) if x == list(y).index(1.0)) / len(train_labels)
        print(f"Train Accuracy EPOCH {epoch+1}: {train_accuracy:.4f}")
        val_accuracy = sum(1 for x,y in zip(val_preds, val_labels) if x == list(y).index(1.0)) / len(val_labels)
        print(f"Valid Accuracy EPOCH {epoch+1}: {val_accuracy:.4f}")
        precision, recall, f1score = precision_recall_f1(train_preds, train_labels)
        print(f"Precision: {precision}, Recall: {recall}, F1 score: {f1score}")
        print("-"*30)
        wandb.log(
            {
                "epoch": epoch,
                "train_acc": train_accuracy,
                "train_loss": train_loss,
                "val_acc": val_accuracy,
                "val_loss": val_loss,
                "precision": precision,
                "recall": recall,
                "f1 score": f1score
            }
        )

In [None]:
# Save as artifact for version control.
torch.save(model.state_dict(), 'saved/model_test_2')
artifact = wandb.Artifact('model_test_2', type='model')
artifact.add_file('saved/model_test_2')
run.log_artifact(artifact)
run.finish()