In [1]:
# Install required packages
!pip install torch torch-geometric geoopt tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import kneighbors_graph
from scipy.sparse import coo_matrix
from tqdm import tqdm
import geoopt
from geoopt import PoincareBall
from torch.utils.data import DataLoader
from torch_geometric.data import Data
import os

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Check for CUDA availability and set the device accordingly. One of the two options below must be commented out
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print(f'Using device: {device}')

# Load the CSV file
file_path = '/content/drive/MyDrive/CIC_data.csv'  # Update this with the actual path to your CSV file in Google Drive
data = pd.read_csv(file_path, low_memory=False)

# Data preprocessing steps
data.columns = data.columns.str.strip()
data['Label'] = data['Label'].str.strip()

# Verify unique labels and their distribution
unique_labels = data['Label'].unique()
print(f"Unique labels in the dataset: {unique_labels}")
label_counts = data['Label'].value_counts()
print("Label distribution in the dataset:")
print(label_counts)

# Sample a smaller fraction of the rows (e.g., 20%)
data_sampled = data.sample(frac=0.10, random_state=42)

# Convert all columns to numeric, coerce errors to NaN
data_numeric = data_sampled.apply(pd.to_numeric, errors='coerce')

# Fill NaN values with the mean of each column
data_filled = data_numeric.fillna(data_numeric.mean())

# Handle infinite and very large values
data_filled = data_filled.replace([np.inf, -np.inf], np.nan)
data_filled = data_filled.fillna(data_filled.max())

# Check again for any remaining NaNs and fill them
if data_filled.isnull().values.any():
    data_filled = data_filled.fillna(0)

# Extract labels
labels = data_sampled['Label']
data_filled = data_filled.drop(columns=['Label'])

# Normalize the data
scaler = StandardScaler()
data_scaled = scaler.fit_transform(data_filled)

# Convert to PyTorch tensors
node_features = torch.tensor(data_scaled, dtype=torch.float32).to(device)
print(f"Input tensor shape: {node_features.shape}")

# UHG Operations
def uhg_quadrance(a, b, eps=1e-9):
    """Compute UHG quadrance between two points."""
    dot_product = torch.sum(a * b, dim=-1)
    return 1 - (dot_product ** 2) / ((torch.sum(a ** 2, dim=-1) - a[:, -1] ** 2 + eps) * (torch.sum(b ** 2, dim=-1) - b[:, -1] ** 2 + eps))

def uhg_spread(L, M, eps=1e-9):
    """Compute UHG spread between two lines."""
    dot_product = torch.sum(L * M, dim=-1)
    return 1 - (dot_product ** 2) / ((torch.sum(L ** 2, dim=-1) - L[:, -1] ** 2 + eps) * (torch.sum(M ** 2, dim=-1) - M[:, -1] ** 2 + eps))

# Transform the node features to UHG space
def to_uhg_space(x):
    """Transform Euclidean coordinates to UHG space."""
    return torch.cat([x, torch.ones(x.shape[0], 1, device=x.device)], dim=-1)


node_features_uhg = to_uhg_space(node_features)
print(f"Node features transformed to UHG space: {node_features_uhg.shape}")

# Create a k-nearest neighbors graph
k = 2  # Set k value to 2
knn_graph = kneighbors_graph(data_scaled, k, mode='connectivity', include_self=False)

# Convert knn_graph to COO format
knn_graph_coo = coo_matrix(knn_graph)

# Create edge index
edge_index_np = np.array([knn_graph_coo.row, knn_graph_coo.col])
edge_index = torch.from_numpy(edge_index_np).long().to(device)
print(f"Edge index shape: {edge_index.shape}")

# Convert labels to numeric
label_mapping = {label: idx for idx, label in enumerate(unique_labels)}
labels_numeric = labels.map(label_mapping).values
labels_tensor = torch.tensor(labels_numeric, dtype=torch.long).to(device)

# Create 70/15/15 train/val/test split
total_samples = node_features_uhg.size(0)
train_size = int(0.7 * total_samples)
val_size = int(0.15 * total_samples)

indices = torch.randperm(total_samples)
train_indices = indices[:train_size]
val_indices = indices[train_size:train_size+val_size]
test_indices = indices[train_size+val_size:]

