In [18]:
import os
import random
import timeit
import wandb

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image
from tqdm import tqdm

In [19]:
RANDOM_SEED = 42
BATCH_SIZE = 1024
EPOCHS = 10
LEARNING_RATE = 1e-3
PATCH_SIZE = 8
HEIGHT = 32
WIDTH = 64
IN_CHANNELS = 3
NUM_HEADS = 8
DROPOUT = 0.1
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 4
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS # (8**2)*3=192
NUM_PATCHES = (HEIGHT // PATCH_SIZE) * (WIDTH // PATCH_SIZE) # 4*8=32

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"

In [20]:
device

'cuda'

In [21]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size,
            ),
            nn.Flatten(2))

        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
        self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+in_channels, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):        
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings + x
        x = self.dropout(x)
        return x

In [22]:
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape)

torch.Size([1024, 35, 192])


In [23]:
class ViT(nn.Module):
    def __init__(self, num_patches, num_classes, patch_size, embed_dim, num_encoders, num_heads, dropout, activation, in_channels):
        super().__init__()
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  # Apply MLP on the CLS token only
        return x

In [24]:
model = ViT(NUM_PATCHES, 15, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES



torch.Size([1024, 15])


In [25]:
model

ViT(
  (embeddings_block): PatchEmbedding(
    (patcher): Sequential(
      (0): Conv2d(3, 192, kernel_size=(8, 8), stride=(8, 8))
      (1): Flatten(start_dim=2, end_dim=-1)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_blocks): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
        )
        (linear1): Linear(in_features=192, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=192, bias=True)
        (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mlp_head): Sequential(
    (0): LayerNorm((192,), eps

In [26]:
class CicIds2017(Dataset):
    BASE_PATH = "C:\VScode Projects\FIIT_MASTERS\DP\datasets\CIC-IDS-2017"
    MAPPING_FILE = "\cicids2017_img.csv"
    index: int
    batch_size: int
    classes_count: int
    classes_list: list
    
    def __init__(self, shuffle: bool = False):        
        self.mapping = pd.read_csv(self.BASE_PATH+self.MAPPING_FILE)
        self.mapping = pd.get_dummies(self.mapping, columns=['label'])
        
        if shuffle:
            self.mapping = self.mapping.sample(frac=1) # shuffle
            
        self.classes_list = [label.split("_")[1] for label in self.mapping.columns[1:]]
        
        self.mapping = self.mapping.to_numpy()
        
        self.classes_count = len(self.mapping[0]) - 1
        
        self.transform = transforms.Compose([transforms.ToTensor()]) 
        
    def __len__(self):
        return len(self.mapping)
    
    def __getitem__(self, idx):
        img_name = self.mapping[idx, 0]
        img_path = os.path.join(self.BASE_PATH + "\image", img_name)
        img = read_image(img_path)
        
        label = [1 if label_class is True else 0 for label_class in self.mapping[idx, 1:]]
        label = np.array(label)
        
        return img, label
    
    def translate_encoded_label(self, encoded_label):
        return self.classes_list[list(encoded_label).index(1)]

In [27]:
dataset = CicIds2017()
print(len(dataset))

763416


In [28]:
train_split = int(0.9 * len(dataset))
val_split = int(0.8 * len(dataset))
train, test = random_split(dataset, [train_split, len(dataset) - train_split])
train, val = random_split(train, [val_split, len(train) - val_split])


train_dataloader = DataLoader(train, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test, batch_size=BATCH_SIZE)

In [29]:
print(len(train))
print(len(val))
print(len(test))

610732
76342
76342


In [30]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([1024, 3, 32, 64])
Labels batch shape: torch.Size([1024, 15])


In [31]:
val_features, val_labels = next(iter(val_dataloader))
print(f"Feature batch shape: {val_features.size()}")
print(f"Labels batch shape: {val_labels.size()}")

Feature batch shape: torch.Size([1024, 3, 32, 64])
Labels batch shape: torch.Size([1024, 15])


In [32]:
test_features, test_labels = next(iter(test_dataloader))
print(f"Feature batch shape: {test_features.size()}")
print(f"Labels batch shape: {test_labels.size()}")

Feature batch shape: torch.Size([1024, 3, 32, 64])
Labels batch shape: torch.Size([1024, 15])


In [33]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)

run = wandb.init(
    project = "DP",
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "ViT",
        "dataset": "CIC-IDS-2017-payload",
        "epochs": EPOCHS,
    }
)

start = timeit.default_timer()
for epoch in tqdm(range(EPOCHS), position=0, leave=True):
    model.train()
    train_labels = []
    train_preds = []
    train_running_loss = 0
    for idx, (img, label) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)

        train_labels.extend(label.cpu().detach())
        train_preds.extend(y_pred_label.cpu().detach())
        
        loss = criterion(y_pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item()
    train_loss = train_running_loss / (idx + 1)

    model.eval()
    val_labels = []
    val_preds = []
    val_running_loss = 0
    with torch.no_grad():
        for idx, (img, label) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
            img = img.float().to(device)
            label = label.float().to(device)         
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)
            
            val_labels.extend(label.cpu().detach())
            val_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            val_running_loss += loss.item()
    val_loss = val_running_loss / (idx + 1)

    print("-"*30)
    print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
    print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
    train_accuracy = sum(1 for x,y in zip(train_preds, train_labels) if x == list(y).index(1.0)) / len(train_labels)
    print(f"Train Accuracy EPOCH {epoch+1}: {train_accuracy:.4f}")
    val_accuracy = sum(1 for x,y in zip(val_preds, val_labels) if x == list(y).index(1.0)) / len(val_labels)
    print(f"Valid Accuracy EPOCH {epoch+1}: {val_accuracy:.4f}")
    print("-"*30)
    wandb.log({"epoch": epoch,"train_acc": train_accuracy, "train_loss": train_loss, "val_acc": val_accuracy, "val_loss": val_loss})


stop = timeit.default_timer()
print(f"Training Time: {stop-start:.2f}s")

100%|██████████| 597/597 [03:32<00:00,  2.81it/s]
100%|██████████| 75/75 [00:16<00:00,  4.48it/s]


------------------------------
Train Loss EPOCH 1: 0.7114
Valid Loss EPOCH 1: 0.6474
Train Accuracy EPOCH 1: 0.6840


 10%|█         | 1/10 [04:12<37:54, 252.68s/it]

Valid Accuracy EPOCH 1: 0.6995
------------------------------


100%|██████████| 597/597 [03:36<00:00,  2.76it/s]
100%|██████████| 75/75 [00:17<00:00,  4.34it/s]


------------------------------
Train Loss EPOCH 2: 0.6414
Valid Loss EPOCH 2: 0.6475
Train Accuracy EPOCH 2: 0.7035


 20%|██        | 2/10 [08:30<34:06, 255.82s/it]

Valid Accuracy EPOCH 2: 0.6979
------------------------------


100%|██████████| 597/597 [03:35<00:00,  2.77it/s]
100%|██████████| 75/75 [00:16<00:00,  4.46it/s]


------------------------------
Train Loss EPOCH 3: 0.6364
Valid Loss EPOCH 3: 0.6403
Train Accuracy EPOCH 3: 0.7056


 30%|███       | 3/10 [12:46<29:49, 255.67s/it]

Valid Accuracy EPOCH 3: 0.7034
------------------------------


100%|██████████| 597/597 [03:38<00:00,  2.74it/s]
100%|██████████| 75/75 [00:17<00:00,  4.36it/s]


------------------------------
Train Loss EPOCH 4: 0.6374
Valid Loss EPOCH 4: 0.6387
Train Accuracy EPOCH 4: 0.7059


 40%|████      | 4/10 [17:05<25:42, 257.17s/it]

Valid Accuracy EPOCH 4: 0.7036
------------------------------


100%|██████████| 597/597 [03:44<00:00,  2.66it/s]
100%|██████████| 75/75 [00:17<00:00,  4.39it/s]


------------------------------
Train Loss EPOCH 5: 0.6326
Valid Loss EPOCH 5: 0.6384
Train Accuracy EPOCH 5: 0.7076


 50%|█████     | 5/10 [21:30<21:38, 259.78s/it]

Valid Accuracy EPOCH 5: 0.7038
------------------------------


100%|██████████| 597/597 [03:46<00:00,  2.64it/s]
100%|██████████| 75/75 [00:18<00:00,  4.04it/s]


------------------------------
Train Loss EPOCH 6: 0.6319
Valid Loss EPOCH 6: 0.6398
Train Accuracy EPOCH 6: 0.7074


 60%|██████    | 6/10 [26:01<17:35, 263.79s/it]

Valid Accuracy EPOCH 6: 0.7036
------------------------------


100%|██████████| 597/597 [03:51<00:00,  2.58it/s]
100%|██████████| 75/75 [00:19<00:00,  3.81it/s]


------------------------------
Train Loss EPOCH 7: 0.6396
Valid Loss EPOCH 7: 0.6396
Train Accuracy EPOCH 7: 0.7055


 70%|███████   | 7/10 [30:36<13:22, 267.51s/it]

Valid Accuracy EPOCH 7: 0.7035
------------------------------


100%|██████████| 597/597 [03:38<00:00,  2.74it/s]
100%|██████████| 75/75 [00:17<00:00,  4.36it/s]


------------------------------
Train Loss EPOCH 8: 0.6324
Valid Loss EPOCH 8: 0.6371
Train Accuracy EPOCH 8: 0.7073


 80%|████████  | 8/10 [34:55<08:49, 264.66s/it]

Valid Accuracy EPOCH 8: 0.7046
------------------------------


100%|██████████| 597/597 [03:38<00:00,  2.73it/s]
100%|██████████| 75/75 [00:17<00:00,  4.38it/s]


------------------------------
Train Loss EPOCH 9: 0.6306
Valid Loss EPOCH 9: 0.6365
Train Accuracy EPOCH 9: 0.7080


 90%|█████████ | 9/10 [39:13<04:22, 262.77s/it]

Valid Accuracy EPOCH 9: 0.7044
------------------------------


100%|██████████| 597/597 [03:38<00:00,  2.73it/s]
100%|██████████| 75/75 [00:16<00:00,  4.42it/s]


------------------------------
Train Loss EPOCH 10: 0.6305
Valid Loss EPOCH 10: 0.6362
Train Accuracy EPOCH 10: 0.7079


100%|██████████| 10/10 [43:32<00:00, 261.26s/it]

Valid Accuracy EPOCH 10: 0.7043
------------------------------
Training Time: 2612.60s





In [34]:
# Save as artifact for version control.
torch.save(model.state_dict(), 'saved/model_test')
artifact = wandb.Artifact('model_test', type='model')
artifact.add_file('saved/model_test')
run.log_artifact(artifact)
run.finish()

0,1
epoch,▁▂▃▃▄▅▆▆▇█
train_acc,▁▇▇▇██▇███
train_loss,█▂▂▂▁▁▂▁▁▁
val_acc,▃▁▇▇▇▇▇███
val_loss,██▄▃▂▃▃▂▁▁

0,1
epoch,9.0
train_acc,0.70795
train_loss,0.63053
val_acc,0.70425
val_loss,0.63624
