# Aerial Greenery Image Segmentation Model E2E

This notebook is used to explore the step by step process for the aerial greenery image segmentation model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Add the parent directory to the path
import sys  
sys.path.insert(1, '../')

In [3]:
import os
import numpy as np
import webdataset as wds
from dotenv import load_dotenv

import torch
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils.data import DataLoader
#from torchvision.transforms import functional as F
#from pytorch_lightning import Trainer

import model.transforms as T
from model.aer_mae import AerMae
from model.aegis import AeGIS

In [4]:
load_dotenv()

True

## Load Dataset

In [None]:
root = os.getenv('DATA_ROOT')

In [5]:
root = 'E:/USGS/'

In [6]:
train_loc = f'file:{root}' + '/NCIR/finetune/train-{000000..000013}.tar'

In [7]:
mu = [0.6577]
sigma = [0.1994]
lens = [44_641, 14_823, 15_112]

In [8]:
ds_train = (wds.WebDataset(train_loc, shardshuffle=True)
      .shuffle(10_000)
      .decode("pil")
      .to_tuple("jpg", "json")
      .map(T.AerMaeTransforms(is_train=True, mu=mu, sigma=sigma)))

In [31]:
train_loader = DataLoader(ds_train, batch_size=1)

In [None]:
ds_train = (wds.WebDataset(train_loc, resampled=True, shardshuffle=True, nodesplitter=wds.split_by_node)
    .shuffle(10_000)
    .decode("pil")
    .to_tuple("jpg", "json")
    .map(T.AerMaeTransforms(is_train=True, mu=mu, sigma=sigma))
    .batched(2))

ds_loader = (wds.WebLoader(ds_train, batch_size=None, num_workers=1)
    .unbatched()
    .shuffle(10_000)
    .batched(2)
    .with_length(10))

In [41]:
images, loc = next(iter(ds_loader))
images.shape, loc.shape

(torch.Size([2, 1, 224, 224]), torch.Size([2, 4]))

## Sample Data

In [10]:
images, loc = next(iter(train_loader))
images.shape, loc.shape

(torch.Size([1, 1, 224, 224]), torch.Size([1, 4]))

## Create MAE Model

In [11]:
mae = AerMae(img_size=(224, 224),
             patch_size=8,
             enc_geo_dim=256,
             enc_dim=512,
             dec_geo_dim=256,
             dec_dim=512,
             enc_layers=4,
             dec_layers=4,
             enc_heads=4,
             dec_heads=4,
             ff_mul=4,
             mask_pct=0.75)

## Create AeGIS Model

In [12]:
aegis = AeGIS(mae)
loss_fn = torch.nn.BCELoss()

## Walk through AeGIS steps

In [16]:
img, mask = mae(images, loc)
img.shape, mask.shape

(torch.Size([1, 784, 64]), torch.Size([1, 784]))

In [17]:
y_hat = aegis.conv(img)
y_hat.shape

torch.Size([1, 784, 64])

In [18]:
y_hat = aegis.sigmoid(y_hat)
y_hat.shape

torch.Size([1, 784, 64])

In [19]:
print(y_hat[0])

tensor([[0.4566, 0.3834, 0.4420,  ..., 0.6241, 0.5860, 0.2956],
        [0.4246, 0.4680, 0.5195,  ..., 0.4820, 0.3555, 0.3049],
        [0.3582, 0.3053, 0.5433,  ..., 0.3149, 0.3831, 0.4082],
        ...,
        [0.3946, 0.3985, 0.5449,  ..., 0.4905, 0.4146, 0.3393],
        [0.3890, 0.4122, 0.5573,  ..., 0.4822, 0.4106, 0.3374],
        [0.3774, 0.5008, 0.6071,  ..., 0.4235, 0.3324, 0.3808]],
       grad_fn=<SelectBackward0>)


## Full pass through

In [26]:
y_hat = aegis(images, loc)
y_hat.shape

torch.Size([1, 784, 64])

Calculate loss

In [27]:
# TODO: need actual segmentations
true = torch.randint(0, 2, y_hat.shape).float()
true.shape

torch.Size([1, 784, 64])

In [28]:
loss = loss_fn(true, y_hat)
loss

tensor(50.0598, grad_fn=<BinaryCrossEntropyBackward0>)

## Cleanup

In [30]:
ds_train.close()