In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import random_split, DataLoader
import numpy as np
from scipy.linalg import sqrtm

# ----- 1. Define your transforms -----
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ----- 2. Create a dataset from the evaluation directory -----
root_dir = "/Users/yahyarahhawi/Developer/Film/evaluation"
full_dataset = datasets.ImageFolder(root=root_dir, transform=train_transforms)

# Optional: check class-to-index mapping
print("Class to index mapping:")
print(full_dataset.class_to_idx)

# ----- 3. Split into train, val, test (70%/15%/15%) -----
dataset_size = len(full_dataset)
train_size = int(0.7 * dataset_size)
val_size = int(0.15 * dataset_size)
test_size = dataset_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Overwrite val/test transforms
val_dataset.dataset.transform = val_test_transforms
test_dataset.dataset.transform = val_test_transforms

# ----- 4. Dataloaders -----
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# ----- 5. Use pretrained ResNet18 with dropout -----
class CustomResNet(nn.Module):
    def __init__(self):
        super(CustomResNet, self).__init__()
        self.base_model = models.resnet18(pretrained=True)
        num_features = self.base_model.fc.in_features
        
        # Custom classifier with Dropout to prevent overfitting
        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(num_features, 2)  # Binary classification
        )

    def forward(self, x):
        return self.base_model(x)

model = CustomResNet()
device = torch.device("mps")
model.to(device)

# ----- 6. Define loss and optimizer with weight decay -----
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)

# ----- 7. Training loop with early stopping -----
epochs = 50
patience = 0
best_val_acc = 0.0
early_stop_threshold = 5

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels).item()
        total += labels.size(0)
    
    train_loss = running_loss / total
    train_acc = correct / total

    # Validation
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            val_correct += torch.sum(preds == labels).item()
            val_total += labels.size(0)
    val_acc = val_correct / val_total

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

    # Early stopping based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        patience = 0  # Reset patience if validation improves
    else:
        patience += 1
        if patience >= early_stop_threshold:
            print("Early stopping triggered.")
            break



Class to index mapping:
{'real_film': 0, 'real_iphone': 1}
Epoch 1/50, Train Loss: 0.4789, Train Acc: 0.7570, Val Acc: 0.8497
Epoch 2/50, Train Loss: 0.1568, Train Acc: 0.9525, Val Acc: 0.8889
Epoch 3/50, Train Loss: 0.0706, Train Acc: 0.9832, Val Acc: 0.8954
Epoch 4/50, Train Loss: 0.0364, Train Acc: 0.9944, Val Acc: 0.9150
Epoch 5/50, Train Loss: 0.0179, Train Acc: 1.0000, Val Acc: 0.9346
Epoch 6/50, Train Loss: 0.0134, Train Acc: 1.0000, Val Acc: 0.9542
Epoch 7/50, Train Loss: 0.0071, Train Acc: 1.0000, Val Acc: 0.9412
Epoch 8/50, Train Loss: 0.0059, Train Acc: 1.0000, Val Acc: 0.9608
Epoch 9/50, Train Loss: 0.0149, Train Acc: 0.9958, Val Acc: 0.9477
Epoch 10/50, Train Loss: 0.0105, Train Acc: 0.9986, Val Acc: 0.9477
Epoch 11/50, Train Loss: 0.0090, Train Acc: 0.9986, Val Acc: 0.9150
Epoch 12/50, Train Loss: 0.0118, Train Acc: 0.9972, Val Acc: 0.9020
Epoch 13/50, Train Loss: 0.0252, Train Acc: 0.9930, Val Acc: 0.8824
Early stopping triggered.


In [24]:
import os
import torch
import torch.nn as nn
import numpy as np
from scipy.linalg import sqrtm
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models

# ================================
# 1. Device Setup
# ================================
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)


# ================================
# 2. Your Custom ResNet Model
#    (the one you trained for film vs iPhone)
# ================================
class CustomResNet(nn.Module):
    def __init__(self):
        super(CustomResNet, self).__init__()
        self.base_model = models.resnet18(pretrained=True)
        num_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(num_features, 2)  # final layer: 2-class
        )

    def forward(self, x):
        return self.base_model(x)


# ================================
# 3. Feature Extractor
#    (Removes final classification head -> 512-d features)
# ================================
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, ckpt_path):
        super().__init__()
        # Load your trained ResNet
        self.model = CustomResNet()
        # Load the weights you saved; adjust path if needed
        self.model.load_state_dict(torch.load(ckpt_path, map_location=device))
        self.model.eval()
        
        # Replace final classification layers with Identity,
        # so forward() returns the 512-d feature vector.
        self.model.base_model.fc = nn.Identity()

    def forward(self, x):
        return self.model(x)


