In [2]:
%cd /content/drive/MyDrive/FC_4_mine

/content/drive/MyDrive/FC_4_mine


In [3]:
import math
import os
import torch
import numpy as np
import scipy.io
import time
import re
import pandas as pd
import cv2

import torch.utils.data as data
import torchvision.transforms.functional as F
import torchvision.models as models

from torch import nn, Tensor
from torch.nn.functional import normalize as norma
from torch.nn.functional import interpolate
from torch.utils.data import DataLoader

## Utils

In [4]:
def normalize(image):
    return np.clip(image, 0.0, 65535.0) * (1.0 / 65535.0)

def bgr_to_rgb(image):
    return image[:, :, ::-1]
#?
def linear_to_nonlinear(image):
    if isinstance(image, np.ndarray):
        return np.power(image, (1.0 / 2.2))
    if isinstance(image, Tensor):
        return torch.pow(image, 1.0 / 2.2)
    return F.to_pil_image(torch.pow(F.to_tensor(image), 1.0 / 2.2).squeeze(), mode="RGB")

def hwc_to_chw(image):
    return image.transpose(2, 0, 1)

def scale(image):
    image = image - image.min()
    image = image / image.max()
    return image

def rescale(image, size):
    return interpolate(image, size, mode='bilinear')

def correct(image,illuminant):
    image = F.to_tensor(image).to(DEVICE)

    #Correcting image
    correction = illuminant.unsqueeze(2).unsqueeze(3) * torch.sqrt(Tensor([3])).to(DEVICE)
    corrected_img = torch.div(image, correction + 1e-10)

    #Normalization
    max_img = torch.max(torch.max(torch.max(corrected_img, dim=1)[0], dim=1)[0], dim=1)[0] + 1e-10
    max_img = max_img.unsqueeze(1).unsqueeze(1).unsqueeze(1)
    normalized_img = torch.div(corrected_img, max_img)

    return F.to_pil_image(linear_to_nonlinear(normalized_img).squeeze(), mode="RGB")

def percentile(errors, procents):
    return np.percentile(errors, procents * 100)

def compute_metrics(errors):
    errors = sorted(errors)
    metrics = {
        "mean": np.mean(errors),
        "median": percentile(errors, 0.5),
        "trimean": 0.25 * (percentile(errors, 0.25) + 2 * percentile(errors, 0.5) + percentile(errors, 0.75)),
        "bst25": np.mean(errors[:int(0.25 * len(errors))]),
        "wst25": np.mean(errors[int(0.75 * len(errors)):]),
        "wst5": percentile(errors, 0.95)}
    return metrics

def print_metrics(current_metrics):
    print(" Mean ......... : {:.4f} ".format(current_metrics["mean"]))
    print(" Median ....... : {:.4f} ".format(current_metrics["median"]))
    print(" Trimean ...... : {:.4f} ".format(current_metrics["trimean"]))
    print(" Best 25% ..... : {:.4f} ".format(current_metrics["bst25"]))
    print(" Worst 25% .... : {:.4f} ".format(current_metrics["wst25"]))
    print(" Worst 5% ..... : {:.4f} ".format(current_metrics["wst5"]))

def normalize(image):
    max_int = 65535.0
    return np.clip(image, 0.0, max_int) * (1.0 / max_int)

## Dataset class

In [5]:
class ColorCheckerDataset(data.Dataset):

    def __init__(self, train = True, folds_num = 1):

        self.__train = train
  

        path_to_folds = os.path.join("dataset", "folds.mat")
        path_to_metadata = os.path.join("dataset", "metadata.txt")
        self.__path_to_data = os.path.join("dataset", "preprocessed", "numpy_data")
        self.__path_to_label = os.path.join("dataset", "preprocessed", "numpy_labels")

        folds = scipy.io.loadmat(path_to_folds)
        img_idx = folds["tr_split" if self.__train else "te_split"][0][folds_num][0]

        metadata = open(path_to_metadata, 'r').readlines()
        self.__fold_data = [metadata[i - 1] for i in img_idx]

    def __getitem__(self, index):
        file_name = self.__fold_data[index].strip().split(' ')[1]
        img = np.array(np.load(os.path.join(self.__path_to_data, file_name + '.npy')), dtype='float32')
        illuminant = np.array(np.load(os.path.join(self.__path_to_label, file_name + '.npy')), dtype='float32')

        if self.__train:
            img, illuminant = img, illuminant
        else:
            img = cv2.resize(img, (0,0), fx=0.5, fy=0.5)

        img = hwc_to_chw(linear_to_nonlinear(bgr_to_rgb(normalize(img))))

        img = torch.from_numpy(img.copy())
        illuminant = torch.from_numpy(illuminant.copy())

        if not self.__train:
            img = img.type(torch.FloatTensor)

        return img, illuminant, file_name
   
    def __len__(self):
        return len(self.__fold_data)

