In [1]:
import torch
import torch.optim as optim
from torch.nn import CrossEntropyLoss
from torchvision.datasets.folder import DatasetFolder
import torchvision.transforms as transforms 
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import resnet34, ResNet34_Weights
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.models import inception_v3, Inception_V3_Weights
from torchvision.models import googlenet, GoogLeNet_Weights
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models import densenet121, DenseNet121_Weights
from torchvision.models import resnext50_32x4d, ResNeXt50_32X4D_Weights
from skimage.io import imread
from sklearn import metrics
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import cv2
import numpy as np
import random
from PIL import Image
import sklearn
from flopth import flopth
import os
from sklearn.model_selection import train_test_split

In [2]:
epochs = 20
lr = 0.01
momentum = 0.5

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
    
DATA_DIR = './'

class MyDataset(Dataset):
    def __init__(self, root_dir, test_size=0.2, transform=None, test = False):
        
        #Collect all samples from our dataset
        #Note that our dataset contains 5000 true and 5000 photoshopped images
        self.root_dir = root_dir
        self.transform = transform
        
        self.class_folders = ['originals', 'photoshops']
        self.samples = []

        for class_folder in self.class_folders:
            class_path = os.path.join(self.root_dir, class_folder)
            class_label = self.class_folders.index(class_folder)
            for file_name in os.listdir(class_path):
                file_path = os.path.join(class_path, file_name)
                self.samples.append((file_path, class_label))
        
        # randomly split into train and test sets
        train_samples, test_samples = train_test_split(self.samples, test_size=test_size)
        
        #Subsample data as needed for experiments
        train_samples = train_samples[:1600]
        test_samples = test_samples[:400]
        
        #Label data
        if test:
            self.samples = test_samples
            self.train = False
        else:
            self.samples = train_samples
            self.train = True

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        file_path, class_label = self.samples[index]
        image = Image.open(file_path).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, class_label


In [3]:
#transform and split data
my_transform = transforms.Compose([
    transforms.Resize((224, 224)), #299, 299 for inception
    transforms.ToTensor(),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(15),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = MyDataset(DATA_DIR, test_size=0.2, transform=my_transform, test = False)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

test_dataset = MyDataset(DATA_DIR, test_size=0.2, transform=my_transform, test = True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(len(train_loader.dataset))
print(len(test_loader.dataset))

1600
400


In [4]:
#Call model
arch = "resnext"

if arch == "inceptionv3":
    model = inception_v3(weights=Inception_V3_Weights.DEFAULT).to(device)
elif arch == "resnet18":
    model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
elif arch == "resnet34":
    model = resnet34(weights=ResNet34_Weights.DEFAULT).to(device)
elif arch == "resnet50":
    model = resnet50(weights=ResNet50_Weights.DEFAULT).to(device)
elif arch == "googlenet":
    model = googlenet(weights=GoogLeNet_Weights.DEFAULT).to(device)
elif arch == "vit":
    model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT).to(device)
elif arch == "dense":
    model = densenet121(weights=DenseNet121_Weights.DEFAULT).to(device)
elif arch == "resnext":
    model = resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.DEFAULT).to(device)

#Get parameters
flops, params = flopth(model, in_size=((3, 224, 224),))
print(flops, params)
criterion = CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=lr,momentum=momentum)
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1)


Downloading: "https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth" to C:\Users\Nebbocaj/.cache\torch\hub\checkpoints\resnext50_32x4d-1a0047aa.pth
100%|█████████████████████████████████████████████████████████████████████████████| 95.8M/95.8M [01:14<00:00, 1.35MB/s]


4.26651G 25.0289M


In [None]:
all_acc = []
for epoch_num in range(0, epochs):
    model.train()
    epoch_total_loss = 0
    train_predictions = []
    train_labels = []
    for batch_num, (inp, target) in enumerate(train_loader):
        if batch_num % 10 == 0:
            print("EPOCH", epoch_num, "Batch number", batch_num)
        train_labels+=target
        optimizer.zero_grad()
        if arch == "inceptionv3":
            output, _ = model(inp.to(device))
        else: 
            output = model(inp.to(device))
        batch_loss = criterion(output,target.to(device))
        _, prediction = torch.max(output, dim=1)
        train_predictions += prediction.detach().tolist()
        epoch_total_loss += batch_loss.item()
        batch_loss.backward()
        optimizer.step()
    train_accuracy = metrics.accuracy_score(train_labels, train_predictions)
    all_acc.append(train_accuracy)
    print("Train Accuracy = %0.2f" % (train_accuracy))

    model.eval()
    labels = []
    predictions = []
    for batch_num, (inp, target) in enumerate(test_loader):
        labels+=target
        batch_prediction = model(inp.to(device))
        _, batch_prediction = torch.max(batch_prediction, dim=1)
        predictions += batch_prediction.detach().tolist()
    accuracy = metrics.accuracy_score(labels, predictions)
    print("Test Accuracy = %0.2f" % (accuracy))
    confusion = metrics.confusion_matrix(labels, predictions)
    
    try:
        print(confusion)
        f1_score = sklearn.metrics.f1_score(labels, predictions)
        print(f1_score)
        recall = sklearn.metrics.recall_score(labels, predictions)
    except:
        pass
print(all_acc)

EPOCH 0 Batch number 0
