In [None]:
import torch
from torch import nn
from importlib import reload
import cv2
import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from prettytable import PrettyTable

import albumentations
from albumentations import augmentations
import albumentations.pytorch

In [None]:
full_df = pd.read_csv('combo_all_FULL.csv')

In [None]:
"CMFD, NIST, COVERAGE, CASIA, IMD"

def get_individual_test_df(dataframe, val):
    test_df = dataframe[dataframe['root_dir'].str.contains(val)]
    test_df = test_df[test_df["fold"].isin([1])]

    print(
        "{}: real:{}, fakes:{}".format(
            val, len(test_df[test_df["label"] == 0]), len(test_df[test_df["label"] == 1])
        )
    )
    return test_df.values
    
casia_test = get_individual_test_df(full_df, "CASIA")
casia_test = get_individual_test_df(full_df, "IMD")
casia_test = get_individual_test_df(full_df, "COVERAGE")
casia_test = get_individual_test_df(full_df, "NIST")

In [None]:
root_folder = "Image_Manipulation_Dataset"
def load_images(row):

    image_patch, mask_patch, label, _, ela, root_dir = row

    #------------- Load image, Ela, Mask -------------------------
    image_path = os.path.join(root_folder, root_dir, image_patch)
    ela_path = os.path.join(root_folder, root_dir, ela)

    image = cv2.imread(image_path, cv2.IMREAD_COLOR)
    ela_image = cv2.imread(ela_path, cv2.IMREAD_COLOR)

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    ela_image = cv2.cvtColor(ela_image, cv2.COLOR_BGR2RGB)

    if not isinstance(mask_patch, str) and np.isnan(mask_patch):
        mask_image = np.zeros((image.shape[0], image.shape[1])).astype('uint8')
    else:
        mask_path = os.path.join(root_folder, root_dir, mask_patch)
        mask_image = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    if('NIST' in root_dir):
        mask_image = 255 - mask_image

    image = augmentations.geometric.functional.resize(image, 256, 256, cv2.INTER_AREA)
    mask_image = augmentations.geometric.functional.resize(mask_image, 256, 256, cv2.INTER_AREA)
    ela_image = augmentations.geometric.functional.resize(ela_image, 256, 256, cv2.INTER_AREA)

    return image, ela_image, mask_image, label



def get_tensors(image, ela_image, mask_image):

    #---------------- Reshape & Normalize -----------------------
    image = augmentations.geometric.functional.resize(image, 256, 256, cv2.INTER_AREA)
    mask_image = augmentations.geometric.functional.resize(mask_image, 256, 256, cv2.INTER_AREA)
    ela_image = augmentations.geometric.functional.resize(ela_image, 256, 256, cv2.INTER_AREA)

    normalize = {
        "mean": [0.4535408213875562, 0.42862278450748387, 0.41780105499276865],
        "std": [0.2672804038612597, 0.2550410416463668, 0.29475415579144293],
    }

    transforms_normalize = albumentations.Compose(
        [
            albumentations.Normalize(mean=normalize['mean'], std=normalize['std'], always_apply=True, p=1),
            albumentations.pytorch.transforms.ToTensorV2()
        ],
        additional_targets={'ela':'image'}
    )

    data = transforms_normalize(image=image, mask=mask_image, ela=ela_image)
    image_tensor = data["image"].unsqueeze(0)
    mask_tensor = (data["mask"] / 255.0).unsqueeze(0).unsqueeze(0)
    ela_tensor = data["ela"].unsqueeze(0)
    
    return image_tensor, ela_tensor, mask_tensor
    