In [1]:
import os
import random
import timeit
import wandb

import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm

import torch
from torch import nn
from torch import optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
RANDOM_SEED = 42
BATCH_SIZE = 512
EPOCHS = 30
LEARNING_RATE = 1e-3
PATCH_SIZE = 2
HEIGHT = 8
WIDTH = 8
IN_CHANNELS = 3
NUM_HEADS = 12
DROPOUT = 0.1
ADAM_WEIGHT_DECAY = 0
ADAM_BETAS = (0.9, 0.999)
ACTIVATION="gelu"
NUM_ENCODERS = 12
EMBED_DIM = (PATCH_SIZE ** 2) * IN_CHANNELS
NUM_PATCHES = (HEIGHT // PATCH_SIZE) * (WIDTH // PATCH_SIZE)
NUM_CLASSES = 10

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
device

'cuda'

In [4]:
class PatchEmbedding(nn.Module):
    def __init__(self, embed_dim, patch_size, num_patches, dropout, in_channels):
        super().__init__()
        self.patcher = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=embed_dim,
                kernel_size=patch_size,
                stride=patch_size,
            ),
            nn.Flatten(2))

        self.cls_token = nn.Parameter(torch.randn(size=(1, in_channels, embed_dim)), requires_grad=True)
        self.position_embeddings = nn.Parameter(torch.randn(size=(1, num_patches+in_channels, embed_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):        
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        x = self.patcher(x).permute(0, 2, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.position_embeddings + x
        x = self.dropout(x)
        return x

In [5]:
model = PatchEmbedding(EMBED_DIM, PATCH_SIZE, NUM_PATCHES, DROPOUT, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape)

torch.Size([512, 19, 12])


In [6]:
class ViT(nn.Module):
    def __init__(self, num_patches, num_classes, patch_size, embed_dim, num_encoders, num_heads, dropout, activation, in_channels):
        super().__init__()
        self.embeddings_block = PatchEmbedding(embed_dim, patch_size, num_patches, dropout, in_channels)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dropout=dropout, activation=activation, batch_first=True, norm_first=True)
        self.encoder_blocks = nn.TransformerEncoder(encoder_layer, num_layers=num_encoders)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=num_classes)
        )

    def forward(self, x):
        x = self.embeddings_block(x)
        x = self.encoder_blocks(x)
        x = self.mlp_head(x[:, 0, :])  # Apply MLP on the CLS token only
        return x

In [7]:
model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS).to(device)
x = torch.randn(BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH).to(device)
print(model(x).shape) # BATCH_SIZE X NUM_CLASSES

torch.Size([512, 10])




In [8]:
model

