In [None]:
import os
import cv2

from dataset_process import dataset_to_df, search_df

import numpy as np
import pandas as pd

from tqdm import tqdm

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

# The cross-entropy loss penalizes the model more when it is more confident in the incorrect class
from torch.nn import CrossEntropyLoss

# Adam is an optimization algorithm that can be used instead of the classical SGD procedure
# to update network weights iterative based in training data.
from torch.optim import Adam, lr_scheduler
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import transforms
from torchvision.io import read_image

from torchinfo import summary

from einops import rearrange

from typing import Optional, Tuple

# import timm


In [None]:
# np.random.seed(0)
# torch.manual_seed(0)

In [None]:
print(f"torch version: {torch.__version__}")
print(f"GPU Card: {torch.cuda.get_device_name(0)}")
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Torch is using device:', device)
torch.cuda.get_device_name(device)

NUM_WORKERS = os.cpu_count()
print("CPU Count:", NUM_WORKERS)

#### Gather all the data from the "CIFAK" dataset to be used later

In [None]:
path = '../../data/CIFAK'
relative_paths = ["/train/REAL", "/train/FAKE", "/test/REAL", "/test/FAKE"]
paths_classes = ["REAL", "FAKE", "REAL", "FAKE"]

# path = '../../data/meso_data'
# relative_paths= ["/Real", "/DeepFake"]
# paths_classes=['REAL',"FAKE"]

In [None]:
df_all, df_train, df_val, df_test, classes_stats = dataset_to_df(
    path, relative_paths, paths_classes, 0.8, 0.19, 0.01)

classes_stats

In [None]:
print("First and Last Elements in the Whole dataset")
df_all.iloc[[0,-1]]
# print(df_all.iloc[[0,-1]].to_markdown(headers='keys', tablefmt='psql'))
# print("")
# print("First and Last Elements in the Training dataset")
# print(df_train.iloc[[0,-1]].to_markdown(headers='keys', tablefmt='psql'))

#### Determine the number of slices (patches) of the image

In [None]:
Img_horizontal_slices= 4
images_batch = 16

In [None]:
img_shape = read_image(df_all.iloc[0, 0]).size()

slice_width = img_shape[1]//Img_horizontal_slices
total_img_slices = Img_horizontal_slices**2

print(f"slice_width: {slice_width} pixels")
print("")
print("Image shape: ", img_shape)
print(
    f"Image will be divided into: {Img_horizontal_slices} x {Img_horizontal_slices} = {total_img_slices} slices each with shape {(img_shape[0],slice_width,slice_width)}")
print(
    f"Target Shape of the final flattened image: {total_img_slices} x {img_shape[0]*slice_width**2} ")
print("")
print(f"Feed ({images_batch}) Images to the Dataloader")

In [None]:
class SliceImage:
    def __init__(self, slice_width):
        self.slice_width = slice_width

    def slice(self, img):
        # c: color channels

        # h: desired slice height (pixels)
        # w: desired slice width (pixels)

        # row: No. of vertical slices
        # col: No. of horizontal slices

        # b: Batch of Images

        sliced_Img = rearrange(
            img, 'c (row h) (col w) -> row col c h w', h=slice_width, w=slice_width)

        return sliced_Img

    def __call__(self, img):
        sliced_img = self.slice(img)
        sliced_flattened_img = rearrange(
            sliced_img, 'row col c h w -> (row col) (c h w)')

        return sliced_flattened_img

    plt.show()

In [None]:
# Define the transforms on the input data (x) tensor
data_transform = transforms.Compose([
    transforms.ToPILImage(),
    # transforms.Resize((64, 64)),
    # transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    SliceImage(slice_width=slice_width)
])

In [None]:
# Custom Dataset Class
class Images_Dataset(Dataset):
    def __init__(self, annotations_df, transform=None):
        
        self.annotation = annotations_df
        self.transform = transform

    def __getitem__(self, index):
        
        img = read_image(self.annotation.iloc[index, 0])
        labels = torch.tensor(self.annotation.iloc[index, 4], dtype=torch.float64)

        if self.transform:
            img_t = self.transform(img)

        return img, img_t, labels

    def __len__(self):
        # To return the length of the dataset
        return self.annotation.shape[0]