In [1]:
%matplotlib inline

import numpy as np
import h5py

import torch
from torch.nn import functional as F
from torch.autograd import Variable

##############################
from models import unet
from datasets import datasets
from utils import manager as mgr
from utils import img_utils
##############################

# set device
#device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
exp_name = 'unet_df_030_001'
WEIGHTS_PATH = '/home/philipp/Data/weights/'+exp_name+'/'
#device = "cuda:0"
device = "cpu"
print(device)

nr_channels = 7
nr_classes = 4

cpu


In [2]:
## load model with weights

In [3]:
model = unet.UNET(in_channels=nr_channels, out_channels=nr_classes)

try:
    mgr.load_weights(model, WEIGHTS_PATH+'weights-19-0.217-0.821.pth')
    #load_weights(model, WEIGHTS_PATH+'latest_5d.pt')
    print("weights loaded")
except:
    model.apply(mgr.weights_init)
    print("no weights found")
    
model.to(device)

loading weights '/home/philipp/Data/weights/unet_df_030_001/weights-19-0.217-0.821.pth'
loaded weights (lastEpoch 18, loss 0.21727018058300018, error 0.8208388090133667)
weights loaded


UNET(
  (conv1): Sequential(
    (0): Conv2d(7, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(128, 256,

In [4]:
## load dataset

In [5]:
class ForestDataset(torch.utils.data.Dataset):

    '''Characterizes a dataset for PyTorch'''

    def __init__(self, path):
        '''Initialization'''
        # open dataset
        self.dset = h5py.File(path, 'r')
        self.ortho = self.dset['ortho']
        self.dsm = self.dset['dsm']
        self.dtm = self.dset['dtm']
        self.slope = self.dset['slope']

        # set number of samples
        self.dataset_size = 10679

        ## TODO:
        # make means and stds load from hdf5
        self.means_tams = np.array([56.12055784563426, 62.130400134006976, 53.03228547781888, 119.50916281232037], dtype='float32')
        self.stds_tams = np.array([30.37628560708646, 30.152693706272483, 23.13718651792004, 49.301477498205074], dtype='float32')

        self.means_dsm = np.array([13.45]).astype(np.float32)
        self.stds_dsm = np.array([10.386633252098674]).astype(np.float32)

        self.means_dtm = np.array([1446.0]).astype(np.float32)
        self.stds_dtm = np.array([271.05322202384195]).astype(np.float32)

        self.means_slope = np.array([22.39]).astype(np.float32)
        self.stds_slope = np.array([11.69830556896441]).astype(np.float32)


    def __len__(self):
        '''Denotes the total number of samples'''
        return self.dataset_size


    def __getitem__(self, index):
        '''Generates one sample of data'''

        # depending on data change mean and std
        means = self.means_tams
        stds = self.stds_tams

        # Load data and get label
        X_ortho = (torch.tensor(self.ortho[index], \
            dtype=torch.float32).permute(2, 0, 1) - \
            means[:, np.newaxis, np.newaxis]) / stds[:, np.newaxis, np.newaxis]
        X_dsm = (torch.tensor(self.dsm[index], \
            dtype=torch.float32).permute(2, 0, 1) - self.means_dsm) / self.stds_dsm
        X_dtm = (torch.tensor(self.dtm[index], \
            dtype=torch.float32).permute(2, 0, 1) - self.means_dtm) / self.stds_dtm
        X_slope = (torch.tensor(self.slope[index], \
            dtype=torch.float32).permute(2, 0, 1) - self.means_slope) / self.stds_slope

        X = torch.cat((X_ortho, X_dsm, X_dtm, X_slope),0)

        return X #torch.from_numpy(y).permute(2, 0, 1)


    def close(self):
        ''' closes the hdf5 file'''
        self.dset.close()


    def show_item(self, index):
        '''shows the data'''
        #plt.imshow(np.array(self.ground_truth[index]))

        fig = plt.figure(figsize=(20,20))

        dic_data = {'RGB' : [np.array(self.ortho[index][:,:,:3]), [0.1, 0.3, 0.5, 0.7]], \
        'CIR' : [np.array(np.roll(self.ortho[index], 1, axis=2)[:,:,:3]), [0.1, 0.3, 0.5, 0.7]], \
        'DSM' : [np.array(self.dsm[index].astype('f')), [10, 20, 30]], \
        'DTM' : [np.array(self.dtm[index].astype('f')), [10, 20, 30]], \
        'Slope' : [np.array(self.slope[index].astype('f')), [10, 20, 30]]}

        for i, key in enumerate(dic_data):
            ax = fig.add_subplot(2, 3, i+1)
            imgplot = plt.imshow(dic_data[key][0])
            ax.set_title(key)
            plt.colorbar(ticks=dic_data[key][1], orientation='horizontal')
            plt.axis('off')

In [6]:
## loading the dataset
#path_dataset = "/home/philipp/Data/dataset_256_df_2.h5"
path_dataset = "/media/philipp/DATA/dataset/dataset_512_df_prediction.h5"
#path_dataset = "/media/philipp/DATA/dataset/dataset_256_df_2.h5"

# open dataset
dataset = ForestDataset(path_dataset)

In [7]:
for i,j in enumerate(range(0,20*4,4)):
    print(i,j)
    print(i,j+1)
    print(i,j+2)
    print(i,j+3)

0 0
0 1
0 2
0 3
1 4
1 5
1 6
1 7
2 8
2 9
2 10
2 11
3 12
3 13
3 14
3 15
4 16
4 17
4 18
4 19
5 20
5 21
5 22
5 23
6 24
6 25
6 26
6 27
7 28
7 29
7 30
7 31
8 32
8 33
8 34
8 35
9 36
9 37
9 38
9 39
10 40
10 41
10 42
10 43
11 44
11 45
11 46
11 47
12 48
12 49
12 50
12 51
13 52
13 53
13 54
13 55
14 56
14 57
14 58
14 59
15 60
15 61
15 62
15 63
16 64
16 65
16 66
16 67
17 68
17 69
17 70
17 71
18 72
18 73
18 74
18 75
19 76
19 77
19 78
19 79


In [9]:
## prepare input data
# cut into 4 quaters = 1 x 7x512x512 -> 4 x 7x256x256

start = 0
end = 40

n = end - start
store = torch.zeros((n*4,7,256,256),dtype=torch.float32)

for i,j in enumerate(range(0,n*4,4)):
    store[j] = dataset[i][:,:256,:256]
    store[j+1] = dataset[i][:,256:,:256]
    store[j+2] = dataset[i][:,:256,256:]
    store[j+3] = dataset[i][:,256:,256:]

In [10]:
# get predictions
with torch.no_grad():
    output = model(store)
p = mgr.get_predictions(output).numpy()

In [11]:
p.shape

(160, 256, 256)

In [12]:
## prepare output data
# merge into 4 quaters = 4 x 7x256x256 -> 1 x 7x512x512

pred = np.zeros((n,512,512), dtype=np.int8)

for i,j in enumerate(range(0,n*4,4)):
    pred[i,:256,:256] = p[j]
    pred[i,256:,:256] = p[j+1]
    pred[i,:256,256:] = p[j+2]
    pred[i,256:,256:] = p[j+3]

In [13]:
pred.shape

(40, 512, 512)

In [12]:
mean = 119.50916281232037
std = 49.301477498205074

i = 0
l = 10000
nirs = []

for i in range(20):
    # grab data
    x = torch.stack([dataset[i][3], \
                    dataset[i+1*l][3], \
                    dataset[i+2*l][3], \
                    dataset[i+3*l][3]])
    
    # reconstruct tile
    merged_512 = np.ones((512, 512))
    # paste predicted data into array
    merged_512[:256,:256] = dataset[i][3]
    merged_512[256:,:256] = dataset[i+1*l][3]
    merged_512[:256,256:] = dataset[i+2*l][3]
    merged_512[256:,256:] = dataset[i+3*l][3]
    
    merged_512 = std * merged_512 + mean
    
    nirs.append(merged_512)

In [13]:
nirs = np.array(nirs)

In [15]:
np.save('nir.npy', nirs)

In [14]:
np.save('pred.npy', pred)

In [None]:
## convert to geotiff

In [1]:
import os
import numpy as np

from osgeo import gdal
from osgeo import gdal_array
from osgeo import osr

In [2]:
def array2raster(newRasterfn, dataset, array, dtype):
    """
    save GTiff file from numpy.array
    input:
        newRasterfn: save file name
        dataset : original tif file
        array : numpy.array
        dtype: Byte or Float32.
    """
    cols = array.shape[1]
    rows = array.shape[0]
    originX, pixelWidth, b, originY, d, pixelHeight = dataset.GetGeoTransform() 

    driver = gdal.GetDriverByName('GTiff')

    # set data type to save.
    GDT_dtype = gdal.GDT_Unknown
    if dtype == "Byte": 
        GDT_dtype = gdal.GDT_Byte
    elif dtype == "Float32":
        GDT_dtype = gdal.GDT_Float32
    else:
        print("Not supported data type.")

    # set number of band.
    if array.ndim == 2:
        band_num = 1
    else:
        band_num = array.shape[2]

    outRaster = driver.Create(newRasterfn, cols, rows, band_num, GDT_dtype)
    outRaster.SetGeoTransform((originX, pixelWidth, 0, originY, 0, pixelHeight))

    # Loop over all bands.
    for b in range(band_num):
        outband = outRaster.GetRasterBand(b + 1)
        # Read in the band's data into the third dimension of our array
        if band_num == 1:
            outband.WriteArray(array)
        else:
            outband.WriteArray(array[:,:,b])

    # setteing srs from input tif file.
    prj=dataset.GetProjection()
    outRasterSRS = osr.SpatialReference(wkt=prj)
    outRaster.SetProjection(outRasterSRS.ExportToWkt())
    outband.FlushCache()

In [3]:
pred = np.load('pred.npy')

In [4]:
path_input = ['/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_121889.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_121890.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122277.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122278.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122279.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122280.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122647.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122648.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122649.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122650.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122651.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122652.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_122653.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_123009.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_123010.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_123011.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_123012.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_123013.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_123014.tif',
 '/media/philipp/DATA/2018_tamsweg/ortho/tile_ortho_123015.tif']

In [5]:
len(path_input)

20

In [6]:
for i, path_in in enumerate(path_input):
    path_out = 'foto/pred_{}.tif'.format(i)
    dataset = gdal.Open(path_in, gdal.GA_ReadOnly)
    array2raster(newRasterfn=path_out, dataset=dataset, array=pred[i], dtype='Byte')