In [None]:
import os, sys, gdal
%matplotlib inline
import matplotlib.pylab as plt
import matplotlib.patches as patches
from skimage import exposure
import numpy as np
import pandas as pd

In [None]:
# Select the project dataset and time series data


In [None]:
# West Africa - Biomass Site
datapath = '/Volumes/Salo/Salo_ai/SAR_CH3/wa_v2/BIOsS1/'
datefile='S32631X398020Y1315440sS1_A_vv_0001_mtfil.dates'
imagefile_like='S32631X398020Y1315440sS1_A_vv_0001_mtfil.vrt'
imagefile_cross='S32631X398020Y1315440sS1_A_vh_0001_mtfil.vrt'

In [None]:
os.chdir(datapath)

In [None]:
def CreateGeoTiff(Name, Array, DataType, NDV, bandnames = None, ref_image = None,
                 GeoT = None, Projection = None):
    # If it is a 2D image we fake a third demension
    if len(Array.shape) == 2:
        Array = np.array([Array])
    if ref_image == None and (GeoT == None or Projection == None):
        raise RuntimeWarning("ref_image or settings required.")
    if bandnames != None:
        if len(bandnames) != Array.shape[0]:
            raise RuntimeError('Need {} bandnames. {} given'
                              .format(Array.shape[0], len(bandnames)))
    else:
        bandnames = ['Bnad {}'.format(i+1) for i in range(array.shape[0])]
    if ref_image != None:
        refimg = gdal.Open(ref_image)
        Geot = refimg.GetGeoTransform()
        Projection = refimg.GetProjection()
    driver = gdal.GetDriverByName('GTIFF')
    Array[np.isnan(Array)] = NDV
    DataSet = driver.Create(Name,
                           Array.shape[2], Array.shape[1], Array.shape[0], DataType)
    DataSet.SetGeoTransform(GeoT)
    DataSet.SetProjection(Projection)
    for i, image in enumerate(Array, 1):
        DataSet.GetRasterBand(i).WriteArray(image)
        DataSet.GetRasterBand(i).SetNoDataValue(NDV)
        DataSet.SetDescription(bandnames[i-1])
    DataSet.FlushCache()
    return Name

In [None]:
def dualpol2rgb(like, cross, sartype = 'amp', ndv = 0):
    CF = np.power(10.,-8.3)
    if np.isnan(ndv):
        mask = np.isnana(cross)
    else:
        mask = np.equal(cross, ndv)
        
    l = np.ma.array(like, mask = mask, dtype = np.float32)
    c = np.ma.array(cross, mask = mask, dtype = np.float32)
    
    if sartype == 'amp':
        l = np.ma.power(l,2.)*CF
        c = np.ma.power(l,2.)*CF
    elif sartype == 'dB':
        l = np.ma.power(10.,l/10.)
        c = np.ma.power(10.,c/10.)
    elif sartype == 'pwr':
        pass
    else:
        print('invalid type ', sartype)
        raise RuntimeError
        
    if sartype == 'amp':
        ratio = np.ma.sqrt(1/c)/10
        ratio[np.isinf(ratio.data)]=0.00001
    elif sartype == 'dB':
        ratio = 10.*np.ma.log10(1/c)
    else:
        ratio = 1/c
    
    ratio = ratio.filled(ndv)
    
    rgb = np.dstack((like, cross, ratio.data))
    
    
    bandnames = ('Like', 'Cross', 'Ratio')
    return rgb, bandnames, sartype


def any2amp(raster,sartype = 'amp', ndv = 0):
    CF = np.power(10., -8.3)
    mask = raster == ndv
    
    if sartype == 'pwr':
        raster = np.sqrt(raster/CF)
    elif sartype == 'dB':
        raster = np.ma.power(10., (raster+83)/20.)
    elif sartype == 'amp':
        pass
    else:
        print('invalid type ', sartype)
        raise RuntimeError
    
    raster[raster < 1] = 1
    raster[raster > 65535] = 65535
    raster[mask] = 0
    raster = np.ndarray.astype(raster, dtype = np.uint16)
    return raster