ViT(
  (embeddings_block): PatchEmbedding(
    (patcher): Sequential(
      (0): Conv2d(3, 12, kernel_size=(2, 2), stride=(2, 2))
      (1): Flatten(start_dim=2, end_dim=-1)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_blocks): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mlp_head): Sequential(
    (0): LayerNorm((12,), eps=1e-05

In [9]:
class UNSW_NB15(Dataset):
    BASE_PATH = "C:\VScode_Projects\DP\datasets\\UNSW_NB15"
    MAPPING_FILE = "\\unswnb15_img_flow.csv"
    index: int
    batch_size: int
    classes_count: int
    classes_list: list
    
    def __init__(self, shuffle: bool = False):        
        self.mapping = pd.read_csv(self.BASE_PATH+self.MAPPING_FILE)
        self.mapping = pd.get_dummies(self.mapping, columns=['label'])
        
        if shuffle:
            self.mapping = self.mapping.sample(frac=1) # shuffle
            
        self.classes_list = [label.split("_")[1] for label in self.mapping.columns[1:]]
        
        self.mapping = self.mapping.to_numpy()
        
        self.classes_count = len(self.mapping[0]) - 1
        
        self.transform = transforms.Compose([transforms.ToTensor()]) 
        
    def __len__(self):
        return len(self.mapping)
    
    def __getitem__(self, idx):
        img_name = self.mapping[idx, 0]
        img_path = os.path.join(self.BASE_PATH + "\image_flow", img_name)
        img = read_image(img_path)
        
        label = [1 if label_class is True else 0 for label_class in self.mapping[idx, 1:]]
        label = np.array(label)
        
        return img, label
    
    def translate_encoded_label(self, encoded_label):
        return self.classes_list[list(encoded_label).index(1)]

In [10]:
dataset = UNSW_NB15()
print(len(dataset))

162745


In [11]:
train_split = int(0.9 * len(dataset))
val_split = int(0.8 * len(dataset))
train, test = random_split(dataset, [train_split, len(dataset) - train_split])
train, val = random_split(train, [val_split, len(train) - val_split])


train_dataloader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test, batch_size=BATCH_SIZE, shuffle=True)

In [12]:
print(len(train))
print(len(val))
print(len(test))

130196
16274
16275


In [13]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([512, 3, 8, 8])
Labels batch shape: torch.Size([512, 10])


In [14]:
val_features, val_labels = next(iter(val_dataloader))
print(f"Feature batch shape: {val_features.size()}")
print(f"Labels batch shape: {val_labels.size()}")

Feature batch shape: torch.Size([512, 3, 8, 8])
Labels batch shape: torch.Size([512, 10])


In [15]:
test_features, test_labels = next(iter(test_dataloader))
print(f"Feature batch shape: {test_features.size()}")
print(f"Labels batch shape: {test_labels.size()}")

Feature batch shape: torch.Size([512, 3, 8, 8])
Labels batch shape: torch.Size([512, 10])


In [16]:
def precision_recall_f1(predictions, labels):
    y_true = []
    y_pred = []
    for x,y in zip(predictions, labels):
        y_pred.append(x)
        y_true.append(list(y).index(1.0))
        
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro")
    return p, r, f1

predictions = torch.Tensor(np.array([0, 1, 0, 0, 2]))
labels = torch.Tensor(np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]]))
p, r, f1 = precision_recall_f1(predictions, labels)
print(f"Precision: {p}")
print(f"Recall: {r}")
print(f"F1 score: {f1}")

Precision: 0.8888888888888888
Recall: 0.8333333333333334
F1 score: 0.8222222222222223


In [17]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), betas=ADAM_BETAS, lr=LEARNING_RATE, weight_decay=ADAM_WEIGHT_DECAY)

run = wandb.init(
    project = "DP",
    config={
        "learning_rate": LEARNING_RATE,
        "architecture": "ViT",
        "dataset": "UNSW-NB15-flow",
        "epochs": EPOCHS,
    }
)

start = timeit.default_timer()
for epoch in tqdm(range(EPOCHS), position=0, leave=True):
    model.train()
    train_labels = []
    train_preds = []
    train_running_loss = 0
    for idx, (img, label) in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)
        y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)

        train_labels.extend(label.cpu().detach())
        train_preds.extend(y_pred_label.cpu().detach())
        
        loss = criterion(y_pred, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_running_loss += loss.item()
    train_loss = train_running_loss / (idx + 1)

    model.eval()
    val_labels = []
    val_preds = []
    val_running_loss = 0
    with torch.no_grad():
        for idx, (img, label) in enumerate(tqdm(val_dataloader, position=0, leave=True)):
            img = img.float().to(device)
            label = label.float().to(device)         
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)
            
            val_labels.extend(label.cpu().detach())
            val_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            val_running_loss += loss.item()
    val_loss = val_running_loss / (idx + 1)

    print("-"*30)
    print(f"Train Loss EPOCH {epoch+1}: {train_loss:.4f}")
    print(f"Valid Loss EPOCH {epoch+1}: {val_loss:.4f}")
    train_accuracy = sum(1 for x,y in zip(train_preds, train_labels) if x == list(y).index(1.0)) / len(train_labels)
    print(f"Train Accuracy EPOCH {epoch+1}: {train_accuracy:.4f}")
    val_accuracy = sum(1 for x,y in zip(val_preds, val_labels) if x == list(y).index(1.0)) / len(val_labels)
    print(f"Valid Accuracy EPOCH {epoch+1}: {val_accuracy:.4f}")
    precision, recall, f1score = precision_recall_f1(train_preds, train_labels)
    print(f"Precision: {precision}, Recall: {recall}, F1 score: {f1score}")
    print("-"*30)
    wandb.log(
        {
            "epoch": epoch,
            "train_acc": train_accuracy,
            "train_loss": train_loss,
            "val_acc": val_accuracy,
            "val_loss": val_loss,
            "precision": precision,
            "recall": recall,
            "f1 score": f1score
        }
    )


