In [13]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class PhotoDataset(Dataset):
    def __init__(self, digital_path, film_path, transform=None):
        # Filter for valid image files only
        valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.gif', '.webp')
        self.digital_images = [f for f in sorted(os.listdir(digital_path)) if f.lower().endswith(valid_extensions)]
        self.film_images = [f for f in sorted(os.listdir(film_path)) if f.lower().endswith(valid_extensions)]
        self.digital_path = digital_path
        self.film_path = film_path
        self.transform = transform
        self.length = min(len(self.digital_images), len(self.film_images))  # Match the shorter length

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        digital_img = Image.open(os.path.join(self.digital_path, self.digital_images[idx])).convert("RGB")
        film_img = Image.open(os.path.join(self.film_path, self.film_images[idx])).convert("RGB")
        if self.transform:
            digital_img = self.transform(digital_img)
            film_img = self.transform(film_img)
        return {"digital": digital_img, "film": film_img}

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

dataset = PhotoDataset("/Users/yahyarahhawi/Developer/Film/pytorch-CycleGAN-and-pix2pix/datasets/film_transfer/trainA", "/Users/yahyarahhawi/Developer/Film/pytorch-CycleGAN-and-pix2pix/datasets/film_transfer/trainB", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

In [18]:
import torch
import torch.nn as nn
from torchvision.models import vgg19

class PerceptualLoss(nn.Module):
    def __init__(self, layers=None, weights=None):
        super(PerceptualLoss, self).__init__()
        self.vgg = vgg19(pretrained=True).features.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

        # Select specific layers for perceptual loss
        self.layers = layers or ['3', '8', '17']  # Conv1_2, Conv2_2, Conv4_2
        self.weights = weights or [1.0, 0.5, 0.2]

    def forward(self, input, target):
        loss = 0
        input_features = input
        target_features = target

        for i, layer in enumerate(self.vgg):
            input_features = layer(input_features)
            target_features = layer(target_features)

            if str(i) in self.layers:
                loss += self.weights[self.layers.index(str(i))] * nn.functional.mse_loss(input_features, target_features)

        return loss

In [14]:
import torch
import torch.nn as nn

class ResNetBlock(nn.Module):
    """A single ResNet block with two convolutional layers."""
    def __init__(self, channels):
        super(ResNetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.conv_block(x)

class ResNetGenerator(nn.Module):
    """ResNet-based generator for score prediction."""
    def __init__(self, input_nc, output_nc, num_blocks=6, base_channels=64):
        super(ResNetGenerator, self).__init__()
        # Initial convolution
        self.initial = nn.Sequential(
            nn.Conv2d(input_nc, base_channels, kernel_size=7, padding=3),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )

        # Downsampling layers
        self.downsampling = nn.Sequential(
            nn.Conv2d(base_channels, base_channels * 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(base_channels * 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(base_channels * 2, base_channels * 4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(base_channels * 4),
            nn.ReLU(inplace=True)
        )

        # ResNet blocks
        self.res_blocks = nn.Sequential(
            *[ResNetBlock(base_channels * 4) for _ in range(num_blocks)]
        )

        # Upsampling layers
        self.upsampling = nn.Sequential(
            nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(base_channels * 2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )

        # Output layer
        self.output_layer = nn.Sequential(
            nn.Conv2d(base_channels, output_nc, kernel_size=7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.downsampling(x)
        x = self.res_blocks(x)
        x = self.upsampling(x)
        return self.output_layer(x)

In [15]:
class SDE:
    def __init__(self, beta_start=0.1, beta_end=20.0, T=1.0):
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.T = T

    def beta(self, t):
        return self.beta_start + t * (self.beta_end - self.beta_start) / self.T

    def drift(self, x, t):
        return -0.5 * self.beta(t) * x

    def diffusion(self, t):
        return torch.sqrt(self.beta(t))

In [20]:
import os
import torch
import torch.nn as nn
from torchvision.models import vgg19
from torchvision.transforms import Normalize
from tqdm import tqdm

# Set up device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Initialize the model
model = ResNetGenerator(input_nc=3, output_nc=3, num_blocks=9).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Define the Perceptual Loss class
class PerceptualLoss(nn.Module):
    def __init__(self, layers=None, weights=None):
        super(PerceptualLoss, self).__init__()
        self.vgg = vgg19(pretrained=True).features.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

        self.layers = layers or ['3', '8', '17']  # Conv1_2, Conv2_2, Conv4_2
        self.weights = weights or [1.0, 0.5, 0.2]  # Layer importance weights

    def forward(self, input, target):
        loss = 0
        input_features = input
        target_features = target

        for i, layer in enumerate(self.vgg):
            input_features = layer(input_features)
            target_features = layer(target_features)

            if str(i) in self.layers:
                loss += self.weights[self.layers.index(str(i))] * nn.functional.mse_loss(input_features, target_features)

        return loss

# Initialize perceptual loss
criterion = PerceptualLoss().to(device)

# Normalize input for VGG
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# SDE instance
sde = SDE()

# Create directory to save checkpoints
checkpoint_dir = "SD_checkpoint"
os.makedirs(checkpoint_dir, exist_ok=True)

# Training
epochs = 100
# Training Loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for data in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
        digital = data["digital"].to(device)
        film = data["film"].to(device)

        # Add noise to digital images (forward SDE)
        noise = torch.randn_like(digital).to(device)
        t = torch.rand(digital.size(0), 1, 1, 1, device=device)  # Random time steps
        noisy_digital = digital + sde.diffusion(t) * noise

        # Predict the score
        predicted_score = model(noisy_digital)

        # Ensure model output has requires_grad=True
        assert predicted_score.requires_grad, "Model output does not have requires_grad=True"

        # Compute perceptual loss
        loss = criterion(predicted_score, noise)

        # Ensure loss has grad_fn
        assert loss.grad_fn, "Loss does not have grad_fn, check the computation graph."

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    # Save model checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved at {checkpoint_path}")

Epoch 1/100: 100%|██████████| 56/56 [01:04<00:00,  1.14s/it]


Epoch 1: Loss = 9.2605
Model saved at SD_checkpoint/model_epoch_1.pth


Epoch 2/100: 100%|██████████| 56/56 [01:07<00:00,  1.21s/it]


Epoch 2: Loss = 4.9142
Model saved at SD_checkpoint/model_epoch_2.pth


Epoch 3/100: 100%|██████████| 56/56 [01:06<00:00,  1.19s/it]


Epoch 3: Loss = 3.7223
Model saved at SD_checkpoint/model_epoch_3.pth


Epoch 4/100: 100%|██████████| 56/56 [01:06<00:00,  1.18s/it]


Epoch 4: Loss = 3.1638
Model saved at SD_checkpoint/model_epoch_4.pth


Epoch 5/100:  16%|█▌        | 9/56 [00:11<00:58,  1.24s/it]


KeyboardInterrupt: 

In [22]:
import torch
from torchvision.transforms import ToTensor, ToPILImage
from PIL import Image
import os

# Path to the saved model checkpoint
checkpoint_path = "/Users/yahyarahhawi/Developer/Film/SD_checkpoint/model_epoch_4.pth"

# Initialize the model
model = ResNetGenerator(input_nc=3, output_nc=3, num_blocks=9)

# Load the saved weights
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model = model.to(device)
model.eval()  # Set model to evaluation mode

# Load and preprocess the input image
input_image_path = "/Users/yahyarahhawi/Developer/Film/pytorch-CycleGAN-and-pix2pix/datasets/film_transfer/trainA/434148921_1215316256115590_8824329964826277346_n.jpg"  # Replace with your input image path
output_image_path = "sdetest.jpg"  # Replace with where you want to save the output

input_image = Image.open(input_image_path).convert("RGB")
transform = ToTensor()
input_tensor = transform(input_image).unsqueeze(0).to(device)  # Add batch dimension

# Perform inference
with torch.no_grad():
    output_tensor = model(input_tensor)

# Post-process the output image
output_tensor = (output_tensor.squeeze(0).cpu() + 1) / 2  # Normalize to [0, 1]
output_image = ToPILImage()(output_tensor)

# Save the output image
output_image.save(output_image_path)
print(f"Output image saved at {output_image_path}")

  model.load_state_dict(torch.load(checkpoint_path, map_location=device))


Output image saved at sdetest.jpg