In [None]:
# Get the date indices via pandas
dates = open(datefile).readlines()
tindex = pd.DatetimeIndex(dates)
j = 1
print('Band and dates for', imagefile_like)
for i in tindex:
    print('{:4d} {}'.format(j, i.date()), end = ' ')
    j+=1
    if j%5 ==1: print()

In [None]:
#  PICK A BAND NUMBER
bandnbr = 1

In [None]:
img_like = gdal.Open(imagefile_like)
img_cross = gdal.Open(imagefile_cross)
# Get Dimensions
print('Likepol ', img_like.RasterCount, img_like.RasterYSize, img_like.RasterXSize)
print('Crosspol', img_cross.RasterCount, img_cross.RasterYSize, img_cross.RasterXSize)

In [None]:
subset = None
# subset = (3500, 1000, 500, 500) #(xoff, yoff, xsize, ysize)
if subset == None:
    subset = (0,0, img_like.RasterXSize, img_like.RasterYSize)
    

raster = img_like.GetRasterBand(bandnbr).ReadAsArray()
fig, ax = plt.subplots(figsize = (8,8))
ax.set_title('Likepol full image {}'
            .format(tindex[bandnbr-1].date()))
ax.imshow(raster, cmap = 'gray', vmin = np.nanpercentile(raster, 5), vmax = np.nanpercentile(raster, 95))
ax.grid(color = 'blue')
ax.set_xlabel('Pixels')
ax.set_ylabel('Lines')
# Plot the subset as rectangle
if subset != None:
    _ = ax.add_patch(patches.Rectangle((subset[0], subset[1]), subset[2],subset[3], 
                                       fill = False, edgecolor = 'red', linewidth = 3))

In [None]:
raster_like = img_like.GetRasterBand(bandnbr).ReadAsArray(*subset)
raster_cross = img_cross.GetRasterBand(bandnbr).ReadAsArray(*subset)

In [None]:
rgb, bandnames, sartype = dualpol2rgb(raster_like, raster_cross)

In [None]:
rgb_stretched = np.ndarray.astype(rgb.copy(), 'float32')
# FOr each band we apply the strech
for i in range(rgb_stretched.shape[2]):
    rgb_stretched[:,:,i] = np.ndarray.astype(exposure.equalize_hist(rgb_stretched[:,:,i],
                                                                    mask = ~np.equal(rgb_stretched[:,:,i], 0)), 'float32')

In [None]:
rgb_stretched

In [None]:
fig, ax = plt.subplots(1,2, figsize = (16,8))
fig.suptitle('Multi-temporal Sentinel-1 Backscatter image R:{} G:{} B:{}'
            .format(bandnames[0], bandnames[1], bandnames[2]))
plt.axis('off')
ax[0].hist(rgb[:,:,0].flatten(),
           histtype = 'step', color = 'red', 
           bins = 100,range = (0, 10000))
ax[0].hist(rgb[:,:,1].flatten(), 
           histtype = 'step', color = 'green',
           bins = 100, range = (0, 10000))
ax[0].hist(rgb[:,:,2].flatten(), 
          histtype = 'step', color = 'blue',
          bins = 100, range = (0, 10000))
ax[0].set_title('Histograms')
ax[1].imshow(rgb_stretched)
ax[1].set_title('Histogram Equalized')
_ = ax[1].axis('off')

In [None]:
proj = img_like.GetProjection()
geotrans = list(img_like.GetGeoTransform())

subset_xoff = geotrans[0]+subset[0]*geotrans[1]
subset_yoff = geotrans[3]+subset[1]*geotrans[5]
geotrans[0] = subset_xoff
geotrans[3] = subset_yoff
geotrans = tuple(geotrans)
geotrans

In [None]:
outbands = []
for i in range(3):
    outbands.append(any2amp(rgb[:,:,i]))
    
imagename = imagefile_like.replace('_vv_', '_lcr_').replace('.vrt', '_{}.tif'.format(dates[bandnbr-1].rstrip()))
bandnames = ['Like', 'Cross', 'Ratio']
Array = np.array(outbands)
CreateGeoTiff(imagename, Array, gdal.GDT_UInt16,0, bandnames, GeoT = geotrans, Projection = proj)