## ModelFC4 class

Module:

In [6]:
class FC4(torch.nn.Module):

    def __init__(self):
        super().__init__()

        #Alexnet
        alexnet = models.alexnet(pretrained=True)
        self.backbone = nn.Sequential(*list(alexnet.children())[0])

        #Additional layers
        self.final_convs = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=1, ceil_mode=True),
            nn.Conv2d(256, 64, kernel_size=6, stride=1, padding=3),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(64, 4, kernel_size=1, stride=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, image):
        """
        Estimate an RGB colour for the illuminant of the input image
        @param x: the image for which the colour of the illuminant has to be estimated
        @return: the colour estimate as a Tensor. If confidence-weighted pooling is used, the per-path colour estimates
        and the confidence weights are returned as well (used for visualizations)
        """

        image = self.backbone(image)
        out = self.final_convs(image)
 
        # Per-patch color estimates (first 3 dimensions)
        rgb = norma(out[:, :3, :, :], dim=1)

        # Confidence (last dimension)
        confidence = out[:, 3:4, :, :]

        # Confidence-weighted pooling
        pred = norma(torch.sum(torch.sum(rgb * confidence, 2), 2), dim=1)

        return pred, rgb, confidence

FC4 model:

In [7]:
class ModelFC4:

    def __init__(self):
        self._device = "cuda:0"
        self._optimizer = None
        self._network = FC4().to(self._device)

    def predict(self, image):
        """
        Performs inference on the input image using the FC4 method.
        @param image: the image for which an illuminant colour has to be estimated
        @param return_steps: whether or not to also return the per-patch estimates and confidence weights. When this
        flag is set to True, confidence-weighted pooling must be active)
        @return: the colour estimate as a Tensor. If "return_steps" is set to true, the per-path colour estimates and
        the confidence weights are also returned (used for visualizations)
        """

        pred, rgb, confidence = self._network(image)
        return pred

    def optimize(self, image, true):
        self._optimizer.zero_grad()
        pred = self.predict(image)
        loss = self.get_loss(pred, true)
        loss.backward()
        self._optimizer.step()
        return loss.item()

    def get_loss(self, pred, true, safe_v = 0.999999):
        dot = torch.clamp(torch.sum(norma(pred, dim=1) * norma(true, dim=1), dim=1), -safe_v, safe_v)
        angle = torch.acos(dot) * (180 / math.pi)
        return torch.mean(angle).to(self._device)

    def train_mode(self):
        self._network = self._network.train()

    def evaluation_mode(self):
        self._network = self._network.eval()

    def set_optimizer(self, learning_rate: float, optimizer_type: str = "adam"):
        optimizers_map = {"adam": torch.optim.Adam, "rmsprop": torch.optim.RMSprop}
        self._optimizer = optimizers_map[optimizer_type](self._network.parameters(), lr=learning_rate)

## Device and random seed

In [8]:
seed = 0
DEVICE = "cuda:0"
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False

## Train

In [9]:
log_data = pd.DataFrame(columns = ["train_loss", "val_loss", "mean", "median", "trimean", "bst25", "wst25", "wst5"])

fold_num, epochs, batch_size, lr = 0, 560, 1, 0.0003 

model = ModelFC4()
model.set_optimizer(lr)

training_set = ColorCheckerDataset(train=True, folds_num=fold_num)
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True, num_workers=20, drop_last=True)

test_set = ColorCheckerDataset(train=False, folds_num=fold_num)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=20, drop_last=True)

train_loss, val_loss = 0, 0

