In [None]:
from fastai.vision.all import *
from fastai.torch_core import TensorMask, TensorImage

import rasterio as rio

import timm
import torch
import fastai
from rasterio.enums import Resampling
import torch.multiprocessing
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [None]:
# download training data here
# https://drive.google.com/drive/folders/1Z60g3SxBiSTEZzuHDPOUuNL7I_Mg_dGq?usp=share_link

In [None]:
print(torch.__version__)
print(fastai.__version__)
print(default_device())

In [None]:
model_details = {"model_type": "regnety_006", "fp_16": True, "lr": 1e-3}

In [None]:
device = default_device()
device

In [None]:
model_version = "1.17"  # Assigns string value to model_version
model_name = f"{model_details['model_type']}_v{model_version}_model"  # Assigns string value to model_name
temp_file_name = model_name + "_temp"
model_name, temp_file_name

In [None]:
time_steps = 6
bands_per_timestep = 2

In [None]:
limited_band_read_list = [
    # 1,
    2,
    # 3,
    4,
    #  5,
    6,
    # 7,
    8,
    #  9,
    10,
    # 11,
    12,
    #  13,
    14,
    # 15,
    16,
    #  17,
    18,
    # 19,
    20,
    #  21,
    22,
    # 23,
    24,
]

In [None]:
bands_per_timestep = int(len(limited_band_read_list) / time_steps)
bands_per_timestep

In [None]:
path = Path("/media/nick/SNEAKERNET/training data")
label_path = path / "labels_2_3_4_8_V3"
images_path = path / "images_2_3_4_8_V3"
print(label_path.exists(), images_path.exists())

In [None]:
def label_func(label_path, file_path):
    label_path = label_path / file_path.name

    return label_path

In [None]:
# func to get img list to use excluding the bad images
def get_image_files_custom(source, p=False):
    return list(source.glob("[!.]*.tif"))

In [None]:
f_names = list(get_image_files_custom(images_path))
len(f_names)

In [None]:
f_names[0]

In [None]:
validation_paths = []
for file in f_names:
    if "Validation" in file.name:
        validation_paths.append(file)
len(validation_paths)

In [None]:
def is_valid_file(x, validation_paths):
    return x in validation_paths

In [None]:
torch.zeros(12, 128, 128).shape

In [None]:
class Switcheroo(RandTransform):
    "Randomly switch the order of the time steps, keeping the band order the same"

    split_idx, order = 0, 2

    def __init__(self, p=1, bands_per_timestep=2, time_steps=3):
        super().__init__(p=p)
        self.bands_per_timestep = bands_per_timestep
        self.time_steps = time_steps

    def encodes(self, x: (TensorImage)):
        new_time_step_order = torch.randperm(self.time_steps)
        new_order = [
            (i * self.bands_per_timestep) + j
            for i in new_time_step_order
            for j in range(self.bands_per_timestep)
        ]
        return x[:, new_order]

In [None]:
class BatchRot90(RandTransform):
    "Rotate image and mask by 0, 90, 180, or 270 degrees"
    split_idx, order = 0, 2

    def __init__(self, p=1):
        super().__init__(p=p)
        self.rots = 0

    def before_call(self, b, split_idx):
        if random.random() < self.p:
            self.rot = random.choice([0, 1, 2, 3])
        else:
            self.rot = 0

    def encodes(self, x: (TensorImage, TensorMask)):
        return x.rot90(self.rot, [-2, -1])

In [None]:
def open_mask(img_path, img_size):
    with rio.open(img_path) as src:
        raw_bands = src.read(
            1, out_shape=(img_size, img_size), resampling=Resampling.nearest
        )
    return TensorMask(torch.from_numpy(raw_bands))

In [None]:
device

In [None]:
def open_sar(img_path, img_size):
    sar_img_path = sar_path / img_path.name
    raw_bands = torch.zeros(1, img_size, img_size)
    if sar_img_path.exists():
        with rio.open(sar_img_path) as src:
            raw_bands = src.read(
                out_shape=(img_size, img_size),
                resampling=Resampling.nearest,
            )
    return raw_bands

In [None]:
image_cache = {}

In [None]:
def open_img(img_path, img_size):
    if img_path in image_cache:
        return image_cache[img_path]
    with rio.open(img_path) as src:
        raw_bands = src.read(limited_band_read_list, out_shape=(img_size, img_size))

    tensor_img = (
        TensorImage(torch.from_numpy(raw_bands.astype("float16"))).half().cuda()
    )
    image_cache[img_path] = tensor_img
    return tensor_img

In [None]:
open_img(f_names[0], 128).shape

In [None]:
bands_per_timestep

In [None]:
all_means = []
all_stds = []
for i in tqdm(f_names):
    image_tensor = open_img(i, 128) / 32767
    all_means.append(image_tensor.mean((1, 2)).tolist())
    all_stds.append(image_tensor.std((1, 2)).tolist())
all_stds = np.array(all_stds).mean((0))
all_means = np.array(all_means).mean((0))
# break

In [None]:
all_means

In [None]:
all_stds

In [None]:
# build datablock
def build_dblock(img_size):
    open_img_partial = partial(open_img, img_size=img_size)
    open_mask_partial = partial(open_mask, img_size=img_size)
    lable_func_partial = partial(label_func, label_path)

    dblock = DataBlock(
        blocks=(
            TransformBlock(open_img_partial),
            TransformBlock(open_mask_partial),
            # MaskBlock(),
        ),
        get_items=get_image_files_custom,
        get_y=lable_func_partial,
        splitter=FuncSplitter(lambda x: is_valid_file(x, validation_paths)),
        batch_tfms=[
            IntToFloatTensor(32767, 1),
            *aug_transforms(
                flip_vert=True,
                max_rotate=0,
                max_zoom=0.2,
                max_lighting=0.2,
                max_warp=0,
                p_affine=0,
                p_lighting=0.2,
                size=img_size,
            ),
            BatchRot90(),
            Switcheroo(bands_per_timestep=bands_per_timestep, time_steps=time_steps),
            Normalize.from_stats(mean=all_means, std=all_stds),
        ],
    )
    return dblock

In [None]:
dl = build_dblock(256).dataloaders(images_path, bs=16, num_workers=0)

In [None]:
ob = dl.one_batch()

In [None]:
ob[0].shape

In [None]:
timm_model = partial(
    timm.create_model,
    model_details["model_type"],
    pretrained=True,
    in_chans=dl.one_batch()[0].shape[1],
)

In [None]:
learner = unet_learner(
    dl, timm_model, pretrained=True, loss_func=MSELossFlat(), n_out=1
).to_fp16()

In [None]:
cbs = [
    SaveModelCallback(monitor="valid_loss", fname=model_name, with_opt=True),
    ShowGraphCallback(),
]

In [None]:
learner.fine_tune(
    freeze_epochs=50,
    epochs=200,
    cbs=cbs,
)

In [None]:
learner.load(model_name)

In [None]:
img_numb = 1
p = learner.predict(validation_paths[img_numb], with_input=True)

fig, axes = plt.subplots(1, 3, figsize=(15, 7))  # 1 row, 3 columns
axes[0].imshow(p[0].numpy()[2])
axes[0].axis("off")
axes[1].imshow(p[1].numpy()[0] > 0)
axes[1].axis("off")
axes[2].imshow(p[1].numpy()[0])
axes[2].axis("off")
plt.tight_layout()

In [None]:
pickle.dump(learner.model, open(f"{model_name}.pkl", "wb"))

In [None]:
f"{model_name}.pkl"