In [None]:
import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn
from torchvision.datasets.folder import DatasetFolder
import torchvision.transforms as transforms 
from torchvision import models
from skimage.io import imread
from sklearn import metrics
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
import numpy as np
import random
from PIL import Image
import sklearn
from flopth import flopth
import os
from sklearn.model_selection import train_test_split
from PIL import Image, ImageChops
from torchvision.models.resnet import ResNeXt50_32X4D_Weights
from data import data_preprocess

In [None]:
epochs = 20
lr = 0.01
momentum = 0.5

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    
DATA_DIR = './'
    
class MyDataset(Dataset):
    def __init__(self, root_dir, test_size=0.2, transform=None, test = False):
        
        #Collect all samples from our dataset
        #Note that our dataset contains 5000 true and 5000 photoshopped images
        self.root_dir = root_dir
        self.transform = transform
        
        self.class_folders = ['originals', 'photoshops']
        self.class_folders_ela = ['originals_ela', 'photoshops_ela']
        self.samples = []

        for class_folder, class_folder_ela in zip(self.class_folders, self.class_folders_ela):
          
            class_path = os.path.join(self.root_dir, class_folder)
            class_path_ela = os.path.join(self.root_dir, class_folder_ela)
            class_label = self.class_folders.index(class_folder)
            filenames = sorted(os.listdir(class_path))
            filenames_ela = sorted(os.listdir(class_path_ela))
          
            for file_name, file_name_ela in zip(filenames, filenames_ela):
                file_path = os.path.join(class_path, file_name)
                file_path_ela = os.path.join(class_path_ela, file_name_ela)
                self.samples.append((file_path, file_path_ela, class_label))
        
        # randomly split into train and test sets
        train_samples, test_samples = train_test_split(self.samples, test_size=test_size)
        
        #Subsample data as needed for experiments
        #train_samples = train_samples[:4000]
        #test_samples = test_samples[:1000]
        
        #Label data
        if test:
            self.samples = test_samples
            self.train = False
        else:
            self.samples = train_samples
            self.train = True

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        file_path, file_path_ela, class_label = self.samples[index]
        image = Image.open(file_path).convert('RGB')
        image_ela = Image.open(file_path_ela).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)
            image_ela = self.transform(image_ela)

        return image, image_ela, class_label

In [None]:
#transform and split data
my_transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(15),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = MyDataset(DATA_DIR, test_size=0.2, transform=my_transform, test = False)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

test_dataset = MyDataset(DATA_DIR, test_size=0.2, transform=my_transform, test = True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(len(train_loader.dataset))
print(len(test_loader.dataset))

In [None]:
class ELRes(nn.Module):
    def __init__(self, num_classes):
        super(CustomCNN, self).__init__()
        self.hidden_size = 4096
        self.modelELA = models.resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.DEFAULT)
        self.modelRGB = models.resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.DEFAULT)
        self.modelELA.fc = nn.Identity()
        self.modelRGB.fc = nn.Identity()
        self.fc = nn.Linear(self.hidden_size, num_classes)
        
    def forward(self, x1, x2):
        ela_features = self.modelELA(x1)
        rgb_features = self.modelRGB(x2)
        combined_features = torch.cat([ela_features, rgb_features], dim=1)
        self.hidden_size = combined_features.shape[0]
        output = self.fc(combined_features)
        return output

In [None]:
model = ELRes(num_classes = 2).to(device)
criterion = CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)

In [None]:

for epoch_num in range(0, epochs):
    model.train()
    epoch_total_loss = 0
    train_predictions = []
    train_labels = []
    for batch_num, (inp1, inp2, target) in enumerate(train_loader):
        if batch_num % 10 == 0:
            print("EPOCH", epoch_num, "Batch number", batch_num)
        train_labels+=target
        optimizer.zero_grad()
        output = model(inp1.to(device), inp2.to(device))
        batch_loss = criterion(output,target.to(device))
        _, prediction = torch.max(output, dim=1)
        train_predictions += prediction.detach().tolist()
        epoch_total_loss += batch_loss.item()
        batch_loss.backward()
        optimizer.step()
    avrg_loss = epoch_total_loss / dataset.__len__()
    train_accuracy = metrics.accuracy_score(train_labels, train_predictions)
    print("Train Accuracy = %0.2f" % (train_accuracy))
    print("Epoch %d - loss=%0.4f" % (epoch_num, avrg_loss))
    #scheduler.step()


    model.eval()
    labels = []
    predictions = []
    for batch_num, (inp1, inp2, target) in enumerate(test_loader):
        labels+=target
        batch_prediction = model(inp1.to(device), inp2.to(device))
        _, batch_prediction = torch.max(batch_prediction, dim=1)
        predictions += batch_prediction.detach().tolist()
    accuracy = metrics.accuracy_score(labels, predictions)
    print("Test Accuracy = %0.2f" % (accuracy))
    confusion = metrics.confusion_matrix(labels, predictions)

    try:
        print(confusion)
        f1_score = sklearn.metrics.f1_score(labels, predictions)
        print(f1_score)
        recall = sklearn.metrics.recall_score(labels, predictions)
    except:
        pass