train_mask = torch.zeros(total_samples, dtype=torch.bool)
val_mask = torch.zeros(total_samples, dtype=torch.bool)
test_mask = torch.zeros(total_samples, dtype=torch.bool)

train_mask[train_indices] = True
val_mask[val_indices] = True
test_mask[test_indices] = True

# Create the PyTorch Geometric data object
graph_data = Data(x=node_features_uhg, edge_index=edge_index, y=labels_tensor,
                  train_mask=train_mask, val_mask=val_mask, test_mask=test_mask).to(device)

print(f"Train size: {graph_data.train_mask.sum()}, Val size: {graph_data.val_mask.sum()}, Test size: {graph_data.test_mask.sum()}")

# Define the UHG Quadrance for prediction
def uhg_quadrance(a, b):
    """Compute UHG quadrance between two points."""
    dot_product = torch.sum(a * b, dim=-1)
    return 1 - (dot_product ** 2) / ((torch.sum(a ** 2, dim=-1) - a[:, -1] ** 2) * (torch.sum(b ** 2, dim=-1) - b[:, -1] ** 2))

# Define the UHG GraphSAGE Layer
class UHGGraphSAGELayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(UHGGraphSAGELayer, self).__init__()
        self.weight_neigh = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight_self = nn.Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight_neigh)
        nn.init.xavier_uniform_(self.weight_self)

    def forward(self, x, edge_index):
        row, col = edge_index

        # Neighbor aggregation
        neigh_sum = torch.zeros_like(x)
        neigh_sum.index_add_(0, row, x[col])
        neigh_count = torch.zeros(x.size(0), device=x.device)
        neigh_count.index_add_(0, row, torch.ones_like(row, dtype=torch.float))
        neigh_count = torch.clamp(neigh_count.unsqueeze(1), min=1)
        neigh_features = neigh_sum / neigh_count

        # Apply linear transformations
        neigh_transformed = torch.matmul(neigh_features, self.weight_neigh.t())
        self_transformed = torch.matmul(x, self.weight_self.t())

        # Combine using UHG-inspired operation (simplified addition)
        combined = neigh_transformed + self_transformed

        return F.relu(combined)