stop = timeit.default_timer()
print(f"Training Time: {stop-start:.2f}s")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvikioza[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 255/255 [00:39<00:00,  6.45it/s]
100%|██████████| 32/32 [00:02<00:00, 11.66it/s]


------------------------------
Train Loss EPOCH 1: 1.3536
Valid Loss EPOCH 1: 0.8448
Train Accuracy EPOCH 1: 0.5837
Valid Accuracy EPOCH 1: 0.7140


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  3%|▎         | 1/30 [00:53<25:40, 53.13s/it]

Precision: 0.26549606714372304, Recall: 0.17115378245138274, F1 score: 0.1780462269747568
------------------------------


100%|██████████| 255/255 [00:23<00:00, 11.00it/s]
100%|██████████| 32/32 [00:01<00:00, 16.47it/s]


------------------------------
Train Loss EPOCH 2: 0.8146
Valid Loss EPOCH 2: 0.6916
Train Accuracy EPOCH 2: 0.7032
Valid Accuracy EPOCH 2: 0.7366


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  7%|▋         | 2/30 [01:30<20:21, 43.63s/it]

Precision: 0.3610168424030138, Recall: 0.33063505351050626, F1 score: 0.3288593072307982
------------------------------


100%|██████████| 255/255 [00:23<00:00, 10.85it/s]
100%|██████████| 32/32 [00:02<00:00, 13.83it/s]


------------------------------
Train Loss EPOCH 3: 0.7291
Valid Loss EPOCH 3: 0.6530
Train Accuracy EPOCH 3: 0.7213
Valid Accuracy EPOCH 3: 0.7504


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 10%|█         | 3/30 [02:07<18:22, 40.84s/it]

Precision: 0.37902916126027114, Recall: 0.3570156421137215, F1 score: 0.35676750443274274
------------------------------


100%|██████████| 255/255 [00:24<00:00, 10.51it/s]
100%|██████████| 32/32 [00:02<00:00, 15.59it/s]


------------------------------
Train Loss EPOCH 4: 0.6919
Valid Loss EPOCH 4: 0.6291
Train Accuracy EPOCH 4: 0.7322
Valid Accuracy EPOCH 4: 0.7535


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 13%|█▎        | 4/30 [02:45<17:13, 39.76s/it]

Precision: 0.40544321446151066, Recall: 0.36240932190830505, F1 score: 0.3629911195134409
------------------------------


100%|██████████| 255/255 [00:23<00:00, 11.08it/s]
100%|██████████| 32/32 [00:02<00:00, 15.18it/s]


------------------------------
Train Loss EPOCH 5: 0.6710
Valid Loss EPOCH 5: 0.6127
Train Accuracy EPOCH 5: 0.7387
Valid Accuracy EPOCH 5: 0.7596


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 17%|█▋        | 5/30 [03:21<15:59, 38.38s/it]

Precision: 0.41533370865411623, Recall: 0.36998823208522447, F1 score: 0.3735032798047636
------------------------------


100%|██████████| 255/255 [00:21<00:00, 11.60it/s]
100%|██████████| 32/32 [00:01<00:00, 17.19it/s]


------------------------------
Train Loss EPOCH 6: 0.6512
Valid Loss EPOCH 6: 0.6013
Train Accuracy EPOCH 6: 0.7447
Valid Accuracy EPOCH 6: 0.7629


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 20%|██        | 6/30 [03:56<14:51, 37.16s/it]

Precision: 0.4235544665690121, Recall: 0.37721449731558837, F1 score: 0.3839342069696967
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.56it/s]
100%|██████████| 32/32 [00:02<00:00, 15.86it/s]


