# Segment TIFF images

In [36]:
import math
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import sklearn.model_selection as msel
import torch
import torch.nn as nn
from skimage import io
from tqdm import tqdm

In [2]:
LAYER_DIR = Path("..") / "data" / "layers"

## Look at image properties

For each layer print the height, width, data type, minimum, and maximum values.

In [3]:
FA = io.imread(LAYER_DIR / "fa.tif")
FA.shape, FA.dtype, FA.min(), FA.max()

((41668, 19981), dtype('float32'), 0.0, 3.4e+38)

In [4]:
SLOPE = io.imread(LAYER_DIR / "slope.tif")
SLOPE.shape, SLOPE.dtype, SLOPE.min(), SLOPE.max()

((41668, 19981), dtype('float32'), -3.4028235e+38, 0.9276394)

In [5]:
WETNESS = io.imread(LAYER_DIR / "wetness.tif")
WETNESS.shape, WETNESS.dtype, WETNESS.min(), WETNESS.max()

((41668, 19981), dtype('float32'), -3.4028235e+38, 53.34468)

In [6]:
DEM = io.imread(LAYER_DIR / "dem.tif")
DEM.shape, DEM.dtype, DEM.min(), DEM.max()

((41668, 19981), dtype('float32'), -3.4028235e+38, 53.149834)

In [7]:
LARV = io.imread(LAYER_DIR / "larv_spot_50m_correct.tif")
LARV.shape, LARV.dtype, LARV.min(), LARV.max()

((41668, 19981), dtype('float32'), 0.0, 3.4e+38)

It appears that "Not a Number" (NaN) values area represented by the largest or smallest float32 value.

Let's compare with the true minimum and maximum values for float32. None of the used map values are anywhere near to these extrema which makes them easy to identify and filter.

In [8]:
np.finfo(np.float32).min, np.finfo(np.float32).max

(-3.4028235e+38, 3.4028235e+38)

I'll give myself some wiggle room for the constants that I'll use to detect NA values.

In [9]:
NA_LO = -3.0e38
NA_HI = 3.0e38

I'll going to want to exclude tiles that have a majority of NaN values. For now I'll set this threshold to 50%.

## How many tiles can we actually use?

I only have one large image (with 4 layers) for training, validation, and testing. The strategy is to pretend that I've got several images by slicing the large image into several smaller images.

I'll start with an arbitrary tile size of 512 x 512 pixels height & width.

In [10]:
TILE_SIZE = 512

In [11]:
ROWS, COLS = FA.shape

## Segment the images

Dataset distribution strategy:
1. Slice the images into 81 rows of tile sized data.
2. Randomly assign the rows to the three datasets.
3. Keep the sets the same between runs by pinning random state.

here is a triangle at the top of the images that has no targets. Should I include that? For now, "Yes."

Using a 60/20/20% (train/val/test) split there will be 16 testing and validation stripes and (81 - 32 =) 49 training stripes.

#### Put image rows (stripes) into datasets

