## Importing the libraries

In [1]:
## This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_S_NND_S_N_cha00076_cha00076_10189_gt.png
/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_D_NNN_M_N_art00099_cha00050_11760_gt.png
/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_S_CNN_S_N_pla10127_pla10127_12137_gt.png
/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_S_NRN_S_N_ind00038_ind00038_10913_gt.png
/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_D_CRN_M_N_ani00100_ani10101_10089_gt.png
/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_D_NRN_S_N_art00010_art00014_11842_gt.png
/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_S_NNN_S_N_cha00002_cha00002_00837_gt.png
/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_S_CRN_M_N_sec0009

In [2]:
import torchvision.transforms.functional as tf
import pandas as pd
import os
import numpy as np
from skimage.util import view_as_windows
from skimage import io
import PIL
from glob import glob
import matplotlib.pyplot as plt
import warnings
import cv2
from skimage.metrics import structural_similarity

warnings.filterwarnings('ignore')

### Patch Extraction Utility Functions

In [3]:
def check_and_reshape(image, input_mask):
    """
    Gets an image reshapes it and returns it with its mask.
    :param image: The image
    :param input_mask: The mask of the image
    :returns: the image and its mask
    """
    try:
        mask_x, mask_y = input_mask.shape
        mask = np.empty((mask_x, mask_y, 3))
        mask[:, :, 0] = input_mask
        mask[:, :, 1] = input_mask
        mask[:, :, 2] = input_mask
    except ValueError:
        mask = input_mask
    if image.shape == mask.shape:
        return image, mask
    elif image.shape[0] == mask.shape[1] and image.shape[1] == mask.shape[0]:
        mask = np.reshape(mask, (image.shape[0], image.shape[1], mask.shape[2]))
        return image, mask
    else:
        return image,mask 

In [4]:
def extract_all_patches(image, window_shape, stride, num_of_patches, rotations, output_path, im_name, rep_num, mode):
    """
    Extracts all the patches from an image.
    :param image: The image
    :param window_shape: The shape of the window (for example (128,128,3) in the CASIA2 dataset)
    :param stride: The stride of the patch extraction
    :param num_of_patches: The amount of patches to be extracted per image
    :param rotations: The amount of rotations divided equally in 360 degrees
    :param output_path: The output path where the patches will be saved
    :param im_name: The name of the image
    :param rep_num: The amount of repetitions
    :param mode: If we account rotations 'rot' or nor 'no_rot'
    """
    non_tampered_windows = view_as_windows(image, window_shape, step=stride)
    non_tampered_patches = []
    for m in range(non_tampered_windows.shape[0]):
        for n in range(non_tampered_windows.shape[1]):
            non_tampered_patches += [non_tampered_windows[m][n][0]]
    # select random some patches, rotate and save them
    save_patches(non_tampered_patches, num_of_patches, mode, rotations, output_path, im_name, rep_num,
                 patch_type='authentic')

