In [1]:
from pathlib import Path

DATA_PATH = Path('/mnt/wsl/PHYSICALDRIVE1/data/unsplash/edited')
DATA = sorted(DATA_PATH.glob('*'))
BINS = 32
NUM_EPOCHS = 10
BATCH_SIZE = 16
LEARNING_RATE = 0.001


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class HistogramRestorationNet(nn.Module):
    def __init__(self):
        super(HistogramRestorationNet, self).__init__()
        
        # Encoder
        self.enc_conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.enc_conv2 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.enc_conv3 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)
        
        # Bottleneck
        self.bottleneck_conv = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        
        # Decoder
        self.dec_conv1 = nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec_conv2 = nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dec_conv3 = nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.final_conv = nn.Conv3d(in_channels=16, out_channels=1, kernel_size=3, stride=1, padding=1)
        
        # Adjustments for skip connections
        # Adaptation layer to match dimensions between encoder and decoder stages for skip connections
        self.match_conv3_to_dec1 = nn.Conv3d(in_channels=64, out_channels=64, kernel_size=1)  # Matching x3 to dec_conv1's output
        self.match_conv2_to_dec2 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=1)  # Matching x2 to dec_conv2's output

    def forward(self, x):
        # Encoder
        x1 = F.relu(self.enc_conv1(x))
        x2 = F.relu(self.enc_conv2(x1))
        x3 = F.relu(self.enc_conv3(x2))
        
        # Bottleneck
        b = F.relu(self.bottleneck_conv(x3))
        
        # Decoder with skip connections
        # First decoder layer + skip connection from x3
        d1 = F.relu(self.dec_conv1(b))
        x3_matched = self.match_conv3_to_dec1(x3)
        d1 = d1 + x3_matched  # Add matched x3 to the output of the first decoder layer
        
        # Second decoder layer + skip connection from x2
        d2 = F.relu(self.dec_conv2(d1))
        x2_matched = self.match_conv2_to_dec2(x2)
        d2 = d2 + x2_matched  # Add matched x2 to the output of the second decoder layer
        
        d3 = F.relu(self.dec_conv3(d2))
        
        # Final convolution to get back to original histogram dimensions
        x = torch.sigmoid(self.final_conv(d3))
        
        return x


In [3]:
from torch.utils.data import Dataset, DataLoader, random_split
from typing import Generator, Tuple, List
from editor.utils import compute_histogram
from PIL import Image
from tqdm import tqdm


class HistogramDataset(Dataset):
    def __init__(self, paths: List[Path], expected_edit_count: int = 5):
        self._paths = paths
        self._expected_edit_count = expected_edit_count
        self._pairs = list(self._get_pairs())

    def _get_pairs(self) -> Generator[Tuple[Path, Path], None, None]:
        for path in tqdm(self._paths):
            if len(list(path.glob('*.jpg'))) != self._expected_edit_count + 1:
                continue

            original_path = path / 'original.jpg'
            try:
                Image.open(original_path)
            except:
                print(f'Failed to open {original_path}')
                continue
            for i in range(self._expected_edit_count):
                try:
                    Image.open(path / f'{i}.jpg')
                except:
                    print(f'Failed to open {path / f"{i}.jpg"}')
                    break
                yield original_path, path / f'{i}.jpg'

    def __len__(self):
        return len(self._pairs)

    def __getitem__(self, idx):
        original, edited = self._pairs[idx]
        original_histogram = compute_histogram(original, bins=BINS, normalize=True)
        edited_histogram = compute_histogram(edited, bins=BINS, normalize=True)
        return (
            torch.tensor(edited_histogram, dtype=torch.float).unsqueeze(0), 
            torch.tensor(original_histogram, dtype=torch.float).unsqueeze(0)
        )


dataset = HistogramDataset(DATA)
total_size = len(dataset)
train_size = int(0.8 * total_size)
test_size = total_size - train_size
generator = torch.Generator().manual_seed(42)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=generator)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=32)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=32)

Failed to open /mnt/wsl/PHYSICALDRIVE1/data/unsplash/edited/0LnXtS8DUZI/3.jpg
Failed to open /mnt/wsl/PHYSICALDRIVE1/data/unsplash/edited/0OqCRbwWu6g/1.jpg


In [4]:
import torch
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output

device = torch.device("cuda:0")
model = HistogramRestorationNet().to(device)
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
loss_function = torch.nn.SmoothL1Loss()

losses = []

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    for edited_histogram, original_histogram in tqdm(train_dataloader):
        edited_histogram = edited_histogram.to(device)
        original_histogram = original_histogram.to(device)
        
        optimizer.zero_grad()
        
        predicted_original = model(edited_histogram)
        loss = loss_function(predicted_original, original_histogram)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        losses.append(loss.item())
    
    scheduler.step()

    clear_output(wait=True)
    print(f"LR: {scheduler.get_last_lr()}")
    plt.figure(figsize=(10, 5))
    plt.plot(losses, label='Training Loss')
    plt.xlabel('Batch Number')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.title(f'Loss per Batch - Epoch {epoch + 1}')
    plt.show()

  2%|▏         | 150/6243 [00:07<03:05, 32.93it/s]

In [None]:
torch.save(model.state_dict(), 'model_state_dict.pth')

In [None]:
device = torch.device("cuda:0")
model = HistogramRestorationNet().to(device)
model.load_state_dict(torch.load('model_state_dict.pth'))
model.eval()

HistogramRestorationNet(
  (enc_conv1): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (enc_conv2): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (enc_conv3): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (bottleneck_conv): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (dec_conv1): ConvTranspose3d(128, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))
  (dec_conv2): ConvTranspose3d(64, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))
  (dec_conv3): ConvTranspose3d(32, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(1, 1, 1))
  (final_conv): Conv3d(16, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
)

In [None]:
loader = iter(test_dataloader)

In [None]:
from editor.ploting import plot_histograms


edited_histogram, original_histogram = next(loader)
edited_histogram = edited_histogram.to(device)
original_histogram = original_histogram.to(device)
predicted_original = model(edited_histogram)
plot_histograms([
    original_histogram.cpu().numpy().squeeze(),
    edited_histogram.cpu().numpy().squeeze(),
    predicted_original.cpu().detach().numpy().squeeze()
])
