In [None]:
#### Calculate quantiles for dataset
from tqdm.notebook import tqdm
from dataset.datasets import s2stats, sentinel
import torch
import numpy as np
import os


def my_quant(tens,hi=0.98,lo=0.02):
    
    def get_indices(q):
        if int(q)<q:
            ind = [int(q),int(q+1)]
        else:
            ind=[int(q)]
        return(ind)
    
    def calc_quant(ind,sort):
        if len(ind) >1:
            q = (sort[ind[0]]+sort[ind[1]])/2    
        else:
            q=sort[ind]
        return(q.numpy().item()) 
            
    
    sort,_ = torch.sort(tens)
   
    n = tens.shape[0]-1
    ind_hi = get_indices(hi*(n))
    ind_lo = get_indices(lo*(n))
    return(calc_quant(ind_lo,sort),calc_quant(ind_hi,sort))

  



quanthi=[]
quantlo=[]
median_bands=[]
dset= s2stats(root_dir='processed-data/dsen_2_256_new_split/timeperiod1/train/')

for band in tqdm(range(10)):
    for id,(img,_) in enumerate(tqdm(dset)):
        if id==0:
            conc=img[band,:,:].view(-1)
        else:
            conc=torch.cat((conc,img[band,:,:].view(-1)),0)
        
    q = my_quant(conc)
    quantlo.append(q[0])
    quanthi.append(q[1])
    median_bands.append(conc.median().item())

print("q_hi =",quanthi)
print('q_lo =',quantlo)
print('median =', median_bands)





In [None]:
# Calculate classCounts (count how many pixels of each class).

from dataset.utils import classCount,pNormalize
from torch.utils.data import DataLoader
import torch
from dataset.datasets import sentinel
# Create experimental dataset, rgb=True for 3 channels (default = False)
# POINT TO FOLDER WITH TIMEPERIOD(S) WITH SUBFOLDERS: 'test, 'train, 'val
q_hi = torch.tensor([2102.0, 1716.0, 1398.0, 4732.0, 2434.42919921875, 3701.759765625, 4519.2177734375, 4857.7734375, 3799.80322265625, 3008.8935546875])
q_lo = torch.tensor([102.0, 159.0, 107.0, 77.0, 106.98081970214844, 79.00384521484375, 86.18966674804688, 70.40167236328125, 50.571197509765625, 36.95356750488281])    
norm = pNormalize(maxPer=q_hi,minPer=q_lo)

def get_set_classcounts(timeperiod=1):
    BATCH_SIZE=10
    NUM_WORKERS = 2 
    test_set = sentinel(root_dir='processed-data/dsen_2_256_new_split/', img_transform=norm,data="test",timeperiod=timeperiod)
    train_set=sentinel(root_dir='processed-data/dsen_2_256_new_split/', img_transform=norm,data="train",timeperiod=timeperiod)
    val_set=sentinel(root_dir='processed-data/dsen_2_256_new_split/', img_transform=norm,data="val",timeperiod=timeperiod)
    
    test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
    test_classCounts,_ = classCount(test_loader)
    train_classCounts,_ = classCount(train_loader)
    val_classCounts,_ = classCount(val_loader)
    
    return(test_classCounts,train_classCounts,val_classCounts)

test_cc1,train_cc1,val_cc1 = get_set_classcounts(timeperiod=1)
test_cc2,train_cc2,val_cc2 = get_set_classcounts(timeperiod=2)

classCounts= {
    'test':{
        '1':test_cc1,
        '2':test_cc2},
    'train':{
        '1':train_cc1,
        '2':train_cc2},
    'val':{
        '1':val_cc1,
        '2':val_cc2}
}

classCounts1= {
    'test':test_cc1,
    'train':train_cc1,
    'val':val_cc1
    }

In [None]:
#### calculate std and mean for dataset
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from dataset.datasets import s2stats
import torch
import numpy as np
dset= s2stats(root_dir='processed-data/dsen_2_256_new_split/timeperiod1/train/')

loader = DataLoader(dset,
                         batch_size=1,
                         num_workers=0,
                         shuffle=False)