In [5]:
def delete_prev_images(dir_name):
    """
    Deletes all the file in a directory.
    :param dir_name: Directory name
    """
    for the_file in os.listdir(dir_name):
        file_path = os.path.join(dir_name, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(e)

In [6]:
def create_dirs(output_path):
    """
    Creates the directories to the output path.
    :param output_path: The output path
    """
    if not os.path.exists(output_path):
        os.makedirs(output_path)
        os.makedirs(output_path + '/authentic')
        os.makedirs(output_path + '/tampered')
    else:
        if os.path.exists(output_path + '/authentic'):
            delete_prev_images(output_path + '/authentic')
        else:
            os.makedirs(output_path + '/authentic')
        if os.path.exists(output_path + '/tampered'):
            delete_prev_images(output_path + '/tampered')
        else:
            os.makedirs(output_path + '/tampered')

In [7]:
def save_patches(patches, num_of_patches, mode, rotations, output_path, im_name, rep_num, patch_type):
    """
    Saves all the extracted patches to the output path.
    :param patches: The extracted patches
    :param num_of_patches: The amount of patches to be extracted per image
    :param mode: If we account rotations 'rot' or nor 'no_rot'
    :param rotations: The amount of rotations divided equally in 360 degrees
    :param output_path: The output path where the patches will be saved
    :param im_name: The name of the image
    :param rep_num: The amount of repetitions
    :param patch_type: The mask of the image
    """
    inds = np.random.choice(len(patches), num_of_patches, replace=False)
    if mode == 'rot':
        for i, ind in enumerate(inds):
            image = patches[ind][0] if patch_type == 'tampered' else patches[ind]
            for angle in rotations:
                im_rt = tf.rotate(PIL.Image.fromarray(np.uint8(image)), angle=angle,
                                  resample=PIL.Image.BILINEAR)
                im_rt.save(output_path + '/{0}/{1}_{2}_{3}_{4}.png'.format(patch_type, im_name, i, angle, rep_num))
    else:
        for i, ind in enumerate(inds):
            image = patches[ind][0] if patch_type == 'tampered' else patches[ind]
            io.imsave(output_path + '/{0}/{1}_{2}_{3}.png'.format(patch_type, im_name, i, rep_num), image)

In [8]:
def find_tampered_patches(image, im_name, mask, window_shape, stride, dataset, patches_per_image):
    """
    Gets an image reshapes it and returns it with its mask.
    :param image: The image
    :param im_name: The name of the image
    :param mask: The mask of the image
    :param window_shape: The shape of the window (for example (128,128,3) in the CASIA2 dataset)
    :param stride: The stride of the patch extraction
    :param dataset: The name of the dataset
    :param patches_per_image: The amount of patches to be extracted per image
    :returns: the tampered patches and their amount
    """
    # extract patches from images and masks
    patches = view_as_windows(image, window_shape, step=stride)

    if dataset == 'casia2':
        mask_patches = view_as_windows(mask, window_shape, step=stride)
    elif dataset == 'nc16':
        mask_patches = view_as_windows(mask, (128, 128), step=stride)
    else:
        raise NotSupportedDataset('The datasets supported are casia2 and nc16')

    tampered_patches = []
    # find tampered patches
    for m in range(patches.shape[0]):
        for n in range(patches.shape[1]):
            im = patches[m][n][0]
            ma = mask_patches[m][n][0]
            num_zeros = (ma == 0).sum()
            num_ones = (ma == 255).sum()
            total = num_ones + num_zeros
            if dataset == 'casia2':
                if num_zeros <= 0.99 * total:
                    tampered_patches += [(im, ma)]
            elif dataset == 'nc16':
                if 0.80 * total >= num_ones >= 0.20 * total:
                    tampered_patches += [(im, ma)]

    # if patches are less than the given number then take the minimum possible
    num_of_patches = patches_per_image
    if len(tampered_patches) < num_of_patches:
        print("Number of tampered patches for image {} is only {}".format(im_name, len(tampered_patches)))
        num_of_patches = len(tampered_patches)

    return tampered_patches, num_of_patches

## Patch Extractor Class for CASIA 2 Dataset

In [9]:
class PatchExtractorCASIA:
    """
    Patch extraction class
    """

    def __init__(self, input_path, output_path, patches_per_image=4, rotations=8, stride=8, mode='no_rot'):
        """
        Initialize class
        :param patches_per_image: Number of samples to extract for each image
        :param rotations: Number of rotations to perform
        :param stride: Stride size to be used
        """
        self.patches_per_image = patches_per_image
        self.stride = stride
        rots = [0, 90, 180, 270]
        self.rotations = rots[:rotations]
        self.mode = mode
        self.input_path = input_path
        self.output_path = output_path

        # define the indices of the image names and read the authentic images
        self.background_index = [13, 21]
        au_index = [3, 6, 7, 12]
        au_pic_list = glob(self.input_path + os.sep + 'Au' + os.sep + '*')
        self.au_pic_dict = {
            au_pic.split(os.sep)[-1][au_index[0]:au_index[1]] + au_pic.split(os.sep)[-1][au_index[2]:au_index[3]]:
                au_pic for au_pic
            in au_pic_list}

    def extract_authentic_patches(self, sp_pic, num_of_patches, rep_num):
        """
        Extracts and saves the patches from the authentic image
        :param sp_pic: Name of tampered image
        :param num_of_patches: Number of patches to be extracted
        :param rep_num: Number of repetitions being done(just for the patch name)
        """
        sp_name = sp_pic.split('/')[-1][self.background_index[0]:self.background_index[1]]
        if sp_name in self.au_pic_dict.keys():
            au_name = self.au_pic_dict[sp_name].split(os.sep)[-1].split('.')[0]
            # define window size
            window_shape = (128, 128, 3)
            au_pic = self.au_pic_dict[sp_name]
            au_image = plt.imread(au_pic)
            # extract all patches
            extract_all_patches(au_image, window_shape, self.stride, num_of_patches, self.rotations, self.output_path,
                                au_name, rep_num, self.mode)

    def extract_patches(self):
        """
        Main function which extracts all patches
        :return:
        """
        # uncomment to extract masks
#         mask_path = 'masks'
#         if os.path.exists(mask_path) and os.path.isdir(mask_path):
#             if not os.listdir(mask_path):
#                 print("Extracting masks")
#                 extract_masks()
#                 print("Masks extracted")
#             else:
#                 print("Masks exist. Patch extraction begins...")
#         else:
#             os.makedirs(mask_path)
#             print("Extracting masks")
#             extract_masks()
#             print("Masks extracted")
#         #

        # create necessary directories
        create_dirs(self.output_path)

        # define window shape
        window_shape = (128, 128, 3)
        tp_dir = self.input_path+'/Tp/'
        rep_num = 0
        # run for all the tampered images
        for f in os.listdir(tp_dir):
            try:
                if f == 'Thumbs.db' or f == '_list.txt':
                    continue
                rep_num += 1
                image = io.imread(tp_dir + f)
                im_name = f.split(os.sep)[-1].split('.')[0]
                # read mask
                mask = io.imread(self.input_path + '/CASIA 2 Groundtruth/' + im_name + '_gt.png')
               
                image, mask = check_and_reshape(image, mask)

                # extract patches from images and masks
                tampered_patches, num_of_patches = find_tampered_patches(image, im_name, mask,
                                                                         window_shape, self.stride, 'casia2',
                                                                         self.patches_per_image)
                save_patches(tampered_patches, num_of_patches, self.mode, self.rotations, self.output_path, im_name,
                             rep_num, patch_type='tampered')
                self.extract_authentic_patches(tp_dir + f, num_of_patches, rep_num)
            except IOError as e:
                rep_num -= 1
                print(str(e))
            except IndexError:
                rep_num -= 1
                print('Mask and image have not the same dimensions')

## Patch Extraction Driver Code (CASIA Dataset)

In [10]:
pe = PatchExtractorCASIA(input_path='../input/casia-20-image-tampering-detection-dataset/CASIA2', output_path='patches_casia_with_rot',
                         patches_per_image=2, stride=128, rotations=4, mode='rot')
pe.extract_patches()

Number of tampered patches for image Tp_S_NRN_S_N_art00092_art00092_11809 is only 1
Number of tampered patches for image Tp_S_CRN_S_O_art00040_art00040_10463 is only 1
No such file: '/kaggle/input/casia-20-image-tampering-detection-dataset/CASIA2/CASIA 2 Groundtruth/Tp_S_NRN_M_N_cha10114_nat10114_12181_gt.png'
Number of tampered patches for image Tp_D_NRN_S_N_pla00010_pla00008_10935 is only 1
Number of tampered patches for image Tp_S_NNN_S_N_cha10204_cha10204_12353 is only 1
Number of tampered patches for image Tp_S_NNN_S_B_art20011_art20011_02493 is only 0
Number of tampered patches for image Tp_S_CRN_S_N_sec00055_sec00055_11234 is only 1
Number of tampered patches for image Tp_D_NRN_S_N_art00014_art00092_11810 is only 1
Number of tampered patches for image Tp_S_NND_S_N_cha00092_cha00092_00412 is only 1
Number of tampered patches for image Tp_S_NNN_S_N_sec20053_sec20053_01643 is only 1
Number of tampered patches for image Tp_D_CRD_S_N_ind00074_cha00049_00474 is only 1
Number of tamper

## Training the CNN model

### Filters

In [11]:
from typing import Dict

import numpy as np
from torch import Tensor, stack


def get_filters():
    """
    Function that return the required high pass SRM filters for the first convolutional layer of our implementation
    :return: A pytorch Tensor containing the 30x3x5x5 filter tensor with type
    [number_of_filters, input_channels, height, width]
    """

    filters: Dict[str, Tensor] = {}

    # 1st Order
    filters["1O1"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, -1, 1, 0], [0, 0, 0, 0, 0],
                                      [0, 0, 0, 0, 0]]))
    filters["1O2"] = Tensor(np.rot90(filters["1O1"]).copy())
    filters["1O3"] = Tensor(np.rot90(filters["1O2"]).copy())
    filters["1O4"] = Tensor(np.rot90(filters["1O3"]).copy())
    filters["1O5"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, -1, 0, 0], [0, 0, 0, 0, 0],
                                      [0, 0, 0, 0, 0]]))
    filters["1O6"] = Tensor(np.rot90(filters["1O5"]).copy())
    filters["1O7"] = Tensor(np.rot90(filters["1O6"]).copy())
    filters["1O8"] = Tensor(np.rot90(filters["1O7"]).copy())
    # 2nd Order
    filters["2O1"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, -2, 1, 0], [0, 0, 0, 0, 0],
                                      [0, 0, 0, 0, 0]]))
    filters["2O2"] = Tensor(np.rot90(filters["2O1"]).copy())
    filters["2O3"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, -2, 0, 0], [0, 0, 0, 1, 0],
                                      [0, 0, 0, 0, 0]]))
    filters["2O4"] = Tensor(np.rot90(filters["2O3"]).copy())
    # 3rd Order
    filters["3O1"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 1, -3, 1, 0], [0, 0, 0, 0, 0],
                                      [0, 0, 0, 0, 0]]))
    filters["3O2"] = Tensor(np.rot90(filters["3O1"]).copy())
    filters["3O3"] = Tensor(np.rot90(filters["3O2"]).copy())
    filters["3O4"] = Tensor(np.rot90(filters["3O3"]).copy())
    filters["3O5"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, -3, 0, 0], [0, 1, 0, 0, 0],
                                      [0, 0, 0, 0, 0]]))
    filters["3O6"] = Tensor(np.rot90(filters["3O5"]).copy())
    filters["3O7"] = Tensor(np.rot90(filters["3O6"]).copy())
    filters["3O8"] = Tensor(np.rot90(filters["3O7"]).copy())
    # 3x3 SQUARE
    filters["3x3S"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, -1, 2, -1, 0], [0, 2, -4, 2, 0], [0, -1, 2, -1, 0],
                                       [0, 0, 0, 0, 0]]))
    # 3x3 EDGE
    filters["3x3E1"] = Tensor(np.array([[0, 0, 0, 0, 0], [0, -1, 2, -1, 0], [0, 2, -4, 2, 0], [0, 0, 0, 0, 0],
                                        [0, 0, 0, 0, 0]]))
    filters["3x3E2"] = Tensor(np.rot90(filters["3x3E1"]).copy())
    filters["3x3E3"] = Tensor(np.rot90(filters["3x3E2"]).copy())
    filters["3x3E4"] = Tensor(np.rot90(filters["3x3E3"]).copy())
    # 5X5 EDGE
    filters["5x5E1"] = Tensor(np.array([[-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2],
                                        [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))
    filters["5x5E2"] = Tensor(np.rot90(filters["5x5E1"]).copy())
    filters["5x5E3"] = Tensor(np.rot90(filters["5x5E2"]).copy())
    filters["5x5E4"] = Tensor(np.rot90(filters["5x5E3"]).copy())
    # 5x5 SQUARE
    filters["5x5S"] = Tensor(np.array([[-1, 2, -2, 2, -1], [2, -6, 8, -6, 2], [-2, 8, -12, 8, -2],
                                       [2, -6, 8, -6, 2], [-1, 2, -2, 2, -1]]))

    return vectorize_filters(filters)


def vectorize_filters(filters: dict):
    """
    Function that takes as input the 30x5x5 different SRM high pass filters and creates the 30x3x5x5 tensor with the
    following permutations 𝑾𝑗 = [𝑊3𝑘−2 𝑊3𝑘−1 𝑊3𝑘] where 𝑘 = ((𝑗 − 1) mod 10) + 1 and (𝑗 = 1, ⋅ ⋅ ⋅ , 30).
    :arg filters: The 30 SRM high pass filters
    :return: Returns the 30x3x5x5 filter tensor of the type [number_of_filters, input_channels, height, width]
    """
    tensor_list = []

    w = list(filters.values())

    for i in range(1, 31):
        tmp = []

        k = ((i - 1) % 10) + 1

        tmp.append(w[3 * k - 3])
        tmp.append(w[3 * k - 2])
        tmp.append(w[3 * k - 1])

        tensor_list.append(stack(tmp))

    return stack(tensor_list)


### CNN Class

In [12]:
import torch.nn.functional as f
import torch.nn as nn


class CNN(nn.Module):
    """
    The convolutional neural network (CNN) class
    """
    def __init__(self):
        """
        Initialization of all the layers in the network.
        """
        super(CNN, self).__init__()

        self.conv0 = nn.Conv2d(3, 3, kernel_size=5, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv0.weight)

        self.conv1 = nn.Conv2d(3, 30, kernel_size=5, stride=2, padding=0)
        self.conv1.weight = nn.Parameter(get_filters())

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.conv2 = nn.Conv2d(30, 16, kernel_size=3, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv2.weight)

        self.conv3 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv3.weight)

        self.conv4 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv4.weight)

        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

        self.conv5 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv5.weight)

        self.conv6 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv6.weight)

        self.conv7 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv7.weight)

        self.conv8 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=0)
        nn.init.xavier_uniform_(self.conv8.weight)

        self.fc = nn.Linear(16 * 5 * 5, 2)

        self.drop1 = nn.Dropout(p=0.5)  # used only for the NC dataset

    def forward(self, x):
        """
        The forward step of the network that consumes an image patch and either uses a fully connected layer in the
        training phase with a softmax or just returns the feature map after the final convolutional layer.
        :returns: Either the output of the softmax during training or the 400-D feature representation at testing
        """
        x = f.relu(self.conv0(x))
        x = f.relu(self.conv1(x))
        lrn = nn.LocalResponseNorm(3)
        x = lrn(x)
        x = self.pool1(x)
        x = f.relu(self.conv2(x))
        x = f.relu(self.conv3(x))
        x = f.relu(self.conv4(x))
        x = f.relu(self.conv5(x))
        x = lrn(x)
        x = self.pool2(x)
        x = f.relu(self.conv6(x))
        x = f.relu(self.conv7(x))
        x = f.relu(self.conv8(x))
        x = x.view(-1, 16 * 5 * 5)

        # In the training phase we also need the fully connected layer with softmax
        if self.training:
            # x = self.drop1(x) # used only for the NC dataset
            x = f.relu(self.fc(x))
            x = f.softmax(x, dim=1)

        return x