# UHG GraphSAGE Model
class UHGGraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout=0.2):
        super(UHGGraphSAGE, self).__init__()
        self.layers = nn.ModuleList()
        self.dropout = nn.Dropout(dropout)

        self.layers.append(UHGGraphSAGELayer(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.layers.append(UHGGraphSAGELayer(hidden_channels, hidden_channels))
        self.layers.append(UHGGraphSAGELayer(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for layer in self.layers[:-1]:
            x = self.dropout(F.relu(layer(x, edge_index)))
        x = self.layers[-1](x, edge_index)
        return x

# Initialize the model
in_channels = node_features_uhg.size(1)
hidden_channels = 128
out_channels = len(label_mapping)
num_layers = 2

# Define the loss
criterion = nn.CrossEntropyLoss()

# Create a simple DataLoader
batch_size = 16
accumulation_steps = 4

#Create Dataloader with smaller batch size
train_loader = DataLoader(range(graph_data.train_mask.sum()), batch_size=batch_size, shuffle=True)

# Training process with gradient accumulation
def train_with_accumulation(model, optimizer):
    model.train()
    total_loss = 0
    optimizer.zero_grad()  # Reset gradients
    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        batch = batch.to(device)

        # Get the features for the sampled nodes
        x = graph_data.x[graph_data.train_mask][batch]
        y = graph_data.y[graph_data.train_mask][batch]

        # Create a subgraph for the batch
        batch_node_ids = graph_data.train_mask.nonzero(as_tuple=True)[0][batch]
        edge_mask = torch.isin(graph_data.edge_index[0], batch_node_ids) & torch.isin(graph_data.edge_index[1], batch_node_ids)
        batch_edge_index = graph_data.edge_index[:, edge_mask]

        # Relabel nodes to have consecutive indices
        node_idx = torch.unique(batch_edge_index)
        idx_map = {int(idx): i for i, idx in enumerate(node_idx)}
        mapped_edge_index = torch.tensor([[idx_map[int(i)] for i in batch_edge_index[0]],
                                          [idx_map[int(i)] for i in batch_edge_index[1]]],
                                         dtype=torch.long,
                                         device=device)

        # Forward pass
        out = model(x, mapped_edge_index)
        loss = criterion(out, y) / accumulation_steps  # Scale loss by accumulation steps

        loss.backward()  # Backpropagate the loss

        # Accumulate gradients and update model weights every accumulation_steps batches
        if (batch_idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()  # Update weights
            optimizer.zero_grad()  # Reset gradients for next accumulation
            total_loss += loss.item() * accumulation_steps  # Accumulate loss

    return total_loss / len(train_loader)


# Evaluation function
@torch.no_grad()
def evaluate(model, mask):
    model.eval()
    node_indices = mask.nonzero(as_tuple=True)[0]
    sub_x = graph_data.x[node_indices]
    sub_y = graph_data.y[node_indices]
    edge_mask = torch.isin(graph_data.edge_index[0], node_indices) & torch.isin(graph_data.edge_index[1], node_indices)
    sub_edge_index = graph_data.edge_index[:, edge_mask]
    node_idx = torch.unique(sub_edge_index)
    idx_map = {int(idx): i for i, idx in enumerate(node_idx)}
    mapped_edge_index = torch.tensor([[idx_map[int(i)] for i in sub_edge_index[0]],
                                      [idx_map[int(i)] for i in sub_edge_index[1]]],
                                     dtype=torch.long,
                                     device=device)
    out = model(sub_x, mapped_edge_index)

    # For simplicity, using the model's output directly for classification
    pred = out.argmax(dim=1)  # Choose the class with the highest logit
    correct = (pred == sub_y).sum().item()
    accuracy = correct / len(node_indices)  # Calculate accuracy

    return accuracy  # Return the calculated accuracy


# Set the learning rate
best_lr = 0.01
print(f"Using learning rate: {best_lr}")

# Initialize the model
model = UHGGraphSAGE(in_channels=in_channels, hidden_channels=hidden_channels,
                     out_channels=out_channels, num_layers=num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=best_lr, weight_decay=1e-5)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10)

# Training loop
num_epochs = 400
best_val_acc = 0
patience = 20
counter = 0
best_model_path = '/content/drive/MyDrive/best_uhg_graphsage_model.pth'

for epoch in range(1, num_epochs + 1):
    try:
        # Train the model for one epoch
        loss = train_with_accumulation(model, optimizer)


        # Evaluate the model on validation and test sets
        val_acc = evaluate(model, graph_data.val_mask)
        test_acc = evaluate(model, graph_data.test_mask)

        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']

        # Adjust learning rate based on validation accuracy
        scheduler.step(val_acc)

        # Check if current model is the best so far
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            counter = 0
            torch.save(model.state_dict(), best_model_path)
        else:
            counter += 1

        # Print progress every 10 epochs
        if epoch % 10 == 0:
            print(f'Epoch: {epoch}, Loss: {loss:.4f}, Val Accuracy: {val_acc:.4f}, Test Accuracy: {test_acc:.4f}, Learning Rate: {current_lr:.6f}')

        if counter >= patience:
            print("Early stopping")
            break
    except RuntimeError as e:
        print(f"Error occurred in epoch {epoch}:")
        print(str(e))
        break

# After training, print the final learning rate
final_lr = optimizer.param_groups[0]['lr']
print(f"Final Learning Rate: {final_lr:.6f}")

# Load the best model and evaluate on the test set
if os.path.exists(best_model_path):
    model.load_state_dict(torch.load(best_model_path))
    final_test_acc = evaluate(model, graph_data.test_mask)
    print(f"Final Test Accuracy: {final_test_acc:.4f}")
else:
    print("No best model found. Training might not have completed successfully.")

Collecting torch-geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl.metadata (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.2/64.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting geoopt
  Downloading geoopt-0.5.0-py3-none-any.whl.metadata (6.7 kB)
Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading geoopt-0.5.0-py3-none-any.whl (90 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.1/90.1 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric, geoopt
Successfully installed geoopt-0.5.0 torch-geometric-2.5.3
Mounted at /content/drive
Using device: cuda
Unique labels in the dataset: ['BENIGN' 'DDoS' 'PortScan' 'Bot' 'Infiltration'
 'Web Attack � Brute Force' 'Web Attack � XSS'
 'Web Attack � Sql Injection' 'FTP-Patator' 'SSH-Patator

Training: 100%|██████████| 12385/12385 [00:53<00:00, 232.87it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.69it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.92it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 242.52it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.98it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.68it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.12it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 242.89it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.76it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.98it/s]


Epoch: 10, Loss: 0.0285, Val Accuracy: 0.9507, Test Accuracy: 0.9517, Learning Rate: 0.010000


Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.17it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.54it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 242.86it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 247.83it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.50it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.15it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.06it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.53it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.61it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 242.00it/s]


Epoch: 20, Loss: 0.0280, Val Accuracy: 0.9534, Test Accuracy: 0.9555, Learning Rate: 0.010000


Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.58it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.10it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.31it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.98it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 242.55it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.92it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.69it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.39it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.32it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.48it/s]


Epoch: 30, Loss: 0.0201, Val Accuracy: 0.9612, Test Accuracy: 0.9624, Learning Rate: 0.005000


Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.13it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.03it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.93it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.71it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 250.01it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.97it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 242.97it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 247.83it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.76it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.01it/s]


Epoch: 40, Loss: 0.0192, Val Accuracy: 0.9670, Test Accuracy: 0.9676, Learning Rate: 0.005000


Training: 100%|██████████| 12385/12385 [00:51<00:00, 240.98it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.31it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.43it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.53it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.24it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.22it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 250.09it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.04it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.90it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.60it/s]


Epoch: 50, Loss: 0.0172, Val Accuracy: 0.9704, Test Accuracy: 0.9719, Learning Rate: 0.002500


Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.26it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.44it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.15it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 247.92it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.77it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.82it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.46it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.66it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 250.66it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.28it/s]


Epoch: 60, Loss: 0.0162, Val Accuracy: 0.9721, Test Accuracy: 0.9715, Learning Rate: 0.002500


Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.30it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.62it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.59it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.33it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.05it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.07it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.73it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.01it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 242.36it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 247.79it/s]


Epoch: 70, Loss: 0.0154, Val Accuracy: 0.9689, Test Accuracy: 0.9704, Learning Rate: 0.001250


Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.73it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.14it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.45it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.11it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.72it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 242.93it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.56it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.72it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.43it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.79it/s]


Epoch: 80, Loss: 0.0145, Val Accuracy: 0.9743, Test Accuracy: 0.9761, Learning Rate: 0.001250


Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.62it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 247.87it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.73it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.22it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.06it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.69it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.55it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.96it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.57it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.92it/s]


