In [None]:
##################################Calculate quantiles for dataset
from tqdm.notebook import tqdm
from dataset.datasets import s2stats, senti
import torch
import numpy

def my_quant(tens,hi=0.98,lo=0.02):
    
    def get_indices(q):
        if int(q)<q:
            ind = [int(q),int(q+1)]
        else:
            ind=[int(q)]
        return(ind)
    
    def calc_quant(ind,sort):
        if len(ind) >1:
            q = (sort[ind[0]]+sort[ind[1]])/2    
        else:
            q=sort[ind]
        return(q.numpy().item()) 
            
    
    sort,_ = torch.sort(tens)
   
    n = tens.shape[0]-1
    ind_hi = get_indices(hi*(n))
    ind_lo = get_indices(lo*(n))
    return(calc_quant(ind_lo,sort),calc_quant(ind_hi,sort))

#test=torch.rand(256*256*2638)
#torchquant=torch.quantile(test,q=torch.tensor([0.02,0.98]),interpolation='midpoint')    

dset= s2stats(root_dir='processed-data/dsen_2_256_split/timeperiod1/train/')

quanthi=[]
quantlo=[]
for band in tqdm(range(10)):
    for id,img in enumerate(tqdm(dset)):
        if id==0:
            conc=img[band,:,:].view(-1)
        else:
            conc=torch.cat((conc,img[band,:,:].view(-1)),0)
        
    q = my_quant(conc)
    quanthi.append(q[0])
    quantlo.append(q[1])
    

print("q_hi =",quanthi)
print('q_lo =',quantlo)
#q_hi = [100.0, 158.0, 108.0, 75.0, 103.54399871826172, 76.29261779785156, 83.46126174926758, 68.2025260925293, 50.748931884765625, 37.05616760253906]
#q_lo = [2102.0, 1714.0, 1398.0, 4716.0, 2433.3343505859375, 3686.039794921875, 4502.04052734375, 4839.02197265625, 3796.3406982421875, 2994.882568359375]


In [None]:
####### calculate std and mean for dataset

from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from dataset.datasets import s2stats
import torch
import numpy as np
dset= s2stats(root_dir='processed-data/dsen_2_256_split/timeperiod1/train/')

loader = DataLoader(dset,
                         batch_size=1,
                         num_workers=0,
                         shuffle=False)

loader = tqdm(loader)
mean = 0.
var = 0.
ninstance=0.
for images in loader:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    
    m=images.mean(2).sum(0)
    v=images.var(2).sum(0)
    
    if not torch.any(torch.isinf(v)):
        mean += m
        var += v
        ninstance += batch_samples

mean /= ninstance
var /= ninstance
std = torch.sqrt(var)

print('std = ',std.numpy())
print('mean =',mean.numpy())

In [8]:
##some functions
import torch

#------------quant_batch----------------------
# takes in tensor of shape [batchsize,band,H,w]
# max batchsize is 256
# permutes dimensions to [band,batchsize,H,W]
# flattens dimensions to [band,batchsize*H*W]
# calculates quantile for each band on dim=1

def quant_batch(imbatch,q=torch.tensor([0.02,0.98])):
    perm = imbatch.permute(1,0,2,3).flatten(start_dim=1)
    q = torch.quantile(perm, q, dim=1,interpolation='midpoint')
    return(q)

#------example use case
# define dataset and dataloader, 
# dset= s2stats(root_dir='processed-data/dsen_2_256_split/*/*/')
# dloader = DataLoader(dset, batch_size=(100), num_workers=0,pin_memory=True,shuffle=True)
#
# dataiter = iter(dloader)
# imgs = dataiter.next()
#  
# qbatch = quant_batch(imgs.float())
# qbatch.shape
# output: torch.Size([2,10]) eg  index [0,1] = lower quantile for band 1


#----------mean_batch--------
# takes in tensor of shape [batchsize,band,H,w]
# replaces all inf values with nan to prevent overflow
# calculates mean for each band
# returns mean for each band: torch.Size[channels]
def mean_batch(batch):
    b=batch.nan_to_num(nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
    return(torch.nanmean(batch,dim=(0,2,3),out=torch.empty(batch.shape[1])))
#-----------std_batch----------
# takes in tensor of shape [batchsize,band,H,w]
# replaces all inf values with nan to prevent overflow
# convert to numpy array
# calculates std for each band
# returns std for each band: torch.Size[channels]
def std_batch(batch):
    b=batch.nan_to_num(nan=torch.nan,posinf=torch.nan,neginf=torch.nan).numpy()
    std =np.nanstd(b,axis=(0,2,3),out=np.empty(batch.shape[1]))
    return(torch.from_numpy(std))
       

q_hi = [100.0, 158.0, 108.0, 75.0, 103.54399871826172, 76.29261779785156, 83.46126174926758, 68.2025260925293, 50.748931884765625, 37.05616760253906]
q_lo = [2102.0, 1714.0, 1398.0, 4716.0, 2433.3343505859375, 3686.039794921875, 4502.04052734375, 4839.02197265625, 3796.3406982421875, 2994.882568359375]