### Script to train the CNN model

In [13]:
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
import time
import numpy as np


def create_loss_and_optimizer(net, learning_rate=0.01):
    """
    Creates the loss function and optimizer of the network.
    :param net: The network object
    :param learning_rate: The initial learning rate
    :returns: The loss function and the optimizer
    """
    loss = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.99, weight_decay=5 * 1e-4)
    return loss, optimizer


def train_net(net, train_set, n_epochs, learning_rate, batch_size):
    """
    Training of the CNN
    :param net: The CNN object
    :param train_set: The training part of the dataset
    :param n_epochs: The number of epochs of the experiment
    :param learning_rate: The initial learning rate
    :param batch_size: The batch size of the SGD
    :returns: The epoch loss (vector) and the epoch accuracy (vector)
    """
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)
    criterion, optimizer = create_loss_and_optimizer(net, learning_rate)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.9)
    n_batches = len(train_loader)
    epoch_loss = []
    epoch_accuracy = []

    for epoch in range(n_epochs):

        total_running_loss = 0.0
        print_every = n_batches // 5
        training_start_time = time.time()
        c = 0
        total_predicted = []
        total_labels = []

        for i, (inputs, labels) in enumerate(train_loader):
            # get the inputs
            if torch.cuda.is_available():
                inputs = Variable(inputs.cuda())
                labels = Variable(labels.cuda().long())
            else:
                inputs = Variable(inputs)
                labels = Variable(labels)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)

            total_labels.extend(labels)
            total_predicted.extend(predicted)

            if (i + 1) % (print_every + 1) == 0:
                total_running_loss += loss.item()
                c += 1

        epoch_predictions = (np.array(total_predicted) == np.array(total_labels)).sum().item()
        print('---------- Epoch %d Loss: %.3f Accuracy: %.3f Time: %.3f----------' % (
            epoch + 1, total_running_loss / c, epoch_predictions / len(total_predicted),
            time.time() - training_start_time))
        epoch_accuracy.append(epoch_predictions / len(total_predicted))
        epoch_loss.append(total_running_loss / c)
        scheduler.step()

    print('Finished Training')

    return epoch_loss, epoch_accuracy


