In [None]:
import torch
import torch.nn as nn
from torch.utils.data import random_split
from time import perf_counter
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install colour-demosaicing

In [None]:
! rm -rf Restorer

In [None]:
! git clone https://github.com/rymuelle/Restorer.git

In [None]:
! pip install Restorer/[losses]

In [None]:
!pip install optuna

In [None]:
!pip install kornia

In [None]:
from Restorer.Restorer import Restorer, AddPixelShuffle
from Restorer.utils import numpy_pixel_shuffle, numpy_pixel_unshuffle

In [None]:
from colour_demosaicing import mosaicing_CFA_Bayer, demosaicing_CFA_Bayer_Malvar2004
from torchvision.transforms import ToTensor
import optuna
import gc

In [None]:
# Path to onsight
checkpoint_path = '/content/drive/MyDrive/Onsight/Checkpoints'
model_path = '/content/drive/MyDrive/Onsight/Models'

In [None]:
# Dataset

In [None]:
from RawRefinery.train.pretraining import Flickr8kDataset

In [None]:
dataset = Flickr8kDataset()

In [None]:
rggb_tensor, image_tensor, target_image_tensor, conditioning_tensor = dataset[30]
plt.subplots(2,2, figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.imshow(image_tensor.permute(1,2,0))

plt.subplot(2, 2, 2)
plt.imshow(target_image_tensor.permute(1,2,0)/1.2)

plt.subplot(2, 2, 3)
plt.imshow(target_image_tensor.permute(1,2,0))



conditioning_tensor


In [None]:
# Set up for Optuna

In [None]:
weight_file_path = f'{checkpoint_path}/gain_model_v4.pt'


from Restorer.Restorer import Restorer, AddPixelShuffle


def make_model(width = 58, base_blocks = 2, dec_blocks=2, late_blocks=1, vit_blocks=0, expand_dims=2):

  total_late_block = max(base_blocks+late_blocks-vit_blocks, 0)
  enc_blks = [base_blocks, base_blocks, total_late_block, total_late_block]
  dec_blks = [dec_blocks, dec_blocks, dec_blocks, dec_blocks]
  vit_blks = [0, 0, vit_blocks, vit_blocks]
  middle_blk_num = base_blocks+late_blocks*2
  cond_output=32
  cond_input = 4
  drop_path = 0.0 #trial.suggest_float("drop_path",0, 0.1, log=False)
  drop_path_increment = 0.05 #trial.suggest_float("drop_path_increment",0, 0.1, log=False)
  model = Restorer
  model = Restorer(in_channels=4, out_channels=3 * 2 ** 2, width=width, middle_blk_num=middle_blk_num,
                    enc_blk_nums=enc_blks, vit_blk_nums=vit_blks, dec_blk_nums=dec_blks,
                    cond_input = cond_input, cond_output=cond_output, expand_dims=expand_dims,
                   drop_path=drop_path,drop_path_increment=drop_path_increment)
  model = AddPixelShuffle(model)
  return model
#[FrozenTrial(number=0, state=1, values=[0.0010226645435831695, 2463.5001086660195],  params={'width': 80, 'expand_dims': 2, 'base_blocks': 1, 'dec_blks': 2, 'late_blocks': 0}, user_attrs={}, system_attrs={'fixed_params': {'width': 80, 'expand_dims': 2, 'base_blocks': 1, 'dec_blks': 2, 'late_blocks': 0}}, intermediate_values={}, distributions={'width': IntDistribution(high=90, log=False, low=30, step=1), 'expand_dims': IntDistribution(high=4, log=False, low=1, step=1), 'base_blocks': IntDistribution(high=3, log=False, low=1, step=1), 'dec_blks': IntDistribution(high=3, log=False, low=1, step=1), 'late_blocks': IntDistribution(high=5, log=False, low=0, step=1)}, trial_id=1, value=None),

def load_model(device):
    #[FrozenTrial(number=4,  params={'width': 33, 'expand_dims': 2, 'base_blocks': 1, 'dec_blks': 2, 'late_blocks': 0}, user_attrs={}, system_attrs={'NSGAIISampler:generation': 0}, intermediate_values={}, distributions={'width': IntDistribution(high=90, log=False, low=30, step=1), 'expand_dims': IntDistribution(high=4, log=False, low=1, step=1), 'base_blocks': IntDistribution(high=3, log=False, low=1, step=1), 'dec_blks': IntDistribution(high=3, log=False, low=1, step=1), 'late_blocks': IntDistribution(high=5, log=False, low=0, step=1)}, trial_id=5, value=None),
    #FrozenTrial(number=12, state=1, values=[0.0012007935130358552, 574.8597779699994],params={'width': 90, 'expand_dims': 4, 'base_blocks': 1, 'dec_blks': 1, 'late_blocks': 0}, user_attrs={}, system_attrs={'NSGAIISampler:generation': 0}, intermediate_values={}, distributions={'width': IntDistribution(high=90, log=False, low=30, step=1), 'expand_dims': IntDistribution(high=4, log=False, low=1, step=1), 'base_blocks': IntDistribution(high=3, log=False, low=1, step=1), 'dec_blks': IntDistribution(high=3, log=False, low=1, step=1), 'late_blocks': IntDistribution(high=5, log=False, low=0, step=1)}, trial_id=13, value=None),
    width = 58
    expand_dims = 2
    base_blocks = 2
    dec_blocks = 2
    late_blocks = 1

    # Model, loss, optimizer
    model = make_model(width=width, base_blocks=base_blocks, late_blocks=late_blocks, expand_dims=expand_dims, dec_blocks=dec_blocks)
    model = model.to(device)
    state_dict = torch.load(weight_file_path,map_location=torch.device('cpu'))

    new_dict = {}
    for key, value in torch.load(weight_file_path,map_location=torch.device('cpu')).items():
        if 'conditioning_gen' not in key:
            new_dict[key] = value
    model.load_state_dict(new_dict, strict=False)
    model.to(device)
    return model


In [None]:
from Restorer.CombinedPerceptualLoss import CombinedPerceptualLoss


In [None]:
device = 'cpu'
val_split=0.1
test_split=0.1
val_size = int(len(dataset) * val_split)
test_size = int(len(dataset) * test_split)
batch_size=4
train_size = len(dataset) - val_size - test_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
with torch.no_grad():
  for rggb_tensor, image_tensor, target_image_tensor, conditioning_tensor in train_loader:
        target_image_tensor = target_image_tensor.to(device).float()


        # demosaic_noise = demosaic_noise.to(device).float()
        # demosaic_clean = demosaic_clean.to(device).float()
        # rggb_img = rggb_img.to(device).float()
        # noise_level = noise_level.to(device).float()/0.06
        break
conditioning_tensor

In [None]:
plt.imshow(target_image_tensor[1].permute(1, 2, 0))

In [None]:
def objective(trial: optuna.Trial):
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    best_model_path = os.path.join(checkpoint_path, f'grain_tune_{trial.number}.pt')
    print(best_model_path)
    best_val_loss = float("inf")

    # Hyperparameters
    width = 58
    expand_dims = 2
    base_blocks = 2
    dec_blocks = 2
    vit_blocks = 2
    late_blocks = 1
    #late_blocks = 0
    #max_val_epoch = trial.suggest_int("max_val_epoch", 1, 20)
    max_val_epoch = 5
    drop_path = 0.0 #trial.suggest_float("drop_path",0, 0.1, log=False)
    drop_path_increment = 0.05 #trial.suggest_float("drop_path_increment",0, 0.1, log=False)
    l2_reg =  0 #trial.suggest_float("l2_reg", 1e-8, 1e-3, log=True)
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)

    # lr = 7e-4
    clip = 8e-3#trial.suggest_float("clip", 1e-4, 1e1, log=True)

    batch_size = 2
    lr = batch_size * lr / 2
    num_epochs = 101
    val_split = 0.2
    test_split = 0.6

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on device: {device}")

    # Split dataset into train and val
    val_size = int(len(dataset) * val_split)
    test_size = int(len(dataset) * test_split)
    train_size = len(dataset) - val_size - test_size
    torch.manual_seed(42)  # For reproducibility
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    model = load_model(device)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_reg)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr*0.01)
    #loss_func = CombinedPerceptualLoss(mse=1)
    loss_func = CombinedPerceptualLoss(mse=1*0.5, ssim=0.2, vgg=0.0015)

    loss_func = loss_func.to(device)
    loss_func.loss_modules['vgg'] = loss_func.loss_modules['vgg'].to(device)
    start = perf_counter()

    print(f"\n Trial {trial.number} parameters:")
    print(f"  width: {width}")
    print(f"  expand_dims: {expand_dims}")
    print(f"  base_blocks: {base_blocks}")
    print(f"  late_blocks: {late_blocks}")
    print(f"  max_val_epoch: {max_val_epoch}")
    print(f"  dec_blocks: {dec_blocks}")
    print(f"  drop_path: {drop_path}")
    print(f"  drop_path_increment: {drop_path_increment}")
    print(f"  l2_reg: {l2_reg}")
    print(f"  lr: {lr:.2e}")
    print(f"  clip: {clip:.2e}")

    for epoch in range(num_epochs):
      # train
      model.train()
      print(epoch)
      n_images = 0
      total_loss = 0
      for rggb_tensor, image_tensor, target_image_tensor, conditioning_tensor in train_loader:
            rggb_tensor = rggb_tensor.to(device).float()
            target_image_tensor = target_image_tensor.to(device).float()
            conditioning_tensor = conditioning_tensor.to(device).float()

            pred = model(rggb_tensor, conditioning_tensor)

            # Scale to 0 to 1
            # mins = torch.amin(target_image_tensor, dim=(1,2,3), keepdim=True)
            # maxs = torch.amax(target_image_tensor, dim=(1,2,3), keepdim=True)
            # target_image_tensor -= mins
            # target_image_tensor /= maxs-mins
            # pred -= mins
            # pred /= maxs-mins

            loss = loss_func(pred, target_image_tensor)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),clip)
            optimizer.step()
            n_images += pred.size(0)
            total_loss += loss.item() * pred.size(0)

      scheduler.step()
      # print current learning rate
      lr = optimizer.param_groups[0]['lr']
      print(f'total_loss: {total_loss/n_images:.2e} lr: {lr:.2e}')

      if epoch % 5 == 0:
          start_val = perf_counter()
          # Evaluate on validation set
          model.eval()
          val_loss = 0
          n_val = 0
          with torch.no_grad():
            for rggb_tensor, image_tensor, target_image_tensor, conditioning_tensor in train_loader:
                  rggb_tensor = rggb_tensor.to(device).float()
                  target_image_tensor = target_image_tensor.to(device).float()
                  conditioning_tensor = conditioning_tensor.to(device).float()

                  pred = model(rggb_tensor, conditioning_tensor)

                  # Scale to 0 to 1
                  # mins = torch.amin(target_image_tensor, dim=(1,2,3), keepdim=True)
                  # maxs = torch.amax(target_image_tensor, dim=(1,2,3), keepdim=True)
                  # target_image_tensor -= mins
                  # target_image_tensor /= maxs-mins
                  # pred -= mins
                  # pred /= maxs-mins

                  val_loss += loss.item() * pred.size(0)
                  n_val += pred.size(0)

          end_val = perf_counter()
          avg_val_loss = val_loss / n_val
          print(f'{avg_val_loss:.2e}')
          trial.report(avg_val_loss, epoch)
          torch.save(model.state_dict(), best_model_path)
          if avg_val_loss < best_val_loss:
              best_val_loss = avg_val_loss
              torch.save(model.state_dict(), best_model_path)

          if epoch > 1 and trial.should_prune():
              raise optuna.exceptions.TrialPruned()

    # torch.save(model.state_dict(), best_model_path)
    run_time = perf_counter() - start

    # Clean up
    del model, image, rggb_img, noise_level, pred, loss
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

    print(f"Trial {trial.number} finished in {run_time:.2f}s with val_loss: {avg_val_loss:.4e}")
    return avg_val_loss


In [None]:
import optuna
from optuna.storages import JournalStorage, JournalFileStorage  # Optional alternative to SQLite
import os


# Define a persistent SQLite storage
study_name = "grain_tune"
# Use a Google Drive path
storage_path = f"sqlite:///{checkpoint_path}/{study_name}.db"

# Create or load the study
study = optuna.create_study(
    direction="minimize",

    study_name=study_name,
    storage=storage_path,
    load_if_exists=True
)

In [None]:
study.enqueue_trial({"lr": 7e-4})


In [None]:
study.optimize(objective, n_trials=20, timeout=None)

In [None]:

from plotly.io import show
fig = optuna.visualization.plot_optimization_history(study)
show(fig)