loader = tqdm(loader)
mean = 0.
var = 0.
ninstance=0.
for images in loader:
    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)
    images = images.view(batch_samples, images.size(1), -1)
    
    m=images.mean(2).sum(0)
    v=images.var(2).sum(0)
   
    if not torch.any(torch.isinf(v)):
        mean += m
        var += v
        ninstance += batch_samples

mean /= ninstance
var /= ninstance
std = torch.sqrt(var)

print('std = ',std.numpy())
print('mean =',mean.numpy())

In [None]:
## Examine weights 

import torch
from preprocess.classDict import class_dict
from dataset.stats import ce_weights,classCounts
counts= classCounts['train']

n_samples = torch.sum(counts[1:28])
n_classes = counts.shape[0]-3

weights  = n_samples/(n_classes*counts)
weights[torch.where(counts==0)] = 0
#torch.max(weights)
#torch.min(weights[torch.where(weights!=0)])
#torch.max(ce_weights)
#torch.min(ce_weights[torch.where(ce_weights!=0)])
#weights[9] += 100
#weights[18]+=10
#weights[20] +=10
#weights[23] +=10
#weights[27] +=1

for ii, v in enumerate(class_dict.values()):
    i=ii+1
    print(i,round(ce_weights[i].item(),2),' ',round(weights[i].item(),2),' ',v['class_name'],'\n')


In [None]:
##  some functions
import torch

#------------quant_batch----------------------
# takes in tensor of shape [batchsize,band,H,w]
# max batchsize is 256
# permutes dimensions to [band,batchsize,H,W]
# flattens dimensions to [band,batchsize*H*W]
# calculates quantile for each band on dim=1

def quant_batch(imbatch,q=torch.tensor([0.02,0.98])):
    perm = imbatch.permute(1,0,2,3).flatten(start_dim=1)
    q = torch.quantile(perm, q, dim=1,interpolation='midpoint')
    return(q)

#------example use case
# define dataset and dataloader, 
# dset= s2stats(root_dir='processed-data/dsen_2_256_split/*/*/')
# dloader = DataLoader(dset, batch_size=(100), num_workers=0,pin_memory=True,shuffle=True)
#
# dataiter = iter(dloader)
# imgs = dataiter.next()
#  
# qbatch = quant_batch(imgs.float())
# qbatch.shape
# output: torch.Size([2,10]) eg  index [0,1] = lower quantile for band 1


#----------mean_batch--------
# takes in tensor of shape [batchsize,band,H,w]
# replaces all inf values with nan to prevent overflow
# calculates mean for each band
# returns mean for each band: torch.Size[channels]
def mean_batch(batch):
    b=batch.nan_to_num(nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
    return(torch.nanmean(batch,dim=(0,2,3),out=torch.empty(batch.shape[1])))
#-----------std_batch----------
# takes in tensor of shape [batchsize,band,H,w]
# replaces all inf values with nan to prevent overflow
# convert to numpy array
# calculates std for each band
# returns std for each band: torch.Size[channels]
def std_batch(batch):
    b=batch.nan_to_num(nan=torch.nan,posinf=torch.nan,neginf=torch.nan).numpy()
    std =np.nanstd(b,axis=(0,2,3),out=np.empty(batch.shape[1]))
    return(torch.from_numpy(std))
       

In [None]:
#### check inf values

import torch
from dataset.datasets import s2stats, sentinel
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
import glob
import os
import h5py
import numpy as np


#stat= s2stats(root_dir='processed-data/dsen_2_256_new_split/timeperiod1/train/')
sent= sentinel(root_dir='processed-data/dsen_2_256_new_split/timeperiod1/train/')

def isinf(img):
    return(torch.any(torch.isinf(img)))

statinf=[]
sentinf=[]
for ind,(img,_) in enumerate(tqdm(sent)):
    if isinf(img):
        sentinf.append(ind)
        
       
print('inf img from sentinel:\n',len(sentinf))

In [None]:
# Examine optuna parameters 

from math import floor,log
# reduction_factor = n
red_factor=4
# min_resource = r
min_resource= 2 
#max_resource/min_resource =  R
max_resource = 50

# min_early_stopping_rate = s
min_early_stopping_rate = 1


R = max_resource/min_resource

B = floor(log(max_resource))*R +1