### train_net.py -> Driver code for training the CNN model

In [14]:
import torch
import pandas as pd
import torchvision.transforms as transforms
from torchvision import datasets


torch.manual_seed(0)

DATA_DIR = "patches_casia_with_rot"  # the directory of the augmented patches
transform = transforms.Compose([transforms.ToTensor()])

data = datasets.ImageFolder(root=DATA_DIR, transform=transform)  # Fetch data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if str(device) == "cuda:0":
    print("cuda enabled")
    cnn = CNN().cuda()
else:
    print("no cuda")
    cnn = CNN()

epoch_loss, epoch_accuracy = train_net(cnn, data, n_epochs=40, learning_rate=0.0001, batch_size=128)

pd.DataFrame(epoch_loss).to_csv('SRM_loss.csv')
pd.DataFrame(epoch_accuracy).to_csv('SRM_accuracy.csv')

torch.save(cnn.state_dict(), 'Cnn.pt')


no cuda
---------- Epoch 1 Loss: 0.693 Accuracy: 0.504 Time: 867.756----------
---------- Epoch 2 Loss: 0.694 Accuracy: 0.527 Time: 881.413----------
---------- Epoch 3 Loss: 0.693 Accuracy: 0.525 Time: 898.228----------
---------- Epoch 4 Loss: 0.693 Accuracy: 0.530 Time: 872.893----------
---------- Epoch 5 Loss: 0.693 Accuracy: 0.531 Time: 876.871----------
---------- Epoch 6 Loss: 0.691 Accuracy: 0.531 Time: 896.824----------
---------- Epoch 7 Loss: 0.694 Accuracy: 0.530 Time: 868.527----------
---------- Epoch 8 Loss: 0.689 Accuracy: 0.531 Time: 873.244----------
---------- Epoch 9 Loss: 0.690 Accuracy: 0.531 Time: 911.680----------
---------- Epoch 10 Loss: 0.690 Accuracy: 0.530 Time: 875.200----------
---------- Epoch 11 Loss: 0.692 Accuracy: 0.531 Time: 878.297----------
---------- Epoch 12 Loss: 0.690 Accuracy: 0.532 Time: 903.949----------
---------- Epoch 13 Loss: 0.688 Accuracy: 0.532 Time: 880.082----------
---------- Epoch 14 Loss: 0.689 Accuracy: 0.532 Time: 900.387----