In [5]:
from pipelines.p_utils import read_raster
import matplotlib.pyplot as plt
import numpy as np
from glob import glob

In [63]:
def size_filter_func(image_path):
    shape = read_raster(image_path)[0].shape
    area = shape[0]*shape[1]
    if area > 1e6:
        return True
    else:
        return False
    
size_filter = lambda images: list(filter(size_filter_func, images))

def get_images(bands, year="*", month="*"):
    all_images = {
        band: glob(f"../data_source/WHOLE_LONDON_DATASET/*LONDON--{year}-{month}-*--*.{band}.tif") for band in bands
    }
    return {k: size_filter(v) for k, v in all_images.items()}

def get_image_middle(image):
    middle0 = image.shape[0]//2
    middle1 = image.shape[1]//2
    return middle0, middle1

def stack_image_middle(*images):
    max_size = np.array([i.shape for i in images]).max(axis=0)
    ref0, ref1 = max_size[0]//2, max_size[1]//2
    image_stack = []
    for image in images:
        image0, image1 = get_image_middle(image)
        pin0, pin1 = ref0 - image0, ref1 - image1
        canvas = np.zeros(max_size)
        canvas[pin0:pin0+image.shape[0],pin1:pin1+image.shape[1]] = image
        image_stack.append(canvas)
    return np.dstack(image_stack)

def sample_image(image, tl0, tl1, sample_width=16):
    return image[t0:t0+sample_width,t1:t1+sample_width]


image_tilt_ref = tilt_dict = {
    "bad": [
        "2020-04-22--10-58-07",
        "2020-06-25--10-58-18",
        "2020-08-12--10-58-33"
    ],
    "good": [
        "2014-04-15--10-52-21",
        "2014-07-04--10-52-12",
        "2020-04-15--10-51-59",
        "2020-06-02--10-51-54",
        "2020-07-20--10-52-17",
        "2020-08-05--10-52-21",
    ]
}


def tilt_correction(image_file_names, image_dict, image_ref = image_tilt_ref):
    
    for band, fns in image_file_names.items():
        for j, fn in enumerate(fns):
            if any([d in fn for d in image_ref["bad"]]):
                image_dict[band][j] = rotate(image_dict[band][j], -5)
            elif any([d in fn for d in image_ref["good"]]):
                pass
            else:
                raise ValueError(f"{fn} not catalogued")
                
    return image_dict


def construct_image_dict_from_file_dict(image_files, bands):
    images = {band: [read_raster(m)[0] for m in filenames] for band, filenames in image_files.items()}
    images = tilt_correction(image_files, images)
    stacked_images = {band: stack_image_middle(*imgs) for band, imgs in images.items()}
    return stacked_images


def construct_image_dict(bands, year="*", month="*"):
    image_files = get_images(bands, year, month)
    stacked_images = construct_image_dict_from_file_dict(image_files, bands)
    return stacked_images, image_files

In [48]:
from tqdm import tqdm
from scipy.ndimage import rotate

def is_inside_border(sample):
    if not any(np.isnan(sample.reshape(-1))) and (0 not in sample.reshape(-1)):
        return True
    else:
        return False
    

def locate_samples(num_samples, check_band, sample_size, image_dict):
    check_image = image_dict[check_band]
    samples = []
    for n in tqdm(range(num_samples)):
        while True:
            idx0 = np.random.randint(0, check_image.shape[0]-sample_size)
            idx1 = np.random.randint(0, check_image.shape[1]-sample_size)
            sample = check_image[idx0:idx0+sample_size,idx1:idx1+sample_size,:]
            if is_inside_border(sample):
                samples.append([idx0, idx1, sample_size])
                break
    return samples

In [49]:
from torch.utils.data import Dataset

