In [68]:
import torch
import numpy as np

import time

from dataset.datasets import s2stats
from model.models import UNET

from torchvision import transforms
from torch.utils.data import DataLoader


#------------quant_batch----------------------
# takes in tensor of shape [batchsize,band,H,w]
# max batchsize is 256
# permutes dimensions to [band,batchsize,H,W]
# flattens dimensions to [band,batchsize*H*W]
# calculates quantile for each band on dim=1
# 

def quant_batch(imbatch,q=torch.tensor([0.02,0.98])):
    perm = imbatch.permute(1,0,2,3).flatten(start_dim=1)
    q = torch.quantile(perm, q, dim=1,interpolation='midpoint')
    return(q)

  
#------example use case
# define dataset and dataloader, 
# dset= s2stats(root_dir='processed-data/dsen_2_256_split/*/*/')
# dloader = DataLoader(dset, batch_size=(100), num_workers=0,pin_memory=True,shuffle=True)
#
# dataiter = iter(dloader)
# imgs = dataiter.next()
#  
# qbatch = quant_batch(imgs.float())
# qbatch.shape
# output: torch.Size([2,10]) eg  index [0,1] = lower quantile for band 1


#----------mean_batch--------
# takes in tensor of shape [batchsize,band,H,w]
# replaces all inf values with nan to prevent overflow
# calculates mean for each band
# returns mean for each band: torch.Size[channels]
def mean_batch(batch):
    batch=batch.nan_to_num(nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
    return(torch.nanmean(batch,dim=(0,2,3),out=torch.empty(batch.shape[1])))
#-----------std_batch----------
# takes in tensor of shape [batchsize,band,H,w]
# replaces all inf values with nan to prevent overflow
# convert to numpy array
# calculates std for each band
# returns std for each band: torch.Size[channels]
def std_batch(batch):
    batch=batch.nan_to_num(nan=torch.nan,posinf=torch.nan,neginf=torch.nan).numpy()
    std =np.nanstd(batch,axis=(0,2,3),out=np.empty(batch.shape[1]))
    return(torch.from_numpy(std))

#temp
def p(*args):
    print(args)
        

In [70]:
dset= s2stats(root_dir='processed-data/old_dsen_2_256_split/*/*/')
torch.device('cpu')
dloader = DataLoader(dset,batch_size=20, pin_memory=True,shuffle=False)
#torch.device('cuda')
mean = []
Nbatch = 0
for batch in dloader:
    # pixel mean per channel
    batch_std = std_batch(batch)
    batch_mean= mean_batch(batch)
    mean += batch_mean
    Nbatch+=batch.shape[0]
    break
    
#mean\=Nbatch
batch_std.float()
#batch_mean
#mean

tensor([427.1248, 337.7532, 305.9428, 944.7454, 437.4391, 711.1324, 889.7389,
        959.4146, 727.0137, 625.2451])