# ================================
# 4. Single-Folder Dataset
#    (No subfolders required)
# ================================
class SingleImageFolder(Dataset):
    """
    Reads all images from a single folder (no subfolders).
    Returns (image, 0) because we don't need real labels for FID.
    """
    def __init__(self, root, transform=None):
        super().__init__()
        self.root = root
        self.transform = transform
        
        valid_exts = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff')
        self.image_paths = [
            os.path.join(root, fname)
            for fname in os.listdir(root)
            if fname.lower().endswith(valid_exts)
        ]
        self.image_paths.sort()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, 0


# ================================
# 5. ResNet Transforms
#    (224x224, standard ImageNet normalization)
# ================================
resnet_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


# ================================
# 6. Compute Mean/Cov in ResNet Space
# ================================
def compute_statistics_of_folder(folder_path, model, transform, batch_size=32):
    dataset = SingleImageFolder(root=folder_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    features = []
    model.eval()
    with torch.no_grad():
        for batch_imgs, _ in dataloader:
            batch_imgs = batch_imgs.to(device)
            preds = model(batch_imgs)  # shape: [batch_size, 512]
            features.append(preds.cpu().numpy())

    features = np.concatenate(features, axis=0)
    mu = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)
    return mu, sigma


# ================================
# 7. "FID-like" Distance
# ================================
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
    diff = mu1 - mu2
    diff_squared = diff.dot(diff)
    
    covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid_value = diff_squared + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid_value

def compute_fid_resnet(folder1, folder2, model, transform, batch_size=32):
    mu1, sigma1 = compute_statistics_of_folder(folder1, model, transform, batch_size)
    mu2, sigma2 = compute_statistics_of_folder(folder2, model, transform, batch_size)
    fid_val = calculate_frechet_distance(mu1, sigma1, mu2, sigma2)
    return fid_val


# ================================
# 8. Instantiate the Model
#    (Point to your trained weights)
# ================================
ckpt_path = "/Users/yahyarahhawi/Developer/Film/best_model.pth"  # Adjust if needed
resnet_feature_extractor = ResNetFeatureExtractor(ckpt_path=ckpt_path).to(device)
resnet_feature_extractor.eval()


# ================================
# 9. Compare Various Folders
# ================================
# Modify these paths to match your folder structure
path_cyclegan    = "/Users/yahyarahhawi/Developer/Film/real_vs_fake/fake_cinestill_cyclegan"
path_real_film         = "/Users/yahyarahhawi/Developer/Film/evaluation/real_film"
path_real_iphone       = "/Users/yahyarahhawi/Developer/Film/real_vs_fake/real_iphone"
path_diffusion = "/Users/yahyarahhawi/Developer/Film/real_vs_fake/fake_cinestill_diffusion"
path_validation_real_film = "/Users/yahyarahhawi/Developer/Film/real_vs_fake/validation_real_film"

# 1) Fake Cinestill vs Real Film
fid_fake_vs_film = compute_fid_resnet(
    folder1=path_cyclegan,
    folder2=path_real_film,
    model=resnet_feature_extractor,
    transform=resnet_transforms,
    batch_size=32
)
print("ResNet-based FID (cyclegan film vs Real Film):", fid_fake_vs_film)

# 2) iPhone vs Real Film
fid_iphone_vs_film = compute_fid_resnet(
    folder1=path_real_iphone,
    folder2=path_real_film,
    model=resnet_feature_extractor,
    transform=resnet_transforms,
    batch_size=32
)
print("ResNet-based FID (iPhone vs Real Film):", fid_iphone_vs_film)

# 3) Diffusion vs Real Film
fid_diffusion_vs_film = compute_fid_resnet(
    folder1=path_diffusion,
    folder2=path_real_film,
    model=resnet_feature_extractor,
    transform=resnet_transforms,
    batch_size=32
)
print("ResNet-based FID (Diffusion film vs Real Film):", fid_diffusion_vs_film)

# 3) validation real film vs Real Film
fid_validation_vs_film = compute_fid_resnet(
    folder1=path_validation_real_film,
    folder2=path_real_film,
    model=resnet_feature_extractor,
    transform=resnet_transforms,
    batch_size=32
)
print("ResNet-based FID (validation real film vs Real Film):", fid_validation_vs_film)

Using device: mps


  self.model.load_state_dict(torch.load(ckpt_path, map_location=device))


ResNet-based FID (cyclegan film vs Real Film): 463.0740723128418
ResNet-based FID (iPhone vs Real Film): 789.1210473913691
ResNet-based FID (Diffusion film vs Real Film): 299.6404946109028
ResNet-based FID (validation real film vs Real Film): 346.456256941291
