In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

In [3]:
### Dataset class

In [4]:
class FaceDataset(Dataset):
    def __init__(self, img_dir, labels_dict, transform=None, max_images=100):
        self.img_dir = img_dir
        self.labels_dict = labels_dict
        self.transform = transform
        
        # Get all image names from labels_dict keys
        self.img_names = list(labels_dict.keys())
        
        # Keep only first max_images (default 100)
        self.img_names = self.img_names[:max_images]

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = self.labels_dict[img_name]
        return image, torch.tensor(label, dtype=torch.long)

In [5]:
### Mixture of Experts Model

In [6]:
class ExpertCNN(nn.Module):
    def __init__(self):
        super(ExpertCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(32, 2)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [7]:
class GatingNetwork(nn.Module):
    def __init__(self, num_experts):
        super(GatingNetwork, self).__init__()
        self.gate = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(16, num_experts)

    def forward(self, x):
        x = self.gate(x)
        x = x.view(x.size(0), -1)
        return F.softmax(self.fc(x), dim=1)

In [8]:
class MixtureOfExperts(nn.Module):
    def __init__(self, num_experts):
        super(MixtureOfExperts, self).__init__()
        self.experts = nn.ModuleList([ExpertCNN() for _ in range(num_experts)])
        self.gating = GatingNetwork(num_experts)

    def forward(self, x):
        gate_weights = self.gating(x)  # shape: [batch_size, num_experts]
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)  # shape: [batch, classes, experts]
        out = torch.bmm(expert_outputs, gate_weights.unsqueeze(2)).squeeze(2)  # shape: [batch, classes]
        return out

In [9]:
### Generating Label Dict

In [None]:
def create_labels_dict(dataset_path):
    labels_dict = {}
    for label_type in ['real', 'fake']:
        label = 0 if label_type == 'real' else 1
        folder_path = os.path.join(dataset_path, label_type)
        for fname in os.listdir(folder_path):
            if fname.endswith(".jpg") or fname.endswith(".png"):
                labels_dict[os.path.join(label_type, fname)] = label
    return labels_dict

In [11]:
### Training Script

In [12]:
def train_model(img_dir, num_epochs=5, batch_size=8, num_experts=3):
    labels_dict = create_labels_dict(img_dir)
    transform = transforms.Compose([
        transforms.Resize((1024, 1024)),
        transforms.ToTensor()
    ])

    dataset = FaceDataset(img_dir, labels_dict, transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = MixtureOfExperts(num_experts=num_experts)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

    print("Training completed.")
    return model

In [13]:
### Run Training

In [None]:
if __name__ == "__main__":
    img_dir = r"C:\Users\aasth\Downloads\dataset2\Final Dataset"
    train_model(img_dir)

Epoch 1/5, Loss: 0.6576
Epoch 2/5, Loss: 0.6365