for epoch in range(epochs):

    model.train_mode()
    train_loss = 0
    start = time.time()

    for i, (image, label, _) in enumerate(training_loader):
        image, label = image.to(DEVICE), label.to(DEVICE)
        loss = model.optimize(image, label)
        train_loss += loss

        if i % 5 == 0:
            print("[ Epoch: {}/{} - Batch: {} ] | [ Train loss: {:.4f} ]".format(epoch, epochs, i, loss))

    train_time = time.time() - start

    val_loss = 0
    start = time.time()

    if epoch % 5 == 0:
        model.evaluation_mode()
        errors = []
        print("\n--------------------------------------------------------------")
        print("\t\t\t Validation")
        print("--------------------------------------------------------------\n")

        with torch.no_grad():
            for i, (image, label, file_name) in enumerate(test_loader):
                image, label = image.to(DEVICE), label.to(DEVICE)
                pred = model.predict(image)
                loss = model.get_loss(pred, label).item()
                val_loss += loss
                errors.append(model.get_loss(pred, label).item())

                if i % 5 == 0:
                    print("[ Epoch: {}/{} - Batch: {}] | Val loss: {:.4f} ]".format(epoch, epochs, i, loss))

                img_id = file_name[0].split(".")[0]
        print("\n--------------------------------------------------------------\n")

    val_time = time.time() - start

    metrics = compute_metrics(errors)
    metrics['train_loss'] = train_loss
    metrics['val_loss'] = val_loss
    df_dictionary = pd.DataFrame([metrics])
    log_data = pd.concat([log_data, df_dictionary], ignore_index=True)
    print("\n********************************************************************")
    print(" Train Time ... : {:.4f}".format(train_time))
    print(" Train Loss ... : {:.4f}".format(train_loss))
    if val_time > 0.1:
      print("....................................................................")
      print(" Val Time ..... : {:.4f}".format(val_time))
      print(" Val Loss ..... : {:.4f}".format(val_loss))
      print("....................................................................")
      print_metrics(metrics)
    print("********************************************************************\n")


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth


  0%|          | 0.00/233M [00:00<?, ?B/s]

  cpuset_checked))


[ Epoch: 0/560 - Batch: 0 ] | [ Train loss: 12.8614 ]
[ Epoch: 0/560 - Batch: 5 ] | [ Train loss: 37.5569 ]
[ Epoch: 0/560 - Batch: 10 ] | [ Train loss: 32.4413 ]
[ Epoch: 0/560 - Batch: 15 ] | [ Train loss: 21.3909 ]
[ Epoch: 0/560 - Batch: 20 ] | [ Train loss: 36.8739 ]
[ Epoch: 0/560 - Batch: 25 ] | [ Train loss: 27.8238 ]
[ Epoch: 0/560 - Batch: 30 ] | [ Train loss: 31.9994 ]
[ Epoch: 0/560 - Batch: 35 ] | [ Train loss: 17.4985 ]
[ Epoch: 0/560 - Batch: 40 ] | [ Train loss: 3.4261 ]
[ Epoch: 0/560 - Batch: 45 ] | [ Train loss: 3.4889 ]
[ Epoch: 0/560 - Batch: 50 ] | [ Train loss: 3.9407 ]
[ Epoch: 0/560 - Batch: 55 ] | [ Train loss: 6.3619 ]
[ Epoch: 0/560 - Batch: 60 ] | [ Train loss: 4.1966 ]
[ Epoch: 0/560 - Batch: 65 ] | [ Train loss: 17.0431 ]
[ Epoch: 0/560 - Batch: 70 ] | [ Train loss: 1.2527 ]
[ Epoch: 0/560 - Batch: 75 ] | [ Train loss: 0.3369 ]
[ Epoch: 0/560 - Batch: 80 ] | [ Train loss: 1.4583 ]
[ Epoch: 0/560 - Batch: 85 ] | [ Train loss: 7.3370 ]
[ Epoch: 0/560 - Batc

KeyboardInterrupt: ignored

In [None]:
log_data.min()

In [None]:
plt.figure(figsize = (15,10))

plt.plot([i for i in range(len(log_data['mean']))], log_data['mean'], label = 'Mean')
plt.xlabel('Epochs')
plt.ylabel('Mean')
plt.legend();

In [None]:
from torchvision.utils import save_image
from torchvision.transforms import transforms

In [None]:
image, label, file_name =  next(test_loader)
save_image(image, 'image_ippi.png')
out = model.predict(image)
original = transforms.ToPILImage()(image.squeeze()).convert("RGB")
ans = correct(original, out)
conv_t = transforms.ToTensor()
save_image(conv_t(ans), 'image_ans_ippi.png')

In [None]:
%%shell
jupyter nbconvert --to html "/content/Model_(1) (1).ipynb"