In [2]:
import os
import rasterio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch.nn.functional as F
import random
import torch
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
from segmentation_models_pytorch import utils
import segmentation_models_pytorch as smp
import cv2
from tqdm import tqdm
from typing import List, Optional
from rasterio.windows import Window

  from .autonotebook import tqdm as notebook_tqdm


# Визуализация

In [3]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [4]:
PALLETE = [
        [0, 0, 0],
        [0, 0, 255],
        ]


ENCODER = 'resnet18'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' 
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


EPOCHS = 5
BATCH_SIZE = 32

INIT_LR = 0.0005
LR_DECREASE_STEP = 15
LR_DECREASE_COEF = 2 # LR будет разделен на этот коэф раз в LR_DECREASE_STEP эпох
 
SIZE = 256
CHANELS = 10
OVERLAP = 0
loss = utils.losses.DiceLoss()

CLASSES = [
    "background",
    "water"
]

In [5]:
def get_tiles_with_overlap(image_width: int, image_height: int, 
                           tile_size: int, overlap: int) -> List[Window]:

    step_size = tile_size - overlap
    tiles = []
    for y in range(0, image_height, step_size):
        for x in range(0, image_width, step_size):
            window = Window(x, y, tile_size, tile_size)
            # Adjust window if it exceeds the image bounds
            window = window.intersection(Window(0, 0, image_width, image_height))
            tiles.append(window)
    return tiles

def save_tile(src_dataset: rasterio.io.DatasetReader, window: Window, 
              output_folder: str, tile_index: int, image_id: int) -> None:
    
    transform = src_dataset.window_transform(window)
    tile_data = src_dataset.read(window=window)
    
    profile = src_dataset.profile
    profile.update({
        'driver': 'GTiff',
        'height': window.height,
        'width': window.width,
        'transform': transform
    })
    
    output_filename = os.path.join(output_folder, f"tile_{image_id}_{tile_index}.tif")
    with rasterio.open(output_filename, 'w', **profile) as dst:
        dst.write(tile_data)
        
def split_image(image_path: str, output_folder: str, mask_path: Optional[str] = None, 
                tile_size: int = 512, overlap: int = 20, image_id: int = 20) -> None:

    with rasterio.open(image_path) as src_image:
        image_width = src_image.width
        image_height = src_image.height

        # Create output directories for images and masks (if available)
        images_folder = os.path.join(output_folder, 'images')
        os.makedirs(images_folder, exist_ok=True)

        if mask_path:
            masks_folder = os.path.join(output_folder, 'masks')
            os.makedirs(masks_folder, exist_ok=True)

        # Get list of tiles with overlap
        tiles = get_tiles_with_overlap(image_width, image_height, tile_size, overlap)

        # Save image tiles (and mask tiles if provided)
        if mask_path:
            with rasterio.open(mask_path) as src_mask:
                for idx, window in tqdm(enumerate(tiles)):
                    save_tile(src_image, window, images_folder, idx, image_id)
                    save_tile(src_mask, window, masks_folder, idx, image_id)
        else:
            for idx, window in tqdm(enumerate(tiles)):
                save_tile(src_image, window, images_folder, idx, image_id)

In [7]:

def split_N(image_size: int = SIZE,
            overlap: int = 0) -> None:
    data_list = ['1', '2', '4', '5', '6_1', '6_2', '9_1', '9_2']
    
    output_folder = f'train_split_{image_size}/' 
    for image_id in data_list:
        image_path = f'train/images/{image_id}.tif' 
        mask_path = f'train/masks/{image_id}.tif' 

        split_image(
        image_path=image_path, mask_path=mask_path,
        output_folder=output_folder, tile_size=image_size,
        overlap=overlap, image_id=image_id
        ) 
split_N()

# Считывание данных

In [8]:
def image_padding(image, target_size=SIZE):

    height, width = image.shape[1:3]
    pad_height = max(0, target_size - height)
    pad_width = max(0, target_size - width)
    padded_image = np.pad(image, ((0, 0), (0, pad_height),
                                  (0, pad_width)), mode='reflect')
    return padded_image



def mask_padding(mask, target_size=SIZE):

    height, width = mask.shape
    pad_height = max(0, target_size - height)
    pad_width = max(0, target_size - width)
    padded_mask = np.pad(mask, ((0, pad_height), (0, pad_width)),
                         mode='reflect')
    return padded_mask

def get_data_list(img_path, delete: Optional[bool] = True):

    name = []
    for _, _, filenames in os.walk(img_path): # given a directory iterates over the files
        for filename in filenames:
            f = filename.split('.')[0]
            name.append(f)

    df =  pd.DataFrame({'id': name}, index = np.arange(0, len(name))
                       ).sort_values('id').reset_index(drop=True)
    df = df['id'].values

    if delete:
        return np.delete(df, 0)
    else:
        return df  