------------------------------
Train Loss EPOCH 7: 0.6381
Valid Loss EPOCH 7: 0.6128
Train Accuracy EPOCH 7: 0.7486
Valid Accuracy EPOCH 7: 0.7569


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 23%|██▎       | 7/30 [04:31<13:59, 36.50s/it]

Precision: 0.42315638680491174, Recall: 0.3759657458998917, F1 score: 0.384781796206315
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.43it/s]
100%|██████████| 32/32 [00:01<00:00, 17.38it/s]


------------------------------
Train Loss EPOCH 8: 0.6264
Valid Loss EPOCH 8: 0.5930
Train Accuracy EPOCH 8: 0.7527
Valid Accuracy EPOCH 8: 0.7610


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 27%|██▋       | 8/30 [05:06<13:11, 36.00s/it]

Precision: 0.43388566908451576, Recall: 0.3877786651037438, F1 score: 0.3977719391957807
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.48it/s]
100%|██████████| 32/32 [00:02<00:00, 15.97it/s]


------------------------------
Train Loss EPOCH 9: 0.6198
Valid Loss EPOCH 9: 0.5936
Train Accuracy EPOCH 9: 0.7543
Valid Accuracy EPOCH 9: 0.7640


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 30%|███       | 9/30 [05:41<12:29, 35.70s/it]

Precision: 0.4393736696031333, Recall: 0.38923762653892263, F1 score: 0.401339012146523
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.47it/s]
100%|██████████| 32/32 [00:01<00:00, 16.91it/s]


------------------------------
Train Loss EPOCH 10: 0.6103
Valid Loss EPOCH 10: 0.5769
Train Accuracy EPOCH 10: 0.7571
Valid Accuracy EPOCH 10: 0.7694


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 33%|███▎      | 10/30 [06:16<11:49, 35.47s/it]

Precision: 0.44451714148421206, Recall: 0.40050789630154426, F1 score: 0.41351795568929345
------------------------------


100%|██████████| 255/255 [00:23<00:00, 10.89it/s]
100%|██████████| 32/32 [00:02<00:00, 15.60it/s]


------------------------------
Train Loss EPOCH 11: 0.6068
Valid Loss EPOCH 11: 0.5681
Train Accuracy EPOCH 11: 0.7576
Valid Accuracy EPOCH 11: 0.7720


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 37%|███▋      | 11/30 [06:52<11:19, 35.76s/it]

Precision: 0.4404499545132022, Recall: 0.3986344093861282, F1 score: 0.41045611343831334
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.22it/s]
100%|██████████| 32/32 [00:01<00:00, 17.15it/s]


------------------------------
Train Loss EPOCH 12: 0.6002
Valid Loss EPOCH 12: 0.5638
Train Accuracy EPOCH 12: 0.7600
Valid Accuracy EPOCH 12: 0.7741


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 40%|████      | 12/30 [07:28<10:41, 35.66s/it]

Precision: 0.4495500483113588, Recall: 0.4009777319940132, F1 score: 0.4156222149895517
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.50it/s]
100%|██████████| 32/32 [00:01<00:00, 17.31it/s]


------------------------------
Train Loss EPOCH 13: 0.5949
Valid Loss EPOCH 13: 0.5692
Train Accuracy EPOCH 13: 0.7607
Valid Accuracy EPOCH 13: 0.7718


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 43%|████▎     | 13/30 [08:03<10:03, 35.49s/it]