Epoch: 90, Loss: 0.0142, Val Accuracy: 0.9776, Test Accuracy: 0.9777, Learning Rate: 0.000625


Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.57it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.13it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.24it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.65it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.96it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.06it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 240.48it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.13it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.19it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.04it/s]


Epoch: 100, Loss: 0.0132, Val Accuracy: 0.9789, Test Accuracy: 0.9792, Learning Rate: 0.000625


Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.39it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.04it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.58it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 240.37it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 247.81it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.31it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.54it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.05it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 241.97it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.99it/s]


Epoch: 110, Loss: 0.0125, Val Accuracy: 0.9813, Test Accuracy: 0.9820, Learning Rate: 0.000313


Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.97it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.22it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.73it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.28it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.16it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 240.31it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.08it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.74it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.51it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 242.67it/s]


Epoch: 120, Loss: 0.0128, Val Accuracy: 0.9812, Test Accuracy: 0.9815, Learning Rate: 0.000313


Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.31it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.03it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 244.92it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.43it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 246.12it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.07it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 240.72it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 248.22it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.35it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.51it/s]


Epoch: 130, Loss: 0.0119, Val Accuracy: 0.9817, Test Accuracy: 0.9823, Learning Rate: 0.000156


Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.43it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 243.85it/s]
Training: 100%|██████████| 12385/12385 [00:49<00:00, 249.34it/s]
Training: 100%|██████████| 12385/12385 [00:51<00:00, 241.32it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.28it/s]
Training: 100%|██████████| 12385/12385 [00:52<00:00, 237.47it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 247.37it/s]
Training: 100%|██████████| 12385/12385 [00:50<00:00, 245.81it/s]


Early stopping
Final Learning Rate: 0.000156


  model.load_state_dict(torch.load(best_model_path))


Final Test Accuracy: 0.9836