In [9]:
class WaterDataset(Dataset):
    def __init__(self, img_path, mask_path, file_names):
        self.img_path = img_path
        self.mask_path = mask_path
        self.file_names = file_names

    def __len__(self):
            return len(self.file_names)

    def __getitem__(self, idx):
        with rasterio.open(self.img_path + self.file_names[idx] + '.tif') as fin:
            image = fin.read()
        image = image_padding(image).astype(np.float32)

        with rasterio.open(self.mask_path + self.file_names[idx] + '.tif') as fin:
            mask = fin.read(1)
        mask = mask_padding(mask)
         

        # Преобразуем тип данных изображения в float32
        image = image.astype(np.float32)
        image = torch.from_numpy(image)

        # Преобразуем тип данных маски в long
        mask = mask.astype(np.uint8)
        mask = torch.from_numpy(mask).long()
        mask = mask.unsqueeze(0)
        


        return image, mask

In [10]:
data_list = get_data_list(f'train_split_{SIZE}/images/')
ds = WaterDataset(
    img_path=f'train_split_{SIZE}/images/',
    mask_path=f'train_split_{SIZE}/masks/',
    file_names=data_list
)

dl = DataLoader(ds)


## Обучение

In [11]:
models_to_test = ['Linknet', 'FPN', 'UnetPlusPlus', 'DeepLabV3']

for model_name in models_to_test:

    data_list = get_data_list(f'train_split_{SIZE}/images/')
    ds = WaterDataset(
    img_path=f'train_split_{SIZE}/images/',
    mask_path=f'train_split_{SIZE}/masks/',
    file_names=data_list
)
    val_size = int(len(ds) * 0.2)  
    train_size = len(ds) - val_size  

    train_dataset, val_dataset = random_split(ds, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    model = getattr(smp, model_name)(
        encoder_name=ENCODER, 
        encoder_weights=ENCODER_WEIGHTS, 
        classes=1, 
        activation=ACTIVATION,
        in_channels=10
    )

    metrics = [
        utils.metrics.Fscore(),
        utils.metrics.IoU()
    ]


    loss = utils.losses.DiceLoss()
    optimizer = torch.optim.Adam([ 
        dict(params=model.parameters(), lr=INIT_LR),
    ])

    train_epoch = utils.train.TrainEpoch(
        model, 
        loss=loss, 
        metrics=metrics, 
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )

    valid_epoch = utils.train.ValidEpoch(
        model, 
        loss=loss, 
        metrics=metrics, 
        device=DEVICE,
        verbose=True,
    )
    max_score = 0

    loss_logs = {"train": [], "val": []}
    metric_logs = {"train": [], "val": []}

    print(f"Обучение {model_name}")
    for i in range(0, EPOCHS):

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        train_loss, train_metric, train_metric_IOU = list(train_logs.values())

        loss_logs["train"].append(train_loss)
        metric_logs["train"].append(train_metric_IOU)

        valid_logs = valid_epoch.run(val_loader)
        val_loss, val_metric, val_metric_IOU = list(valid_logs.values())

        loss_logs["val"].append(val_loss)
        metric_logs["val"].append(val_metric_IOU)

       
        if max_score < valid_logs['iou_score']:
            max_score = valid_logs['iou_score']
            torch.save(model, f'models/{model_name}.pth')

            trace_image = torch.randn(BATCH_SIZE, 10, SIZE, SIZE)
            traced_model = torch.jit.trace(model, trace_image.to(DEVICE))

            torch.jit.save(traced_model, f'models/{model_name}.pt')
            print('Model saved!')

        print("LR:", optimizer.param_groups[0]['lr'])
        if i > 0 and i % LR_DECREASE_STEP == 0:
            print('Decrease decoder learning rate')
            optimizer.param_groups[0]['lr'] /= LR_DECREASE_COEF

Обучение Linknet

Epoch: 0
train: 100%|██████████| 107/107 [03:05<00:00,  1.73s/it, dice_loss - 0.7765, fscore - 0.53, iou_score - 0.4211]  
valid: 100%|██████████| 27/27 [00:45<00:00,  1.68s/it, dice_loss - 0.6497, fscore - 0.7684, iou_score - 0.6483]


  if h % output_stride != 0 or w % output_stride != 0:


Model saved!
LR: 0.0005

Epoch: 1
train: 100%|██████████| 107/107 [02:14<00:00,  1.26s/it, dice_loss - 0.5345, fscore - 0.7457, iou_score - 0.6408]
valid: 100%|██████████| 27/27 [00:32<00:00,  1.20s/it, dice_loss - 0.4258, fscore - 0.8257, iou_score - 0.7185]
Model saved!
LR: 0.0005

Epoch: 2
train: 100%|██████████| 107/107 [02:14<00:00,  1.26s/it, dice_loss - 0.3528, fscore - 0.786, iou_score - 0.6864] 
valid: 100%|██████████| 27/27 [00:32<00:00,  1.22s/it, dice_loss - 0.29, fscore - 0.8203, iou_score - 0.7115]  
LR: 0.0005

Epoch: 3
train: 100%|██████████| 107/107 [02:11<00:00,  1.23s/it, dice_loss - 0.2822, fscore - 0.7853, iou_score - 0.6823]
valid: 100%|██████████| 27/27 [00:32<00:00,  1.21s/it, dice_loss - 0.2338, fscore - 0.8239, iou_score - 0.7229]
Model saved!
LR: 0.0005

Epoch: 4
train: 100%|██████████| 107/107 [02:12<00:00,  1.24s/it, dice_loss - 0.2158, fscore - 0.8273, iou_score - 0.7251]
valid: 100%|██████████| 27/27 [00:32<00:00,  1.19s/it, dice_loss - 0.195, fscore - 0.