Precision: 0.44810225963843014, Recall: 0.40407645211049326, F1 score: 0.415667959069187
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.33it/s]
100%|██████████| 32/32 [00:01<00:00, 16.95it/s]


------------------------------
Train Loss EPOCH 14: 0.5896
Valid Loss EPOCH 14: 0.5496
Train Accuracy EPOCH 14: 0.7617
Valid Accuracy EPOCH 14: 0.7777


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 47%|████▋     | 14/30 [08:38<09:27, 35.47s/it]

Precision: 0.45218237813382833, Recall: 0.4096720252467153, F1 score: 0.4209915952307458
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.41it/s]
100%|██████████| 32/32 [00:01<00:00, 17.36it/s]


------------------------------
Train Loss EPOCH 15: 0.5825
Valid Loss EPOCH 15: 0.5471
Train Accuracy EPOCH 15: 0.7645
Valid Accuracy EPOCH 15: 0.7757


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 50%|█████     | 15/30 [09:13<08:50, 35.34s/it]

Precision: 0.45484759146265263, Recall: 0.41404069866162924, F1 score: 0.4248445102879582
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.50it/s]
100%|██████████| 32/32 [00:01<00:00, 17.50it/s]


------------------------------
Train Loss EPOCH 16: 0.5769
Valid Loss EPOCH 16: 0.5482
Train Accuracy EPOCH 16: 0.7659
Valid Accuracy EPOCH 16: 0.7769


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 53%|█████▎    | 16/30 [09:49<08:14, 35.31s/it]

Precision: 0.4594319138185085, Recall: 0.41630726095225706, F1 score: 0.4288114596613067
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.56it/s]
100%|██████████| 32/32 [00:02<00:00, 15.30it/s]


------------------------------
Train Loss EPOCH 17: 0.5747
Valid Loss EPOCH 17: 0.5356
Train Accuracy EPOCH 17: 0.7669
Valid Accuracy EPOCH 17: 0.7784


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 57%|█████▋    | 17/30 [10:24<07:37, 35.22s/it]

Precision: 0.4591668477737648, Recall: 0.41496904449069405, F1 score: 0.428242218502927
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.50it/s]
100%|██████████| 32/32 [00:01<00:00, 17.19it/s]


------------------------------
Train Loss EPOCH 18: 0.5693
Valid Loss EPOCH 18: 0.5294
Train Accuracy EPOCH 18: 0.7688
Valid Accuracy EPOCH 18: 0.7835


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 60%|██████    | 18/30 [10:59<07:01, 35.11s/it]

Precision: 0.4641959864685036, Recall: 0.4225584638305232, F1 score: 0.43632361091167926
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.36it/s]
100%|██████████| 32/32 [00:01<00:00, 16.22it/s]


------------------------------
Train Loss EPOCH 19: 0.5647
Valid Loss EPOCH 19: 0.5379
Train Accuracy EPOCH 19: 0.7702
Valid Accuracy EPOCH 19: 0.7795


 63%|██████▎   | 19/30 [11:34<06:26, 35.16s/it]

Precision: 0.4704291570341268, Recall: 0.424792451210375, F1 score: 0.4379185213500053
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.49it/s]
100%|██████████| 32/32 [00:01<00:00, 17.33it/s]


------------------------------
Train Loss EPOCH 20: 0.5636
Valid Loss EPOCH 20: 0.5270
Train Accuracy EPOCH 20: 0.7694
Valid Accuracy EPOCH 20: 0.7846


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 67%|██████▋   | 20/30 [12:09<05:50, 35.06s/it]

Precision: 0.46639955735232935, Recall: 0.42407465152172935, F1 score: 0.438162197727483
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.48it/s]
100%|██████████| 32/32 [00:01<00:00, 16.91it/s]


