In [None]:
import gc
import math
import os
import random

import matplotlib.pyplot as plt
import numpy as np
from numpy import arange

import torch
import torch.nn as nn
import torchmetrics
from torchmetrics import *
from torchsummary import summary
from sklearn.metrics import *

import torch.nn.functional as F
import torch.optim as optim
import torch.utils.checkpoint as cp
import torch.utils.data as ud
from torch.nn.utils import clip_grad as cg
from torch.optim.lr_scheduler import StepLR, LambdaLR
from torch.utils.data import DataLoader, Dataset, Sampler, RandomSampler
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchvision.transforms
import separableconv.nn

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger
from neptune.new.types import File

In [None]:
# Check if its possible to use GPU/CUDA
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", torch.cuda.get_device_name(0))
    gc.collect()
    torch.cuda.empty_cache()
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "garbage_collection_threshold:0.9,max_split_size_mb:128"
else:
    device = torch.device("cpu")
    print("Using CPU")
print(torch.version.cuda)
print(torch.version)
print(torchvision.version)

In [None]:
params = {
    "train_batch_size": 16,
    "val_batch_size": 16,
    "test_batch_size": 16,
    "lr": 0.005,
    "momentum": 0.9,
    "weight_decay": 5e-4,
    "max_epochs": 200,
    "schedule": [50, 100, 150],
    "gamma": 0.2
}

In [None]:
train_dir = 'E:/WOW_04BPP/train'
val_dir = 'E:/WOW_04BPP/val'
test_dir = 'E:/WOW_04BPP/test'

class Data(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor()
        ])
        self.num_workers = 0
        torch.manual_seed(42)

    def prepare_data(self):
        self.train = ImageFolder(root=train_dir, transform=self.transform)
        self.val = ImageFolder(root=val_dir, transform=self.transform)
        self.test = ImageFolder(root=test_dir, transform=self.transform)

    def setup(self, stage: str):
        if stage == "fit":
            train_indices = list(range(len(self.train)))
            train_sampler = RandomSampler(train_indices, replacement=False, num_samples=None, generator=torch.Generator().manual_seed(42))
            self.train_set = DataLoader(self.train, batch_size=params["train_batch_size"], sampler=train_sampler,
                                         num_workers=self.num_workers, pin_memory=True, drop_last=True) 
            val_indices = list(range(len(self.val)))
            val_sampler = RandomSampler(val_indices, replacement=False, num_samples=None, generator=torch.Generator().manual_seed(42))
            self.val_set = DataLoader(self.val, batch_size=params["val_batch_size"], sampler=val_sampler,
                                    num_workers=self.num_workers, pin_memory=True, drop_last=True)
            
        if stage == "test": 
            test_indices = list(range(len(self.test)))
            test_sampler = RandomSampler(test_indices, replacement=False, num_samples=None, generator=torch.Generator().manual_seed(42))
            self.test_set = DataLoader(self.test, batch_size=params["test_batch_size"], sampler=test_sampler, 
                                    num_workers=self.num_workers, pin_memory=True, drop_last=True)
            
    def train_dataloader(self):
        return self.train_set

    def val_dataloader(self):
        return self.val_set

    def test_dataloader(self):
        return self.test_set
    
data_module = Data()

In [None]:
class ABS(nn.Module):
    def __init__(self):
        super(ABS, self).__init__()

    def forward(self, x):
        output = torch.abs(x)
        return output
    
class TLU(nn.Module):
    def __init__(self, threshold):
        super(TLU, self).__init__()
        self.threshold = threshold

    def forward(self, input):
        output = torch.clamp(input, min=-self.threshold, max=self.threshold)

        return output
    
class ScaleLayer(nn.Module):
    def __init__(self, num_features):
        super(ScaleLayer, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

    def forward(self, x):
        return x * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1)

def plot_roc(y, preds, pred_scores, is_test):
    fpr_0, tpr_0, _ = roc_curve(1 - y, 1 - pred_scores)
    roc_auc_0 = auc(fpr_0, tpr_0)

    fpr_1, tpr_1, _ = roc_curve(y, preds[:, 1])
    roc_auc_1 = auc(fpr_1, tpr_1)
    
    fig = plt.figure()
    plt.plot(fpr_0, tpr_0, label='Class 0: Cover (AUC = {:.2f})'.format(roc_auc_0))
    plt.plot(fpr_1, tpr_1, label='Class 1: Stego (AUC = {:.2f})'.format(roc_auc_1))
    plt.plot([0, 1], [0, 1], linestyle='--', label='Random (AUC = 0.5)')
    plt.title("Receiver Operating Characteristics")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    plt.show()
    plt.close(fig)
    
    if is_test:
        neptune_logger.experiment["test_ROC"].append(File.as_image(fig))