class BaseSampleDatabase(Dataset):
    def __init__(
        self,
        input_dates,
        output_dates,
        agg_input,
        image_input,
        agg_output,
        sample_info,
        image_dict,
        filename_dict
    ):
        self.input_dates=input_dates
        self.output_dates=output_dates
        self.agg_input=agg_input
        self.image_input=image_input
        self.agg_output=agg_output
        self.sample_info=sample_info
        self.image_dict=image_dict
        self.filename_dict=filename_dict
        
        self.previously_notified_list = []
        self.seperate_image_dict()
        

    def __len__(self):
        return len(self.sample_info)
    
    @staticmethod
    def take_sample(image, sample, aggregate):
        sample_image = image[sample[0]:sample[0]+sample[-1],sample[1]:sample[1]+sample[-1],:]
        if aggregate:
            return np.mean(sample_image)
        else:
            return np.mean(sample, axis = -1)
        
    def seperate_image_dict(self):
        self.input_image_dict = {band: [] for band in self.image_dict.keys()}
        self.output_image_dict = {band: [] for band in self.image_dict.keys()}
        
        for band, stacked_image in self.image_dict.items():
            band_filenames = self.filename_dict[band]
            for j, fn in enumerate(band_filenames):
                fn = fn.split("LONDON--")[-1].split(".")[0]
                if fn in self.input_dates:
                    self.input_image_dict[band].append(stacked_image[:,:,j])
                elif fn in self.output_dates:
                    self.output_image_dict[band].append(stacked_image[:,:,j]) 
                else:
                    if fn not in self.previously_notified_list:
                        self.previously_notified_list.append(fn)
                        print(f"{fn} not in input/output list, not used!")
            self.input_image_dict[band] = np.dstack(self.input_image_dict[band])
            self.output_image_dict[band] = np.dstack(self.output_image_dict[band])
        
    def __getitem__(self, idx):
        sample_info = self.sample_info[idx]
        res = {"sample-coords": sample_info}
        
        for band in self.agg_input:
            res[f"{band}-in"] = self.take_sample(self.input_image_dict[band], sample_info, True)
        for band in self.agg_output:
            res[f"{band}-out"] = self.take_sample(self.output_image_dict[band], sample_info, True)
        for band in self.image_input:
            res[f"{band}-in"] = self.take_sample(self.input_image_dict[band], sample_info, False)

        return res

In [67]:
CHECK_BAND = "B10"
LST_BAND = "LST"
NDVI_BAND = "NDVI"
NDBI_BAND = "NDBI"
NDWI_BAND = "NDWI"
UI_BAND = "UI"

stacked_images, image_files = construct_image_dict(bands=all_bands, year="*", month="*")
sample_info = locate_samples(num_samples=500, check_band=CHECK_BAND, sample_size=16, image_dict=stacked_images)

dataset = BaseSampleDatabase(        
    input_dates=["2014-07-04--10-52-12","2014-04-15--10-52-21"],
    output_dates = ["2020-08-05--10-52-21","2020-07-20--10-52-17","2020-06-02--10-51-54","2020-04-22--10-58-07","2020-06-25--10-58-18","2020-08-12--10-58-33","2020-04-15--10-51-59"],
    agg_input=[NDVI_BAND, NDBI_BAND, NDWI_BAND, LST_BAND, UI_BAND],
    image_input=[],
    agg_output=[LST_BAND],
    sample_info=sample_info,
    image_dict=stacked_images,
    filename_dict=image_files
)

KeyboardInterrupt: 

In [51]:
class LearningOperation:
    
    def __init__(self, dataset, learner, input_vars, output_vars, **kwargs):
        
        self.dataset = dataset
        self.learner = learner
        self.input_vars = input_vars
        self.output_vars = output_vars
        self.__dict__.update(kwargs)
        
        unclaimed_vars = set(input_vars + output_vars) - set(dataset[0])
        if unclaimed_vars:
            raise ValueError(f"{unclaimed_vars} not in input or output vars")
    
    def generate_batch(self):
        pass
        
    def step(self):
        pass
    
    def predict(self, var_dict): 
        pass

In [61]:
from sklearn.linear_model import LinearRegression

class LinearRegressionOperation(LearningOperation):
    
    def __init__(self, dataset, input_vars, output_vars):
        learner = LinearRegression()
        super().__init__(dataset, learner, input_vars, output_vars)
        
    def generate_batch(self):
        dataset_dicts = [d for d in self.dataset]
        X_array = np.array([[d[iv] for iv in self.input_vars] for d in dataset_dicts])
        y_array = np.array([[d[ov] for ov in self.output_vars] for d in dataset_dicts])
        return X_array, y_array
        
    def step(self):
        self.learner.fit(*self.generate_batch())
        
    def train_score(self):
        return self.learner.score(*self.generate_batch())

In [62]:
lro = LinearRegressionOperation(
    dataset, input_vars = ["B2-in", "B3-in", "B4-in", "B10-in"], output_vars = ["B10-out"]
)

lro.step()
lro.train_score()

0.3343373312734619

In [None]:
old_image_tilt_ref = tilt_dict = {"bad": ["2018-05-19--10-57-32","2019-04-20--10-58-00","2017-06-17--10-58-17","2017-06-01--10-58-10","2018-07-06--10-57-40","2018-08-07--10-57-56","2016-08-17--10-58-39","2019-08-26--10-58-45","2019-07-25--10-58-34"],"good": ["2017-05-25--10-51-56","2018-06-29--10-51-26","2017-06-10--10-52-04","2018-07-15--10-51-33","2018-07-31--10-51-41","2016-04-20--10-51-56","2019-05-15--10-51-58","2017-08-13--10-52-26","2016-08-26--10-52-32"]}