In [1]:
import torch
from pathlib import Path

DATA = Path('/mnt/wsl/PHYSICALDRIVE1/data/unsplash').glob('*.jpg')
CACHE_PATH = Path('/mnt/wsl/PHYSICALDRIVE1/data/cache2')
CACHE_PATH.mkdir(exist_ok=True, parents=True)
BINS = 32
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.005
SCHEDULER_GAMMA = 0.2
EDIT_COUNT = 25
LOSS_DAMPING = 2
MODELS_PATH = Path('models')
MODELS_PATH.mkdir(exist_ok=True, parents=True)

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
f'Using device `{device}`'

'Using device `cuda:0`'

In [2]:
from torch.utils.data import DataLoader, random_split
from editor.training import HistogramDataset

dataset = HistogramDataset(DATA, edit_count=EDIT_COUNT, bin_count=BINS, delete_corrupt_images=False, cache_path=CACHE_PATH)
total_size = len(dataset)
train_size = int(0.9 * total_size)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=torch.Generator().manual_seed(42))

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=32)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=32)

f'Loaded {len(train_dataset)} training images and {len(test_dataset)} test images'

'Loaded 561982 training images and 62443 test images'

In [3]:
import torch.nn as nn
import torch.nn.functional as F

class HistogramRestorationNet(nn.Module):
    def __init__(self):
        super(HistogramRestorationNet, self).__init__()

        self.conv1 = nn.Conv3d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm3d(16)
        self.conv2 = nn.Conv3d(16, 32, 3, padding=1)
        self.bn2 = nn.BatchNorm3d(32)
        self.conv3 = nn.Conv3d(32, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm3d(64)
        
        # Adjusted residual connections with proper downsampling and channel matching
        self.res1 = nn.Sequential(
            nn.Conv3d(16, 32, 1, stride=1, padding=0),  # Match channels
            nn.BatchNorm3d(32),
            nn.MaxPool3d(2)  # Downsample to match size
        )
        self.res2 = nn.Sequential(
            nn.Conv3d(32, 64, 1, stride=1, padding=0),  # Match channels
            nn.BatchNorm3d(64),
            nn.MaxPool3d(2)  # Downsample to match size
        )

        self.fc1 = nn.Linear(64 * 4 * 4 * 4, 512)
        self.fc_bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 32 * 32 * 32)
        self.apply(HistogramRestorationNet._init_weights_he)

    @staticmethod
    def _init_weights_he(m):
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
            torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

    def forward(self, x):
        # Input dimensions: (batch_size, channels(1), 32, 32, 32)

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool3d(x, 2)

        # Apply first adjusted residual connection
        res = self.res1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool3d(x, 2)
        x += res  # Add adjusted residual

        # Apply second adjusted residual connection
        res = self.res2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool3d(x, 2)
        x += res  # Add adjusted residual

        # Flatten for fully connected layers
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc_bn1(self.fc1(x)))
        x = self.fc2(x)

        # Reshape back to the histogram shape
        x = x.view(-1, 32, 32, 32)
        x /= torch.sum(x, (1, 2, 3)).view(x.size()[0], 1, 1, 1)

        return x
    
edited, og = next(iter(train_dataloader))

In [4]:
import numpy as np
import matplotlib.pyplot as plt


def plot_histograms(original_histogram, edited_histogram, predicted_histogram):
    fig = plt.figure(figsize=(15, 5))
    tensors = [original_histogram.numpy().squeeze(), edited_histogram.numpy().squeeze(), predicted_histogram.numpy().squeeze()]

    for i, tensor in enumerate(tensors, 1):
        ax = fig.add_subplot(1, 3, i, projection='3d')

        x, y, z = np.indices(tensor.shape)
        x = x.flatten()
        y = y.flatten()
        z = z.flatten()
        values = tensor.flatten()

        sizes = values * 5000 

        colors = np.vstack((x, y, z)).T / 31

        sc = ax.scatter(x, y, z, c=colors, s=sizes, marker='o', alpha=0.5)

        ax.set_xlim([0, 31])
        ax.set_ylim([0, 31])
        ax.set_zlim([0, 31])

        ax.set_title(f'Tensor {i}')
    return fig


# plot_histograms(original_histogram, edited_histogram, edited_histogram)

In [5]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from tqdm.notebook import tqdm
from torch.nn.utils import clip_grad_norm_
from editor.training import ProgressivePoolingLoss
# from geomloss import SamplesLoss 
# import numpy as np

