In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, ConcatDataset, Dataset
from torchvision import transforms
from tqdm import tqdm
from transformers import ViTModel
import cv2
import numpy as np
import dnnlib
import legacy

os.environ['TORCH_CUDA_ARCH_LIST'] = '8.0';'8.6'

haar_cascade_path = "D:/DeepFake/nawfal/haarcascade_frontalface_default.xml"
face_cascade = cv2.CascadeClassifier(haar_cascade_path)

class ViT_StyleGAN_Model(nn.Module):
    def __init__(self, num_classes):
        super(ViT_StyleGAN_Model, self).__init__()
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
        self.fc = nn.Linear(768, num_classes)

    def forward(self, x):
        batch_size, c, h, w = x.size()
        x = x.view(batch_size, c, h, w)
        features = self.vit(pixel_values=x).last_hidden_state[:, 0, :]  # Extract the [CLS] token
        features = features.view(batch_size, -1)
        out = self.fc(features)  # No need for indexing here, as `features` is already 2D
        return out


def load_stylegan_models():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    network_pkl = "D:\\DeepFake\\stylegan2-ffhq-1024x1024.pkl"
    with dnnlib.util.open_url(network_pkl) as f:
        network = legacy.load_network_pkl(f)
        G = network['G_ema'].to(device)  # Generator
        D = network['D'].to(device)  # Discriminator
    return G, D

def generate_images_with_gan(G, num_images):
    device = next(G.parameters()).device  # Ensure to get the device from the generator's parameters
    z = torch.randn(num_images, G.z_dim).to(device)
    c = torch.zeros(num_images, G.c_dim).to(device)
    images = G(z, c, truncation_psi=0.7, noise_mode='const')

    # Resize images to 224x224
    resize_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToPILImage()
    ])

    images_resized = [resize_transform(image.cpu().detach()) for image in images]
    images_resized = torch.stack([transforms.ToTensor()(image).to(device) for image in images_resized])  # Move to GPU
    return images_resized

def evaluate_images_with_discriminator(D, images):
    device = next(D.parameters()).device  # Get the device from the discriminator's parameters
    images = images.to(device)
    with torch.no_grad():
        logits = D(images)
    return logits

class VideoDataset(Dataset):
    def __init__(self, video_folder, label, transform=None, max_frames=100):
        self.video_folder = video_folder
        self.label = label
        self.transform = transform
        self.max_frames = max_frames
        self.video_files = [os.path.join(video_folder, f) for f in os.listdir(video_folder) if f.endswith('.mp4')]

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        try:
            video_path = self.video_files[idx]
            frames = self.load_video(video_path)
            if self.transform:
                frames = [self.transform(frame) for frame in frames]
            frames = torch.stack(frames)
            label = torch.tensor(self.label, dtype=torch.long)
            return frames, label
        except Exception as e:
            print(f"Error loading video {self.video_files[idx]}: {e}")
            return torch.zeros((self.max_frames, 3, 224, 224)), torch.tensor(self.label, dtype=torch.long)


    def load_video(self, path):
        cap = cv2.VideoCapture(path)
        frames = []

        while cap.isOpened():
            ret, frame = cap.read()

            if not ret:
                break

            gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            faces = face_cascade.detectMultiScale(gray_frame, scaleFactor=1.1, minNeighbors=5)

            if len(faces) > 0:
                x, y, w, h = faces[0]
                face = frame[y:y+h, x:x+w]
                face = cv2.resize(face, (224, 224))
                frames.append(face)

        cap.release()
        frames = self.pad_or_truncate_frames(frames)
        return frames

    def pad_or_truncate_frames(self, frames):
        if len(frames) < self.max_frames:
            while len(frames) < self.max_frames:
                frames.append(np.zeros((224, 224, 3), dtype=np.uint8))
        else:
            frames = frames[:self.max_frames]
        return frames

transform = transforms.Compose([
    transforms.ToTensor(),
])

VIDEO_DIR_REAL =r"D:/New folder/real"
VIDEO_DIR_MANIPULATED =r"D:/New folder/fake"
BATCH_SIZE = 4
NUM_CLASSES = 2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5

real_dataset = VideoDataset(video_folder=VIDEO_DIR_REAL, label=0, transform=transform)
manipulated_dataset = VideoDataset(video_folder=VIDEO_DIR_MANIPULATED, label=1, transform=transform)
full_dataset = ConcatDataset([real_dataset, manipulated_dataset])

def collate_fn(batch):
    frames, labels = zip(*batch)
    frames = torch.stack(frames)
    labels = torch.tensor(labels)
    return frames, labels

data_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=collate_fn)

model = ViT_StyleGAN_Model(num_classes=NUM_CLASSES)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

G, D = load_stylegan_models()

def train_model(model, data_loader, num_epochs, save_dir="checkpoints"):
    model.train()
    os.makedirs(save_dir, exist_ok=True)  # Create directory to save checkpoints

    for epoch in range(num_epochs):
        running_loss = 0.0
        print(f"Epoch {epoch + 1}/{num_epochs} started...")

        for batch_idx, (frames, labels) in enumerate(tqdm(data_loader)):
            frames = frames.to(device)  # Move frames to GPU
            labels = labels.to(device)  # Move labels to GPU

            optimizer.zero_grad()

            # Generate and resize images with StyleGAN
            gan_frames = generate_images_with_gan(G, frames.size(0))
            gan_frames = gan_frames.to(device)  # Move GAN generated frames to GPU

            outputs = model(gan_frames)  # Model expects inputs on the same device
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * frames.size(0)

            if batch_idx % 10 == 0:
                print(f"Batch {batch_idx + 1}: Loss = {loss.item():.4f}")

        epoch_loss = running_loss / len(data_loader.dataset)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

        # Save checkpoint after each epoch
        checkpoint_path = os.path.join(save_dir, f"model_epoch_{epoch + 1}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")

    print("Training complete.")

# Call the modified train function

train_model(model, data_loader, num_epochs=NUM_EPOCHS)

model_save_path = "vit_stylegan_model.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

model = ViT_StyleGAN_Model(num_classes=NUM_CLASSES)
model.load_state_dict(torch.load(model_save_path, weights_only=True))
model.to(device)  # Ensure the model is on GPU
model.eval()
print("Model loaded from disk")

Epoch 1/5 started...


  2%|▏         | 1/60 [02:50<2:47:57, 170.81s/it]

Batch 1: Loss = 0.5984