def plot_confusion_matrix(confusion_mat, is_test):
    classes = ['Cover', 'Stego']

    confusion_mat = confusion_mat.astype('float') / confusion_mat.sum(axis=1)[:, np.newaxis]
    fig = plt.figure()
    plt.imshow(confusion_mat, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Normalized Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes)
    plt.yticks(tick_marks, classes)

    for i in range(confusion_mat.shape[0]):
        for j in range(confusion_mat.shape[1]):
            plt.text(j, i, format(confusion_mat[i, j], '.2f'),
                     horizontalalignment="center",
                     color="white" if confusion_mat[i, j] > 0.5 else "black")

    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()
    plt.close(fig)

    if is_test:
        neptune_logger.experiment["test_CM"].append(File.as_image(fig))

def Tanh3(x):
    tanh3 = 3 * torch.tanh(x)
    return tanh3
    

In [None]:
srm_weights = np.load('SRM_filters.npy')*(1/12)
biasSRM = np.ones(30)

srm_weights_tensor = torch.from_numpy(srm_weights).permute(3,2,0,1).cuda()
biasSRM_tensor = torch.from_numpy(biasSRM).cuda().float()

### Define CNN architecture ###
class Yedroudj(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.val_ytrue = []
        self.val_ypred = []
        self.test_ytrue = []
        self.test_ypred = []
        self.tlu_3 = TLU(3.0)
        self.tlu_1 = TLU(1.0)
        self.abs = ABS()
        self.relu = nn.ReLU()
        
        self.srm_weights = torch.nn.Parameter(srm_weights_tensor, requires_grad=False)
        self.biasSRM = torch.nn.Parameter(biasSRM_tensor, requires_grad=True)

        self.bn = nn.BatchNorm2d(num_features=30)

        self.conv1 = nn.Conv2d(in_channels=30, out_channels=30, kernel_size=5, stride=1, padding="same")
        self.bn1 = nn.BatchNorm2d(num_features=30)

        self.conv2 = nn.Conv2d(in_channels=30, out_channels=30, kernel_size=5, stride=1, padding="same")
        self.bn2 = nn.BatchNorm2d(num_features=30)
       
        self.conv3 = nn.Conv2d(in_channels=30, out_channels=32, kernel_size=3, stride=1, padding="same")
        self.bn3 = nn.BatchNorm2d(num_features=32)

        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding="same")
        self.bn4 = nn.BatchNorm2d(num_features=64)

        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding="same")
        self.bn5 = nn.BatchNorm2d(num_features=128)

        self.fc1 = torch.nn.Linear(128, 256)
        self.fc2 = torch.nn.Linear(256, 2)

    def init_weights(self, module):
        if type(module) == nn.Conv2d:
            if module.weight.requires_grad:
                nn.init.kaiming_normal_(module.weight.data, mode='fan_in', nonlinearity='relu')

        if type(module) == nn.Linear:
            nn.init.normal_(module.weight.data, mean=0, std=0.01)
            nn.init.constant_(module.bias.data, val=0)
        
    def forward(self, x):
        x = F.conv2d(x, self.srm_weights, self.biasSRM, stride=1, padding="same")
        x = self.relu(x)  
        x = self.bn(x)

        x = self.conv1(x)
        x = self.abs(x)
        x = self.bn1(x)
        x = self.tlu_3(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.tlu_1(x)
        x = nn.AvgPool2d(kernel_size=5, stride=2, padding=2)(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = nn.AvgPool2d(kernel_size=5, stride=2, padding=2)(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu(x)
        x = nn.AvgPool2d(kernel_size=5, stride=2, padding=2)(x)
      
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu(x)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(-1, 128)

        x = self.fc1(x)
        x = self.relu(x)

        x = self.fc2(x)
        x = F.softmax(x, dim=1)

        return x
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(), lr=params["lr"], momentum=params["momentum"], weight_decay=params["weight_decay"]
        )
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=params["schedule"], gamma=params["gamma"])

        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
       
        target = F.one_hot(y, num_classes=2).float()
        preds = pred.float()

        loss = F.binary_cross_entropy(preds, target)
        self.log("train/loss", loss, on_step=False, on_epoch=True)
        
        y_true = y
        y_pred = pred.argmax(dim=1)

        acc = accuracy_score(y_true.cpu().numpy(), y_pred.cpu().numpy())
        self.log("train/acc", acc, on_step=False, on_epoch=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)

        target = F.one_hot(y, num_classes=2).float()
        preds = pred.float()
        
        loss = F.binary_cross_entropy(preds, target)
        self.log("val/loss", loss, on_step=False, on_epoch=True)

        y_true = y
        y_pred = pred.argmax(dim=1)

        acc = accuracy_score(y_true.cpu().numpy(), y_pred.cpu().numpy())
        self.log("val/acc", acc, on_step=False, on_epoch=True)

        self.val_ytrue.append(y.detach())
        self.val_ypred.append(pred.detach())

        return loss

    def on_validation_epoch_end(self):
        y = torch.stack(self.val_ytrue).cpu().numpy()
        preds = torch.stack(self.val_ypred).cpu().numpy()
        
        y = torch.tensor(y).reshape(-1)
        preds = torch.tensor(preds).reshape(-1, 2)

        pred_scores = preds[:, 1]

        plot_roc(y, preds, pred_scores, False)
            
        cm_pred = (pred_scores.numpy() >= 0.5).astype(int)
        cm = confusion_matrix(y, cm_pred)
        plot_confusion_matrix(cm, False)
            
        self.val_ytrue.clear()
        self.val_ypred.clear()
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)

        target = F.one_hot(y, num_classes=2).float()
        preds = pred.float()
        
        loss = F.binary_cross_entropy(preds, target)
        self.log("test/loss", loss, on_step=False, on_epoch=True)

        self.test_ytrue.append(y.detach())
        self.test_ypred.append(pred.detach())

        return loss

    def on_test_epoch_end(self):
        y = torch.stack(self.test_ytrue).cpu().numpy()
        preds = torch.stack(self.test_ypred).cpu().numpy()
        
        y = torch.tensor(y).reshape(-1)
        preds = torch.tensor(preds).reshape(-1, 2)

        acc_metric = torchmetrics.Accuracy(task='multiclass', num_classes=2)
        acc = acc_metric(preds, y)
        self.log("test/acc", acc)

        prec_metric = torchmetrics.Precision(task='multiclass', num_classes=2)
        prec = prec_metric(preds, y)
        self.log("test/prec", prec)

        recall_metric = torchmetrics.Recall(task='multiclass', num_classes=2)
        recall = recall_metric(preds, y)
        self.log("test/recall", recall)
        
        f1_metric = torchmetrics.F1Score(task='multiclass', num_classes=2)
        f1 = f1_metric(preds, y)
        self.log("test/f1", f1)

        pred_scores = preds[:, 1]
    
        plot_roc(y, preds, pred_scores, True)
            
        cm_pred = (pred_scores.numpy() >= 0.5).astype(int)
        cm = confusion_matrix(y, cm_pred)
        plot_confusion_matrix(cm, True)
            
        self.test_ytrue.clear()
        self.test_ypred.clear()
    

