In [None]:
# in Colab, rasterio needs to be installed
# !pip install rasterio

import random
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import torch

# download and extract the dataset sample
# if not pathlib.Path('../SEN12MS_sample').is_dir():
#     !gdown --id 1GKHIPhhfjutCbb3LhJ0tgjxDvIuDO7tr
#     !tar -zxf SEN12MS_sample.tgz
#     !rm SEN12MS_sample.tgz
sys.path.insert(1,"../")
import utils.sen12ms_dataLoader as sen12ms
sys.path.remove("../")


In [None]:
# helper functions

def scale(data):
    min_ = data.min()
    max_ = data.max()
    return (data - min_) / (max_ - min_+1e-8), min_, max_

def scale_batch(data):
    mins = []
    maxs = []
    for i in range(len(data)):
        d, min_, max_ = scale(data[i,0])
        data[i,0] = d
        mins.append(min_)
        maxs.append(max_)
        
    return data, mins, maxs
        
def descale(data, min_, max_, ):
    return min_+ data * (max_ - min_+1e-8)

def descale_batch(data, mins, maxs):
    for i in range(len(data)):
        data[i,0] = descale(data[i,0],mins[i],maxs[i])
    return data


In [None]:
# initialize the data loader
dataset = sen12ms.SEN12MSDataset(base_dir='../SEN12MS_sample') # Change path
# collect all patches
seasons = sen12ms.Seasons.ALL
# get a dictionary {scene_id: patch_ids} for the whole season
patch_unique_ids = []
for season in seasons.value:
    season_ids = dataset.get_season_ids(season=season)
    for scene_id, patch_ids in season_ids.items():
        for patch_id in patch_ids:
            patch_unique_ids.append((season, scene_id, patch_id))

In [None]:
# ----------------------------------------
# load model
# ----------------------------------------
from models.network_unet import UNetRes as net

model_pool = 'model_zoo'             # fixed
model_name = 'drunet_gray'  # set denoiser model, 'drunet_gray' | 'drunet_color'
model_path = os.path.join(model_pool, model_name+'.pth')

n_channels = 1                       # 1 for grayscale image
if 'color' in model_name:
    n_channels = 3                   # 3 for color image
task_current = 'dn'                  # 'dn' for denoising

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

model = net(in_nc=n_channels+1, out_nc=n_channels, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode="strideconv", upsample_mode="convtranspose")
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
    v.requires_grad = False
model = model.to(device)
print('Model path: {:s}'.format(model_path))
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
print('Params number: {}'.format(number_parameters))


In [None]:
# s1, bb = dataset.get_patch(patch_unique_ids[0][0], patch_unique_ids[0][1], patch_unique_ids[2][2], sen12ms.S1Bands.ALL)
# with_original = np.concatenate((s1,s1))
# dataset.save_patch((with_original,bb),patch_unique_ids[0][0], patch_unique_ids[0][1], patch_unique_ids[2][2], sen12ms.S1Bands.ALLD)


In [None]:
ds1, dbb = dataset.get_patch(patch_unique_ids[0][0], patch_unique_ids[0][1], patch_unique_ids[2][2], sen12ms.S1Bands.ALLD)
# ds1.shape
# dbb
# ds1 == s1
# print(bb,dbb) # <-- losses bounds 

In [None]:
from tqdm import tqdm

# 
noise_level_model = 70.


for patch in tqdm(patch_unique_ids):
    #     get patch
    s1, bb = dataset.get_patch(patch[0], patch[1], patch[2], sen12ms.S1Bands.ALL)
    #     convert to tensor
    din = torch.tensor(s1[None], device = device, dtype = torch.float32)
    #     split channels
    #     din.shape: torch.Size([bs, 2, 256, 256]) - > din.shape: torch.Size([bs*2, 1, 256, 256])
    din = din.reshape([din.shape[0]*2,1,din.shape[2],din.shape[3]])
    #     scaling   
    din, mean, std = scale_batch(din)
    #     add random normal noise
    din += torch.randn(din.shape,device=device,dtype = torch.float32)*noise_level_model/255.
    #     Add noise map
    #     din.shape: torch.Size([bs*2, 1, 256, 256]) - > din.shape: torch.Size([bs*2, 2, 256, 256])
    nose_map = torch.tensor([noise_level_model/255.],device=device, dtype = torch.float32).repeat(din.shape[0], 1, din.shape[2], din.shape[3])
    din = torch.cat((din,nose_map), dim=1)
    
    #     denose
    denoised = model(din)
    
    #     return original scale
    denoised = descale_batch(denoised, mean, std)
    #     return original shape
    denoised = denoised.reshape([int(denoised.shape[0]/2),2,denoised.shape[2],denoised.shape[3]]).to('cpu').numpy()
    #     concatinate bands with original
    with_original = np.concatenate(( s1, denoised.squeeze()))
    #     save data
    dataset.save_patch((with_original,bb),patch[0], patch[1], patch[2], sen12ms.S1Bands.ALLD)