In [12]:
ALL_ROWS = list(range(ROWS // TILE_SIZE))

TRAIN_INDEXES, others = msel.train_test_split(
    ALL_ROWS, train_size=0.61, random_state=4486
)
VAL_INDEXES, TEST_INDEXES = msel.train_test_split(
    others, test_size=0.5, random_state=9241
)

TRAIN_INDEXES = sorted(TRAIN_INDEXES)
VAL_INDEXES = sorted(VAL_INDEXES)
TEST_INDEXES = sorted(TEST_INDEXES)

len(ALL_ROWS), len(TRAIN_INDEXES), len(VAL_INDEXES), len(TEST_INDEXES)

(81, 49, 16, 16)

#### Calculate validation and testing dataset tiles

I need to select a static set of tiles for the validation and testing datasets. I'll handle training data separately below.

In [13]:
NA_LIMIT = 0.5

A class for storing tiles.

In [14]:
@dataclass
class Tile:
    row: int  # Top
    col: int  # Left

How many pixels are NA values.

In [15]:
def has_data(row, col):
    tile = FA[row : row + TILE_SIZE, col : col + TILE_SIZE]
    flag = ((tile > NA_LO) & (tile < NA_HI)).any()
    return flag

Only choose tiles that have enough data tiles

In [16]:
def val_test_tiles(indexes):
    tiles = []
    has_blank = False
    for i in indexes:
        row = i * TILE_SIZE
        for col in range(0, COLS, TILE_SIZE):
            flag = has_data(row, col)
            if flag:
                tiles.append(Tile(row, col))
            elif not has_blank:
                tiles.append(Tile(row, col))
                has_blank = True
    return tiles

In [17]:
VAL_TILES = val_test_tiles(VAL_INDEXES)
print(len(VAL_TILES))

233


In [18]:
TEST_TILES = val_test_tiles(TEST_INDEXES)
print(len(TEST_TILES))

226


#### Organize the training dataset tiles

Group the training rows (stripes) and find all of the possible tiles. Note that there will be further augmentations during training.

In [19]:
def join_train_stipes():
    row_beg = TRAIN_INDEXES[0]
    row_end = row_beg + 1

    stripes = []

    for i in TRAIN_INDEXES[1:]:
        if i == row_end:
            row_end = i + 1
        else:
            stripes.append((row_beg, row_end))
            row_beg = i
            row_end = i + 1

    stripes.append((row_beg, i + 1))
    return stripes


stripes = join_train_stipes()

for i in TRAIN_INDEXES:
    print(i, end=" ")
print()

for i in stripes:
    print(i, end=" ")

0 3 4 5 6 9 10 13 14 16 17 19 21 22 23 26 28 30 31 32 36 37 38 41 42 43 44 45 48 49 52 53 54 57 59 60 62 63 65 66 67 68 69 73 74 75 76 78 79 
(0, 1) (3, 7) (9, 11) (13, 15) (16, 18) (19, 20) (21, 24) (26, 27) (28, 29) (30, 33) (36, 39) (41, 46) (48, 50) (52, 55) (57, 58) (59, 61) (62, 64) (65, 70) (73, 77) (78, 80) 

In [39]:
SKIP = 8


def train_tiles(stripes):
    tiles = []
    count = 0
    has_blank = False
    for beg, end in tqdm(stripes, position=0):
        top = beg * TILE_SIZE
        bot = end * TILE_SIZE
        for row in range(top, bot, SKIP):
            for col in range(0, COLS, SKIP):
                flag = has_data(row, col)
                if flag:
                    tiles.append(Tile(row, col))
                elif not has_blank:
                    tiles.append(Tile(row, col))
                    has_blank = True
    return tiles


TRAIN_TILES = train_tiles(stripes)
print(len(TRAIN_TILES))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [22:56<00:00, 68.82s/it]

2893665





2.8 million is a wee bit of overkill for now, I'll dial that back in the next round. Data is going to be correlated so the huge number is meaningless and I should expect overfitting. We'll see.

## Build a data class etc.

I'll copy code over from other projects & tweak it for this project.

## A U-Net

In [None]:
class UNet(nn.Module):
    def __init__(
        self,
        in_channels: int = 4,
        out_channels: int = 1,
        features: int = 64,
    ):
        super().__init__()

        self.input = self.double_conv(in_channels, features)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder1 = self.block(features, features * 2)
        self.encoder2 = self.block(features * 2, features * 4)
        self.encoder3 = self.block(features * 4, features * 8)
        self.encoder4 = self.block(features * 8, features * 16)

        self.bottleneck = nn.conv2d(features * 16, features * 16)

        self.unpool4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = self.block((features * 8) * 2, features * 8)

        self.unpool3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = self.block((features * 4) * 2, features * 4)

        self.unpool2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = self.block((features * 2) * 2, features * 2)

        self.unpool1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = self.block(features * 2, features)

        self.output = nn.Conv2d(features, out_channels)

    def forward(self, x):
        x = self.input(x)
        enc1 = self.pool(self.encoder1(x))
        enc2 = self.pool(self.encoder2(enc1))
        enc3 = self.pool(self.encoder3(enc2))
        enc4 = self.pool(self.encoder4(enc3))

        x = self.bottleneck(enc4)

        x = self.unpool4(x)
        x = self.decoder4(torch.cat(x, enc4, dim=1))

        x = self.unpool3(x)
        x = self.decoder3(torch.cat(x, enc3, dim=1))

        x = self.unpool2(x)
        x = self.decoder2(torch.cat(x, enc2, dim=1))

        x = self.unpool1(bottleneck)
        x = self.decoder1(torch.cat(x, enc1, dim=1))

        x = self.output(x)
        return x

    def block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )