In [None]:
## This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

## Importing the libraries

In [11]:
import torchvision.transforms.functional as tf
import pandas as pd
import os
import numpy as np
from skimage.util import view_as_windows
from skimage import io
import PIL
from glob import glob
import matplotlib.pyplot as plt
import warnings
import cv2
from skimage.metrics import structural_similarity

warnings.filterwarnings('ignore')

### Patch Extraction Utility Functions

In [23]:
def check_and_reshape(image, input_mask):
    """
    Gets an image reshapes it and returns it with its mask.
    :param image: The image
    :param input_mask: The mask of the image
    :returns: the image and its mask
    """
    try:
        mask_x, mask_y = input_mask.shape
        mask = np.empty((mask_x, mask_y, 3))
        mask[:, :, 0] = input_mask
        mask[:, :, 1] = input_mask
        mask[:, :, 2] = input_mask
    except ValueError:
        mask = input_mask
    if image.shape == mask.shape:
        return image, mask
    elif image.shape[0] == mask.shape[1] and image.shape[1] == mask.shape[0]:
        mask = np.reshape(mask, (image.shape[0], image.shape[1], mask.shape[2]))
        return image, mask
    else:
        return image,mask 

In [13]:
def extract_all_patches(image, window_shape, stride, num_of_patches, rotations, output_path, im_name, rep_num, mode):
    """
    Extracts all the patches from an image.
    :param image: The image
    :param window_shape: The shape of the window (for example (128,128,3) in the CASIA2 dataset)
    :param stride: The stride of the patch extraction
    :param num_of_patches: The amount of patches to be extracted per image
    :param rotations: The amount of rotations divided equally in 360 degrees
    :param output_path: The output path where the patches will be saved
    :param im_name: The name of the image
    :param rep_num: The amount of repetitions
    :param mode: If we account rotations 'rot' or nor 'no_rot'
    """
    non_tampered_windows = view_as_windows(image, window_shape, step=stride)
    non_tampered_patches = []
    for m in range(non_tampered_windows.shape[0]):
        for n in range(non_tampered_windows.shape[1]):
            non_tampered_patches += [non_tampered_windows[m][n][0]]
    # select random some patches, rotate and save them
    save_patches(non_tampered_patches, num_of_patches, mode, rotations, output_path, im_name, rep_num,
                 patch_type='authentic')

