https://www.geocorner.net/post/artificial-intelligence-for-geospatial-analysis-with-pytorch-s-torchgeo-part-1

In [16]:
import os
import tempfile
from urllib.parse import urlparse
import math

import matplotlib.pyplot as plt
# import planetary_computer
# import pystac

%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 12)

import torch
from torch.utils.data import DataLoader

In [2]:
from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples
from torchgeo.datasets.utils import download_url
from torchgeo.samplers import RandomGeoSampler

In [3]:
# checking both insallations
import rasterio as rio
import torchgeo

from pathlib import  Path
import xarray as xr
import matplotlib.pyplot as plt

## Prep data

### get paths to train and masks

In [4]:
root = Path(r'C:\Users\wenxinyang\Desktop\Projects\iguide\training_samples')
assert root.exists()

In [20]:
# remove any images with NA values in the dataset
n = 0
for filename in os.listdir(os.path.join(root, 'images')):
    path_file = os.path.join(root, 'images', filename)
    file = rio.open(path_file)
    test_value = file.read(1).max()
    file.close()
    if math.isnan(test_value):
        print(filename)
        n = n + 1
        os.remove(os.path.join(root, 'images', filename))

print(n)

000000000039.tif
000000000133.tif
000000000344.tif
3


In [21]:
img_ids = [x for x in os.listdir(os.path.join(root, 'images')) if x.endswith('.tif')]
len(img_ids)

415

In [22]:
# remove unmatched files
for filename in os.listdir(os.path.join(root, 'annotations')):
    if filename not in img_ids:
        print(filename)
        os.remove(os.path.join(root, 'annotations', filename))

000000000039.tif
000000000133.tif
000000000344.tif


In [25]:
train_imgs = list((root/'images').glob('*.tif'))
train_masks = list((root/'annotations').glob('*.tif'))

In [26]:
# see how many chips we have 
len(train_masks)

415

In [27]:
len(train_imgs) == len(train_masks)

True

In [28]:
# As the images and corresponding masks are matched by name, we will sort both lists to keep them synchronized.
train_imgs.sort(); train_masks.sort()

In [29]:
path_imgs = (root/'images').as_posix()
# it is important to use a projected coordinate system
train_ds = RasterDataset(path_imgs, res=30, crs = 'epsg:26918')

In [30]:
print(train_ds)

RasterDataset Dataset
    type: GeoDataset
    bbox: BoundingBox(minx=-1547075.146921655, maxx=-1437998.5057360458, miny=3426936.12106045, maxy=3526592.039209696, mint=0.0, maxt=9.223372036854776e+18)
    size: 415


In [31]:
sampler = RandomGeoSampler(train_ds, size=(255, 255), length=10)

In [32]:
# sampler.areas

In [33]:
bbox = next(iter(sampler))

In [34]:
bbox

BoundingBox(minx=-1503342.7936045222, maxx=-1495692.7936045222, miny=3451224.8441685433, maxy=3458874.8441685433, mint=0.0, maxt=9.223372036854776e+18)

In [35]:
def scale(item: dict):
    item['image'] = item['image'] / 10000
    return item

In [37]:
train_imgs = RasterDataset((root/'images').as_posix(), crs='epsg:26918', res=30, transforms=scale)
train_msks = RasterDataset((root/'annotations').as_posix(), crs='epsg:26918', res=30)

In [38]:
train_msks.is_image = False

In [39]:
train_dset = train_imgs & train_msks

In [40]:
dataloader = DataLoader(train_dset, sampler=sampler, batch_size=8, collate_fn=stack_samples)

In [41]:
batch = next(iter(dataloader))
batch.keys()

dict_keys(['crs', 'bbox', 'image', 'mask'])

## Normalize data

In [42]:
import rasterio as rio

def calc_statistics(dset: RasterDataset):
        """
        Calculate the statistics (mean and std) for the entire dataset
        Warning: This is an approximation. The correct value should take into account the
        mean for the whole dataset for computing individual stds.
        For correctness I suggest checking: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html
        """

        # To avoid loading the entire dataset in memory, we will loop through each img
        # The filenames will be retrieved from the dataset's rtree index
        files = [item.object for item in dset.index.intersection(dset.index.bounds, objects=True)]

        # Reseting statistics
        accum_mean = 0
        accum_std = 0

        for file in files:
            img = rio.open(file).read()/10000 #type: ignore
            accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
            accum_std += img.reshape((img.shape[0], -1)).std(axis=1)

        # at the end, we shall have 2 vectors with lenght n=chnls
        # we will average them considering the number of images
        return accum_mean / len(files), accum_std / len(files)

In [43]:
files = [item.object for item in train_imgs.index.intersection(train_imgs.index.bounds, objects = True)]

In [45]:
mean, std = calc_statistics(train_imgs)
print(mean, std)

[0.01187991 0.01294848 0.01124623 0.01377422] [0.00345638 0.00295712 0.00307989 0.00308207]


