In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision.io import read_image
import glob
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
torch.manual_seed(42)
device = 'cuda'


## Dataset extracted from Celeb-DF Videos. DataLoader Creation

In [None]:
class CustomDataset(Dataset):
	def __init__(self, transform=None):
		self.imgs_path = 'C:/792/Celeb-DF-v2_images/'
		self.transform = transform
		file_list = glob.glob(self.imgs_path + "*")
		self.data = []
		for class_path in file_list:
			class_name = class_path.split("\\")[-1]
			for img_path in glob.glob(class_path + "/*.jpg"):
				self.data.append([img_path, class_name])
		self.class_map = {"Celeb-real_images": 0, "Celeb-synthesis_images": 1}
	
	def __len__(self):
		return len(self.data)

	def __getitem__(self, idx):
		img_path, class_name = self.data[idx]
		img = read_image(img_path)
		class_id = self.class_map[class_name]
		class_id = torch.tensor([class_id])

		if self.transform:
				img_tensor = self.transform(img)
		return img_tensor, class_id

In [None]:
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize([128, 128]),
    transforms.ToTensor()
    ])


## Train Test Split

In [None]:
dataset = CustomDataset(transform=train_transforms)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])	

In [None]:
len(dataset)

In [None]:
train_data_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=128, shuffle=True)
len(train_data_loader)

## XceptionNet Model

In [None]:
model = timm.create_model('xception', pretrained=True, num_classes=1)
model = model.to(device)


## Optimizer and Loss Function


In [None]:
optimizer = torch.optim.Adam(model.parameters(),lr = 0.0001)
criterion = nn.BCEWithLogitsLoss()

## Training code

In [None]:
train_losses, train_accs, val_losses, val_accs = [], [], [], []

for epoch in range(20):
    print("=================")
    print(f"Epoch: {epoch+1}")

    model.train()
    avg_loss_train = 0
    correct = 0
    total = 0
  
    for i, data in enumerate(train_data_loader, 0):
    
        inputs,labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs.float())
        outputs = torch.sigmoid(outputs) 
    
        optimizer.zero_grad()
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        avg_loss_train += loss.item()
        predicted = (outputs > 0.5).float()  
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    epoch_loss = avg_loss_train / len(train_data_loader)
    epoch_acc = correct / total

    train_losses.append(epoch_loss)
    train_accs.append(epoch_acc)

    print(f"Train Loss: {epoch_loss} | Train Accuracy: {epoch_acc}")    

    model.eval()
    running_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_data_loader:
            inputs,labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs.float())
            outputs = torch.sigmoid(outputs)  
        
            optimizer.zero_grad()
            loss = criterion(outputs, labels.float())
            running_loss += loss.item()
            
            predicted = (outputs > 0.5).float()  
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        epoch_loss = running_loss / len(val_data_loader)
        epoch_acc = correct / total
    val_losses.append(epoch_loss)
    val_accs.append(epoch_acc)

    print(f"Val Loss: {epoch_loss} | Val Accuracy: {epoch_acc}")

In [None]:


def plot_curves(train_losses, train_accuracies, val_losses, val_accuracies):

  fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(8, 8))

  # Plot the train and test losses
  ax1.plot(train_losses, label='Train')
  ax1.plot(val_losses, label='Val')
  ax1.set_xlabel('Epoch')
  ax1.set_ylabel('Loss')
  ax1.set_title('Losses vs Epochs')
  ax1.legend()

  # Plot the train and test accuracies
  ax2.plot(train_accuracies, label='Train')
  ax2.plot(val_accuracies, label='Val')
  ax2.set_xlabel('Epoch')
  ax2.set_ylabel('Accuracy')
  ax2.set_title('Accuracy vs Epochs')
  ax2.legend()

  # Show the plot
  plt.show()
  

In [None]:
torch.save({
            'model': model,
            'model_state_dict': model.state_dict(),
            'train_losses' : train_losses,
            'train_accs' : train_accs,
            'val_losses' : val_losses,
            'val_accs' : val_accs,
            }, 'CelebDF.pt')
print('model saved')
