In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
import os
import random
import numpy as np
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path_real_attack = '/home/taha/Taha26/All_Experiments/FID/OULU_NPU/test_samples/real_attack_360'
data_path_real_bonafide = '/home/taha/Taha26/All_Experiments/FID/OULU_NPU/test_samples/real_bonafide_360'
data_path_fake_bonafide = '/home/taha/Taha26/All_Experiments/FID/OULU_NPU/test_samples/fake_bonafide_360'
data_path_fake_attack = '/home/taha/Taha26/All_Experiments/FID/OULU_NPU/test_samples/fake_attack_360'

In [3]:
def load_samples_image(image_path, transform):
    image = Image.open(image_path).convert('RGB') # It uses PIL (Pillow) library to open the image, convert it to the RGB mode
    sample = (transform(image)) # Apply transformation
    return sample

class ImageDataset(Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
        self.image_files = [file for file in os.listdir(data_path) if file.endswith(('.png', '.jpg', '.jpeg'))]
        self.data_length = len(self.image_files)
        self.transform = transforms.Compose([transforms.Resize((256, 256)),
                                             transforms.ToTensor()])

    def __len__(self):
        return self.data_length

    def __getitem__(self, idx):
        file = self.image_files[idx]
        path = os.path.join(self.data_path, file)
        sample = load_samples_image(path, self.transform)
        return sample

In [4]:
real_attack = ImageDataset(data_path_real_attack)
real_bonafide = ImageDataset(data_path_real_bonafide)
fake_bonafide = ImageDataset(data_path_fake_bonafide)
fake_attack = ImageDataset(data_path_fake_attack)

In [5]:
print("Total No of Real Attack Samples:", len(real_attack))
print("Total No of Real Bonafide Samples:", len(real_bonafide))
print("Total No of Fake Bonafide Samples:", len(fake_bonafide))
print("Total No of Fake Attack Samples:", len(fake_attack))

Total No of Real Attack Samples: 360
Total No of Real Bonafide Samples: 360
Total No of Fake Bonafide Samples: 360
Total No of Fake Attack Samples: 360


In [6]:
first_dataloader = DataLoader(real_attack, batch_size=32, shuffle=False, num_workers=8)
second_dataloader = DataLoader(real_bonafide, batch_size=32, shuffle=False, num_workers=8)
third_dataloader = DataLoader(fake_bonafide, batch_size=32, shuffle=False, num_workers=8)
fourth_dataloader = DataLoader(fake_attack, batch_size=32, shuffle=False, num_workers=8)

In [None]:
# for images in first_dataloader:
#     print(f"Images Shape: {images.shape}")

In [None]:
# for images in second_dataloader:
#     print(f"Images Shape: {images.shape}")

In [7]:
all_images_1 = None

#distribution 1
for img in first_dataloader:
    if all_images_1 is None:
        all_images_1 = img.detach().numpy()
    else:
        all_images_1 = np.concatenate((all_images_1, img.detach().numpy()), axis=0)

In [8]:
all_images_2 = None

#distribution 2
for img2 in second_dataloader:
    if all_images_2 is None:
        all_images_2 = img2.detach().numpy()
    else:
        all_images_2 = np.concatenate((all_images_2, img2.detach().numpy()), axis=0)

In [9]:
all_images_3 = None

#distribution 3
for img3 in third_dataloader:
    if all_images_3 is None:
        all_images_3 = img3.detach().numpy()
    else:
        all_images_3 = np.concatenate((all_images_3, img3.detach().numpy()), axis=0)

In [10]:
all_images_4 = None

#distribution 4
for img4 in fourth_dataloader:
    if all_images_4 is None:
        all_images_4 = img4.detach().numpy()
    else:
        all_images_4 = np.concatenate((all_images_4, img4.detach().numpy()), axis=0)

In [11]:
print(all_images_1.shape)
print(all_images_2.shape)
print(all_images_3.shape)
print(all_images_4.shape)

(360, 3, 256, 256)
(360, 3, 256, 256)
(360, 3, 256, 256)
(360, 3, 256, 256)


In [12]:
flat_all_images_1 = all_images_1.reshape(-1)
flat_all_images_2 = all_images_2.reshape(-1)
flat_all_images_3 = all_images_3.reshape(-1)
flat_all_images_4 = all_images_4.reshape(-1)

In [13]:
print(flat_all_images_1.shape) # Real Attack
print(flat_all_images_2.shape) # Real Bonafide
print(flat_all_images_3.shape) # Fake Bonafide
print(flat_all_images_4.shape) # Fake Attack

(70778880,)
(70778880,)
(70778880,)
(70778880,)


In [24]:
from gmdm import OVL
overlap = OVL(flat_all_images_4, flat_all_images_3)

(26,)
(26,)
(26,)
(26,)


In [25]:
print(overlap)

0.8587579232675058