In [14]:
def delete_prev_images(dir_name):
    """
    Deletes all the file in a directory.
    :param dir_name: Directory name
    """
    for the_file in os.listdir(dir_name):
        file_path = os.path.join(dir_name, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(e)

In [15]:
def create_dirs(output_path):
    """
    Creates the directories to the output path.
    :param output_path: The output path
    """
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        os.makedirs(output_path + '/authentic')
        os.makedirs(output_path + '/tampered')
    else:
        if os.path.exists(output_path + '/authentic'):
            delete_prev_images(output_path + '/authentic')
        else:
            os.makedirs(output_path + '/authentic')
        if os.path.exists(output_path + '/tampered'):
            delete_prev_images(output_path + '/tampered')
        else:
            os.makedirs(output_path + '/tampered')

In [16]:
def save_patches(patches, num_of_patches, mode, rotations, output_path, im_name, rep_num, patch_type):
    """
    Saves all the extracted patches to the output path.
    :param patches: The extracted patches
    :param num_of_patches: The amount of patches to be extracted per image
    :param mode: If we account rotations 'rot' or nor 'no_rot'
    :param rotations: The amount of rotations divided equally in 360 degrees
    :param output_path: The output path where the patches will be saved
    :param im_name: The name of the image
    :param rep_num: The amount of repetitions
    :param patch_type: The mask of the image
    """
    inds = np.random.choice(len(patches), num_of_patches, replace=False)
    if mode == 'rot':
        for i, ind in enumerate(inds):
            image = patches[ind][0] if patch_type == 'tampered' else patches[ind]
            for angle in rotations:
                im_rt = tf.rotate(PIL.Image.fromarray(np.uint8(image)), angle=angle,
                                  resample=PIL.Image.BILINEAR)
                im_rt.save(output_path + '/{0}/{1}_{2}_{3}_{4}.png'.format(patch_type, im_name, i, angle, rep_num))
    else:
        for i, ind in enumerate(inds):
            image = patches[ind][0] if patch_type == 'tampered' else patches[ind]
            io.imsave(output_path + '/{0}/{1}_{2}_{3}.png'.format(patch_type, im_name, i, rep_num), image)

In [17]:
def find_tampered_patches(image, im_name, mask, window_shape, stride, dataset, patches_per_image):
    """
    Gets an image reshapes it and returns it with its mask.
    :param image: The image
    :param im_name: The name of the image
    :param mask: The mask of the image
    :param window_shape: The shape of the window (for example (128,128,3) in the CASIA2 dataset)
    :param stride: The stride of the patch extraction
    :param dataset: The name of the dataset
    :param patches_per_image: The amount of patches to be extracted per image
    :returns: the tampered patches and their amount
    """
    # extract patches from images and masks
    patches = view_as_windows(image, window_shape, step=stride)

    if dataset == 'casia2':
        mask_patches = view_as_windows(mask, window_shape, step=stride)
    elif dataset == 'nc16':
        mask_patches = view_as_windows(mask, (128, 128), step=stride)
    else:
        raise NotSupportedDataset('The datasets supported are casia2 and nc16')

    tampered_patches = []
    # find tampered patches
    for m in range(patches.shape[0]):
        for n in range(patches.shape[1]):
            im = patches[m][n][0]
            ma = mask_patches[m][n][0]
            num_zeros = (ma == 0).sum()
            num_ones = (ma == 255).sum()
            total = num_ones + num_zeros
            if dataset == 'casia2':
                if num_zeros <= 0.99 * total:
                    tampered_patches += [(im, ma)]
            elif dataset == 'nc16':
                if 0.80 * total >= num_ones >= 0.20 * total:
                    tampered_patches += [(im, ma)]

    # if patches are less than the given number then take the minimum possible
    num_of_patches = patches_per_image
    if len(tampered_patches) < num_of_patches:
        print("Number of tampered patches for image {} is only {}".format(im_name, len(tampered_patches)))
        num_of_patches = len(tampered_patches)

    return tampered_patches, num_of_patches

## Patch Extractor Class for CASIA 2 Dataset

In [18]:
class PatchExtractorCASIA:
    """
    Patch extraction class
    """

    def __init__(self, input_path, output_path, patches_per_image=4, rotations=8, stride=8, mode='no_rot'):
        """
        Initialize class
        :param patches_per_image: Number of samples to extract for each image
        :param rotations: Number of rotations to perform
        :param stride: Stride size to be used
        """
        self.patches_per_image = patches_per_image
        self.stride = stride
        rots = [0, 90, 180, 270]
        self.rotations = rots[:rotations]
        self.mode = mode
        self.input_path = input_path
        self.output_path = output_path

        # define the indices of the image names and read the authentic images
        self.background_index = [13, 21]
        au_index = [3, 6, 7, 12]
        au_pic_list = glob(self.input_path + os.sep + 'Au' + os.sep + '*')
        self.au_pic_dict = {
            au_pic.split(os.sep)[-1][au_index[0]:au_index[1]] + au_pic.split(os.sep)[-1][au_index[2]:au_index[3]]:
                au_pic for au_pic
            in au_pic_list}

    def extract_authentic_patches(self, sp_pic, num_of_patches, rep_num):
        """
        Extracts and saves the patches from the authentic image
        :param sp_pic: Name of tampered image
        :param num_of_patches: Number of patches to be extracted
        :param rep_num: Number of repetitions being done(just for the patch name)
        """
        sp_name = sp_pic.split('/')[-1][self.background_index[0]:self.background_index[1]]
        if sp_name in self.au_pic_dict.keys():
            au_name = self.au_pic_dict[sp_name].split(os.sep)[-1].split('.')[0]
            # define window size
            window_shape = (128, 128, 3)
            au_pic = self.au_pic_dict[sp_name]
            au_image = plt.imread(au_pic)
            # extract all patches
            extract_all_patches(au_image, window_shape, self.stride, num_of_patches, self.rotations, self.output_path,
                                au_name, rep_num, self.mode)

    def extract_patches(self):
        """
        Main function which extracts all patches
        :return:
        """
        # uncomment to extract masks
#         mask_path = 'masks'
#         if os.path.exists(mask_path) and os.path.isdir(mask_path):
#             if not os.listdir(mask_path):
#                 print("Extracting masks")
#                 extract_masks()
#                 print("Masks extracted")
#             else:
#                 print("Masks exist. Patch extraction begins...")
#         else:
#             os.makedirs(mask_path)
#             print("Extracting masks")
#             extract_masks()
#             print("Masks extracted")
#         #

        # create necessary directories
        create_dirs(self.output_path)

        # define window shape
        window_shape = (128, 128, 3)
        tp_dir = self.input_path+'/Tp/'
        rep_num = 0
        # run for all the tampered images
        for f in os.listdir(tp_dir):
            try:
                rep_num += 1
                image = io.imread(tp_dir + f)
                im_name = f.split(os.sep)[-1].split('.')[0]
                # read mask
                mask = io.imread(self.input_path + '/CASIA 2 Groundtruth/' + im_name + '_gt.png')
               
                image, mask = check_and_reshape(image, mask)

                # extract patches from images and masks
                tampered_patches, num_of_patches = find_tampered_patches(image, im_name, mask,
                                                                         window_shape, self.stride, 'casia2',
                                                                         self.patches_per_image)
                save_patches(tampered_patches, num_of_patches, self.mode, self.rotations, self.output_path, im_name,
                             rep_num, patch_type='tampered')
                self.extract_authentic_patches(tp_dir + f, num_of_patches, rep_num)
            except IOError as e:
                rep_num -= 1
                print(str(e))
            except IndexError:
                rep_num -= 1
                print('Mask and image have not the same dimensions')

## Patch Extraction Driver Code (CASIA Dataset)

In [24]:
pe = PatchExtractorCASIA(input_path='../input/casia-20-image-tampering-detection-dataset/CASIA2', output_path='patches_casia_with_rot',
                         patches_per_image=2, stride=128, rotations=4, mode='rot')
pe.extract_patches()

Number of tampered patches for image Tp_S_NRN_S_N_art00092_art00092_11809 is only 1
Number of tampered patches for image Tp_S_CRN_S_O_art00040_art00040_10463 is only 1
No such file: '/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_S_NRN_M_N_cha10114_nat10114_12181_gt.png'
Number of tampered patches for image Tp_D_NRN_S_N_pla00010_pla00008_10935 is only 1
Number of tampered patches for image Tp_S_NNN_S_N_cha10204_cha10204_12353 is only 1
Number of tampered patches for image Tp_S_NNN_S_B_art20011_art20011_02493 is only 0
Number of tampered patches for image Tp_S_CRN_S_N_sec00055_sec00055_11234 is only 1
Number of tampered patches for image Tp_D_NRN_S_N_art00014_art00092_11810 is only 1
Number of tampered patches for image Tp_S_NND_S_N_cha00092_cha00092_00412 is only 1
Number of tampered patches for image Tp_S_NNN_S_N_sec20053_sec20053_01643 is only 1
Number of tampered patches for image Tp_D_CRD_S_N_ind00074_cha00049_00474 is only 1
Number of tamper

ValueError: Could not find a backend to open `../input/casia-20-image-tampering-detection-dataset/CASIA2/Tp/_list.txt`` with iomode `ri`.