In [None]:
# Create NeptunLogger
neptune_logger = NeptuneLogger(
    api_key="", # you must enter your own account api key generated by Neptune.ai
    project="", # same with the project name etc.
    project="remmarty/Yedroudj-WOW",
    prefix="experiment",
    tags=["BossBase_8/1/1, srm filters"],
    log_model_checkpoints=True, 
    capture_hardware_metrics=False,
    capture_stdout=False,
    source_files="yedroudj_2.ipynb"
)

run_id = neptune_logger.run["sys/id"].fetch()
folder_name = run_id
log_path = "E:/YEDROUDJ_WOW_CHECKPOINTS"

path = os.path.join(log_path, folder_name)
os.makedirs(path, exist_ok=True)

# Create learning rate logger
lr_logger = LearningRateMonitor(logging_interval="epoch")

# Create model checkpointing object
model_checkpoint = ModelCheckpoint(
    dirpath=path,
    mode="max",
    monitor="val/acc",
    save_weights_only=False,
    save_top_k=100,
    save_last=True,
    every_n_epochs=1,
)

# Initialize a trainer and pass neptune_logger
trainer = pl.Trainer(
    default_root_dir=path,
    logger=neptune_logger,
    callbacks=[lr_logger, model_checkpoint],
    log_every_n_steps=150,
    accelerator="cuda",
    devices=1,
    max_epochs=params["max_epochs"],
    val_check_interval=1.0,
    enable_progress_bar=True,
    num_sanity_val_steps=0,
)

In [None]:
model = Yedroudj()

In [None]:
# Log model summary
neptune_logger.log_model_summary(model=model, max_depth=-1)

# Log hyperparameters
neptune_logger.log_hyperparams(params=params)

In [None]:
# Train
trainer.fit(model, datamodule=data_module)

In [None]:
# Test
best_model_path = model_checkpoint.best_model_path
print(best_model_path)
best_model = Yedroudj.load_from_checkpoint(best_model_path)
trainer.test(best_model, datamodule=data_module)

neptune_logger.experiment.stop()

torch.cuda.empty_cache()
gc.collect()