------------------------------
Train Loss EPOCH 21: 0.5609
Valid Loss EPOCH 21: 0.5249
Train Accuracy EPOCH 21: 0.7709
Valid Accuracy EPOCH 21: 0.7824


 70%|███████   | 21/30 [12:44<05:16, 35.15s/it]

Precision: 0.5668549921920057, Recall: 0.4232225394233312, F1 score: 0.4369268107477765
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.43it/s]
100%|██████████| 32/32 [00:01<00:00, 17.18it/s]


------------------------------
Train Loss EPOCH 22: 0.5557
Valid Loss EPOCH 22: 0.5242
Train Accuracy EPOCH 22: 0.7722
Valid Accuracy EPOCH 22: 0.7809


 73%|███████▎  | 22/30 [13:19<04:40, 35.11s/it]

Precision: 0.47581798434907696, Recall: 0.4286562715244863, F1 score: 0.4437910005575648
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.50it/s]
100%|██████████| 32/32 [00:01<00:00, 17.12it/s]


------------------------------
Train Loss EPOCH 23: 0.5525
Valid Loss EPOCH 23: 0.5248
Train Accuracy EPOCH 23: 0.7738
Valid Accuracy EPOCH 23: 0.7850


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 77%|███████▋  | 23/30 [13:54<04:05, 35.08s/it]

Precision: 0.47367248848371046, Recall: 0.42946449468035963, F1 score: 0.4443030695264585
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.52it/s]
100%|██████████| 32/32 [00:01<00:00, 16.94it/s]


------------------------------
Train Loss EPOCH 24: 0.5514
Valid Loss EPOCH 24: 0.5148
Train Accuracy EPOCH 24: 0.7739
Valid Accuracy EPOCH 24: 0.7866


 80%|████████  | 24/30 [14:29<03:30, 35.07s/it]

Precision: 0.4735129135654496, Recall: 0.4278631887989902, F1 score: 0.441813196213655
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.41it/s]
100%|██████████| 32/32 [00:01<00:00, 17.12it/s]


------------------------------
Train Loss EPOCH 25: 0.5487
Valid Loss EPOCH 25: 0.5100
Train Accuracy EPOCH 25: 0.7745
Valid Accuracy EPOCH 25: 0.7870


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
 83%|████████▎ | 25/30 [15:04<02:55, 35.08s/it]

Precision: 0.47648654208230407, Recall: 0.43304390629377554, F1 score: 0.44741402926283813
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.45it/s]
100%|██████████| 32/32 [00:01<00:00, 17.04it/s]


------------------------------
Train Loss EPOCH 26: 0.5450
Valid Loss EPOCH 26: 0.5084
Train Accuracy EPOCH 26: 0.7757
Valid Accuracy EPOCH 26: 0.7894


 87%|████████▋ | 26/30 [15:39<02:20, 35.09s/it]

Precision: 0.5271715272270254, Recall: 0.43281772372833666, F1 score: 0.4471503456420377
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.52it/s]
100%|██████████| 32/32 [00:01<00:00, 16.03it/s]


------------------------------
Train Loss EPOCH 27: 0.5445
Valid Loss EPOCH 27: 0.5124
Train Accuracy EPOCH 27: 0.7763
Valid Accuracy EPOCH 27: 0.7869


 90%|█████████ | 27/30 [16:14<01:45, 35.06s/it]

Precision: 0.5808299738104965, Recall: 0.44160828319234646, F1 score: 0.456158889477163
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.43it/s]
100%|██████████| 32/32 [00:01<00:00, 17.00it/s]


------------------------------
Train Loss EPOCH 28: 0.5436
Valid Loss EPOCH 28: 0.5115
Train Accuracy EPOCH 28: 0.7762
Valid Accuracy EPOCH 28: 0.7854


 93%|█████████▎| 28/30 [16:49<01:10, 35.08s/it]

Precision: 0.5801717625864583, Recall: 0.43725621930164105, F1 score: 0.45284197500361667
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.28it/s]
100%|██████████| 32/32 [00:01<00:00, 16.99it/s]