writer = SummaryWriter()
model = HistogramRestorationNet().train().to(device)
writer.add_graph(model, next(iter(train_dataloader))[0].to(device))
writer.add_scalar("Bins", BINS)
writer.add_scalar("Batch size", BATCH_SIZE)
writer.add_scalar("Learning rate", LEARNING_RATE)
writer.add_scalar("Scheduler gamma", SCHEDULER_GAMMA)
writer.add_scalar("Loss damping", LOSS_DAMPING)

optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=SCHEDULER_GAMMA)
loss_function = ProgressivePoolingLoss(target_sizes=[16, 32], damping=LOSS_DAMPING).to(device)
# loss_function = torch.nn.KLDivLoss(reduction='batchmean')
# loss_function = SamplesLoss(backend='online')

last_model_path = None
try:
    for epoch in range(NUM_EPOCHS):           
        epoch_loss = 0
        writer.add_scalar("Actual learning rate", scheduler.get_last_lr()[0], epoch)
        for batch_id, (edited_histogram, original_histogram) in enumerate(
            tqdm(train_dataloader, desc=f'Epoch {epoch}', unit='batch')
        ):
            edited_histogram = edited_histogram.to(device)
            original_histogram = original_histogram.to(device)
            
            optimizer.zero_grad()
            predicted_original = model(edited_histogram)
            # predicted_original = torch.clamp(predicted_original, 0.0000000000000000000000001, 1)
            # histogram_points = torch.from_numpy(np.array(np.meshgrid(np.arange(BINS), np.arange(BINS), np.arange(BINS))).T.reshape(-1,3)).to(device)
            # original_weights = original_histogram[:, histogram_points[:, 0], histogram_points[:, 1], histogram_points[:, 2]].float()
            # predicted_weights = predicted_original[:, histogram_points[:, 0], histogram_points[:, 1], histogram_points[:, 2]].float()
            # histogram_points = histogram_points.unsqueeze(0).repeat(BATCH_SIZE, 1, 1)
            # loss = loss_function(original_weights, histogram_points.float(), predicted_weights, histogram_points.float())
            # loss = loss_function(torch.log(predicted_original.unsqueeze(1)), original_histogram)
            loss = loss_function(predicted_original.unsqueeze(1), original_histogram)
            epoch_loss += loss.item()
            writer.add_scalar("Loss/train/batch", loss, epoch * len(train_dataloader) + batch_id)
            writer.flush()

            # loss = loss.sum()
            loss.backward()
            # clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        writer.add_scalar("Loss/train/epoch", epoch_loss, epoch)
        with torch.no_grad():
            model.eval()
            loader = iter(test_dataloader)
            edited_histogram, original_histogram = next(loader)
            edited_histogram = edited_histogram.to(device)
            original_histogram = original_histogram.to(device)
            predicted_original = model(edited_histogram)
            writer.add_figure("Histograms/train/original", plot_histograms(
                original_histogram.cpu()[0], edited_histogram.cpu()[0], predicted_original.cpu()[0]
            ), epoch)
            model.train()
        last_model_path = MODELS_PATH / f'model-{epoch}.pth'
        torch.save(model.state_dict(), last_model_path)
        scheduler.step()
except KeyboardInterrupt:
    print('Interrupted, saving last model')
    last_model_path = MODELS_PATH / f'model-final.pth'
    torch.save(model.state_dict(), last_model_path)
finally:
    writer.close()

Epoch 0:   0%|          | 0/17562 [00:00<?, ?batch/s]

  scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor


Epoch 1:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 2:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 3:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 4:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 5:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 6:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 7:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 8:   0%|          | 0/17562 [00:00<?, ?batch/s]

Epoch 9:   0%|          | 0/17562 [00:00<?, ?batch/s]

In [6]:
model = HistogramRestorationNet().to(device)
model.load_state_dict(torch.load(last_model_path))
model.eval()
loader = iter(test_dataloader)

In [12]:
from editor.ploting import plot_histograms


edited_histogram, original_histogram = next(loader)
edited_histogram = edited_histogram.to(device)
original_histogram = original_histogram.to(device)
predicted_original = model(edited_histogram)
plot_histograms([
    original_histogram.cpu().numpy().squeeze(),
    edited_histogram.cpu().numpy().squeeze(),
    predicted_original.cpu().detach().numpy().squeeze()
])


In [8]:
import torch.nn.functional as F

original_histogram.size()

torch.Size([1, 1, 32, 32, 32])

In [9]:

# 32 / 16 -> 2
# 32 / 8 -> 4
# 4 -> 8

a =  F.avg_pool3d(original_histogram, 4)
a /= torch.sum(a)
plot_histograms([
    original_histogram.cpu().numpy().squeeze(),
   a.cpu().numpy().squeeze(),
])