In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from PIL import Image
import os
from dotenv import load_dotenv
import torch
from torchsummary import summary

In [None]:
class FakeRealDataset(Dataset):
    def __init__(self, real_dir, fake_dir, transform=None):
        self.real_images = [os.path.join(real_dir, img) for img in os.listdir(real_dir)]
        self.fake_images = [os.path.join(fake_dir, img) for img in os.listdir(fake_dir)]
        self.images = self.real_images + self.fake_images
        self.labels = [0] * len(self.real_images) + [1] * len(self.fake_images)
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
real_frames_dir = "/teamspace/studios/this_studio/deepfake/dataset/images/real"
fake_frames_dir = "/teamspace/studios/this_studio/deepfake/dataset/images/fake"
dataset = FakeRealDataset(real_frames_dir, fake_frames_dir, transform=transform)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size

In [None]:
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

In [None]:
class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 28 * 28, 512) 
        self.fc2 = nn.Linear(512, 2)  
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = self.relu(self.conv3(x))
        x = self.maxpool(x)
        
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomCNN()
model = model.to(device)
print(device)

In [None]:
summary(model, (3, 224, 224))  

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [None]:
from tqdm import tqdm
import torch
from comet_ml import Experiment

def train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs, device):
    experiment = Experiment(
        api_key=os.getenv("API_KEY"),
        project_name=os.getenv("PROJECT_NAME"),
        workspace=os.getenv("WORKSPACE")
    )

    experiment.log_parameters({
        "learning_rate": optimizer.param_groups[0]['lr'],
        "batch_size": train_loader.batch_size,
        "num_epochs": num_epochs,
        "model": model.__class__.__name__
    })

    best_accuracy = 0.0
    global_step = 0

    for epoch in range(num_epochs):
        model.train()  
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')

        for batch_idx, (inputs, targets) in enumerate(train_pbar):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad() 
            outputs = model(inputs)  
            loss = criterion(outputs, targets)  
            loss.backward()  
            optimizer.step()  

   
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

            step_loss = loss.item()
            step_accuracy = 100. * predicted.eq(targets).sum().item() / targets.size(0)
            experiment.log_metric("train_step_loss", step_loss, step=global_step)
            experiment.log_metric("train_step_accuracy", step_accuracy, step=global_step)

            train_pbar.set_postfix({'Loss': train_loss / (batch_idx + 1), 'Acc': 100. * train_correct / train_total})
            global_step += 1

        train_accuracy = 100. * train_correct / train_total
        experiment.log_metric("train_epoch_loss", train_loss / len(train_loader), step=epoch)
        experiment.log_metric("train_epoch_accuracy", train_accuracy, step=epoch)

        model.eval()  
        test_loss = 0.0
        test_correct = 0
        test_total = 0
        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Test]')

            for batch_idx, (inputs, targets) in enumerate(test_pbar):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()

                step_loss = loss.item()
                step_accuracy = 100. * predicted.eq(targets).sum().item() / targets.size(0)
                experiment.log_metric("test_step_loss", step_loss, step=global_step)
                experiment.log_metric("test_step_accuracy", step_accuracy, step=global_step)

                test_pbar.set_postfix({'Loss': test_loss / (batch_idx + 1), 'Acc': 100. * test_correct / test_total})

        test_accuracy = 100. * test_correct / test_total
        experiment.log_metric("test_epoch_loss", test_loss / len(test_loader), step=epoch)
        experiment.log_metric("test_epoch_accuracy", test_accuracy, step=epoch)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_accuracy:.2f}%')
        print(f'Test Loss: {test_loss/len(test_loader):.4f}, Test Acc: {test_accuracy:.2f}%')

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            torch.save(model.state_dict(), 'best_model_cnn.pth')
            print(f'Best model saved with accuracy: {best_accuracy:.2f}%')
            experiment.log_model("best_model", 'best_model.pth')

        print()

    experiment.end()

num_epochs = 15
train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, num_epochs, device)