------------------------------
Train Loss EPOCH 29: 0.5397
Valid Loss EPOCH 29: 0.5083
Train Accuracy EPOCH 29: 0.7778
Valid Accuracy EPOCH 29: 0.7908


 97%|█████████▋| 29/30 [17:25<00:35, 35.15s/it]

Precision: 0.5273652216346045, Recall: 0.4357132237674962, F1 score: 0.45072873543744424
------------------------------


100%|██████████| 255/255 [00:22<00:00, 11.48it/s]
100%|██████████| 32/32 [00:01<00:00, 17.20it/s]


------------------------------
Train Loss EPOCH 30: 0.5392
Valid Loss EPOCH 30: 0.5116
Train Accuracy EPOCH 30: 0.7775
Valid Accuracy EPOCH 30: 0.7880


100%|██████████| 30/30 [18:00<00:00, 36.01s/it]

Precision: 0.5305376497334698, Recall: 0.44097556277423366, F1 score: 0.4562312193363166
------------------------------
Training Time: 1080.24s





In [18]:
# Save as artifact for version control.
torch.save(model.state_dict(), '../saved/model_unsw_flow')
artifact = wandb.Artifact('model_unsw_flow', type='model')
artifact.add_file('../saved/model_unsw_flow')
run.log_artifact(artifact)
run.finish()

0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
f1 score,▁▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████
precision,▁▃▄▄▄▅▄▅▅▅▅▅▅▅▅▅▅▅▆▅█▆▆▆▆▇██▇▇
recall,▁▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████
train_acc,▁▅▆▆▇▇▇▇▇▇▇▇▇▇████████████████
train_loss,█▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_acc,▁▃▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███████
val_loss,█▅▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,29.0
f1 score,0.45623
precision,0.53054
recall,0.44098
train_acc,0.77749
train_loss,0.53924
val_acc,0.78801
val_loss,0.51163


In [20]:
saved_model = ViT(NUM_PATCHES, NUM_CLASSES, PATCH_SIZE, EMBED_DIM, NUM_ENCODERS, NUM_HEADS, DROPOUT, ACTIVATION, IN_CHANNELS)
saved_model.load_state_dict(torch.load('../saved/model_unsw_flow'))
saved_model.to(device)
saved_model.eval()






ViT(
  (embeddings_block): PatchEmbedding(
    (patcher): Sequential(
      (0): Conv2d(3, 12, kernel_size=(2, 2), stride=(2, 2))
      (1): Flatten(start_dim=2, end_dim=-1)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_blocks): TransformerEncoder(
    (layers): ModuleList(
      (0-11): 12 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=12, out_features=12, bias=True)
        )
        (linear1): Linear(in_features=12, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=12, bias=True)
        (norm1): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((12,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (mlp_head): Sequential(
    (0): LayerNorm((12,), eps=1e-05

In [23]:
test_labels = []
test_preds = []
with torch.no_grad():
    for idx, (img, label) in enumerate(tqdm(test_dataloader, position=0, leave=True)):
        img = img.float().to(device)
        label = label.float().to(device)         
        y_pred = saved_model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        
        test_labels.extend(label.cpu().detach())
        test_preds.extend(y_pred_label.cpu().detach())

test_accuracy = sum(1 for x,y in zip(test_preds, test_labels) if x == list(y).index(1.0)) / len(test_labels)
print(f"Test Accuracy: {test_accuracy:.4f}")
t_precision, t_recall, t_f1score = precision_recall_f1(test_preds, test_labels)
print(f"Precision: {t_precision}, Recall: {t_recall}, F1 score: {t_f1score}")
print("-"*30)

100%|██████████| 32/32 [00:02<00:00, 14.51it/s]


Test Accuracy: 0.7889
Precision: 0.5055095899687999, Recall: 0.4676226628361956, F1 score: 0.4670443553685423
------------------------------


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