In [46]:
class MyNormalize(torch.nn.Module):
    def __init__(self, mean: list[float], stdev: list[float]):
        super().__init__()

        self.mean = torch.Tensor(mean)[:, None, None]
        self.std = torch.Tensor(stdev)[:, None, None]

    def forward(self, inputs: dict):

        x = inputs["image"][..., : len(self.mean), :, :]

        # if batch
        if inputs["image"].ndim == 4:
            x = (x - self.mean[None, ...]) / self.std[None, ...]

        else:
            x = (x - self.mean) / self.std

        inputs["image"][..., : len(self.mean), :, :] = x

        return inputs

    def revert(self, inputs: dict):
        """
        De-normalize the batch.
        Args:
            inputs (dict): Dictionary with the 'image' key
        """

        x = inputs["image"][..., : len(self.mean), :, :]

        # if batch
        if x.ndim == 4:
            x = inputs["image"][:, : len(self.mean), ...]
            x = x * self.std[None, ...] + self.mean[None, ...]
        else:
            x = x * self.std + self.mean

        inputs["image"][..., : len(self.mean), :, :] = x

        return inputs

In [49]:
normalize = MyNormalize(mean=mean, stdev=std)
norm_batch = normalize(batch)
# plot_batch(norm_batch)

batch = normalize.revert(norm_batch)
# plot_batch(batch)

### Let's skip adding spectral index for now

In [56]:
from torchgeo.transforms import AppendNDVI

ndvi_transform = AppendNDVI(index_red=0, index_nir=3)
# print(transformed_batch['image'].shape, transformed_batch['mask'].shape)

In [71]:
print(norm_batch['image'].shape, norm_batch['mask'].shape)

torch.Size([8, 4, 255, 255]) torch.Size([8, 1, 255, 255])


In [70]:
img0_path = os.path.join(root, 'images', '000000000017.tif')
img0 = rio.open(img0_path)
img0.close()

In [58]:
type(ndvi_transform)

torchgeo.transforms.indices.AppendNDVI

In [72]:
# transformed_batch = ndvi_transform(batch)

## Segmentation model

In [73]:
from torchvision.models.segmentation import deeplabv3_resnet50
model = deeplabv3_resnet50(weights=None, num_classes=2)

model

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\wenxinyang/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth
100%|█████████████████████████████████████████████████████████████████████████████| 97.8M/97.8M [00:01<00:00, 96.9MB/s]


DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [74]:
backbone = model.get_submodule('backbone')

In [75]:
conv = torch.nn.modules.conv.Conv2d(
    in_channels=4, 
    out_channels=64, 
    kernel_size=(7, 7),
    stride=(2, 2),
    padding=(3, 3),
    bias=False
)

In [76]:
backbone.register_module('conv1', conv)

In [78]:
pred = model(torch.randn(3, 4, 255, 255))
pred['out'].shape

torch.Size([3, 2, 255, 255])

## Train loop

In [82]:
from typing import Optional, Callable

In [94]:
def train_loop(
    epochs: int, 
    train_dl: DataLoader, 
    val_dl: Optional[DataLoader], 
    model: torch.nn.Module, 
    loss_fn: Callable, 
    optimizer: torch.optim.Optimizer, 
    acc_fns: Optional[list]=None, 
    batch_tfms: Optional[Callable]=None
):
    # size = len(dataloader.dataset)
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    
    cuda_model = model.to(device)

    for epoch in range(epochs):
        accum_loss = 0
        for batch in train_dl:

            if batch_tfms is not None:
                batch = batch_tfms(batch)

            X = batch['image'].to(device)
            y = batch['mask'].type(torch.long).to(device)
            pred = cuda_model(X)['out']
            loss = loss_fn(pred, y)

            # BackProp
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update the accum loss
            accum_loss += float(loss) / len(train_dl)

        # Testing against the validation dataset
        if acc_fns is not None and val_dl is not None:
            # reset the accuracies metrics
            acc = [0.] * len(acc_fns)

            with torch.no_grad():
                for batch in val_dl:

                    if batch_tfms is not None:
                        batch = batch_tfms(batch)                    

                    X = batch['image'].type(torch.float32).to(device)
                    y = batch['mask'].type(torch.long).to(device)

                    pred = cuda_model(X)['out']

                    for i, acc_fn in enumerate(acc_fns):
                        acc[i] = float(acc[i] + acc_fn(pred, y)/len(val_dl))

            # at the end of the epoch, print the errors, etc.
            print(f'Epoch {epoch}: Train Loss={accum_loss:.5f} - Accs={[round(a, 3) for a in acc]}')
        else:
            print(f'Epoch {epoch}: Train Loss={accum_loss:.5f}')

## Loss and accuracy functions

In [85]:
from sklearn.metrics import jaccard_score

def oa(pred, y):
    flat_y = y.squeeze()
    flat_pred = pred.argmax(dim=1)
    acc = torch.count_nonzero(flat_y == flat_pred) / torch.numel(flat_y)
    return acc

def iou(pred, y):
    flat_y = y.cpu().numpy().squeeze()
    flat_pred = pred.argmax(dim=1).detach().cpu().numpy()
    return jaccard_score(flat_y.reshape(-1), flat_pred.reshape(-1), zero_division=1.)    

def loss(p, t):    
    return torch.nn.functional.cross_entropy(p, t.squeeze())

## Training

In [86]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)

In [93]:
train_loop(2, dataloader, None, model, loss, optimizer, 
           acc_fns=[oa, iou], batch_tfms=None)

Epoch 0: Train Loss=0.73250
Epoch 1: Train Loss=0.70910
