In [1]:
import numpy as np
import cv2
import os
import random
import glob
from tqdm import tqdm 
import natsort
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

import matplotlib.pylab as plt
import warnings
warnings.filterwarnings("ignore")

In [3]:
import rioxarray as xr
import rasterio
from rasterio.plot import show
from early_stopping import EarlyStopping
from utils import DiceLoss

ModuleNotFoundError: No module named 'utils'

In [5]:
import torch 
from torchvision import transforms

from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

import segmentation_models_pytorch as smp

In [6]:
CFG = {
    'RESIZE_WIDTH':32,
    'RESIZE_HEIGHT':32,
    'EPOCHS':50,
    'LEARNING_RATE':3e-4,
    'BATCH_SIZE':8,
    'SEED':41,
}

In [7]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [8]:
# device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cpu')

In [9]:
img_dir = './train/s2_image/'
mask_dir = './train/mask/'
img_pattern = '*.tif'
mask_pattern = '*.tif'

# Glob and sort image and mask paths
train_img_paths = natsort.natsorted(glob.glob(img_dir + img_pattern))
train_mask_paths = natsort.natsorted(glob.glob(mask_dir + mask_pattern))

# Create a DataFrame
total_df = pd.DataFrame({'img_path': train_img_paths, 'mask_path': train_mask_paths})

In [10]:
total_df = total_df.iloc[:1000]

In [11]:
X_train, X_valid= train_test_split(total_df, test_size=0.2, random_state=CFG['SEED'])
print(X_train.shape, X_valid.shape)

(800, 2) (200, 2)


In [12]:
X_train = X_train.reset_index(drop=True)
X_valid = X_valid.reset_index(drop=True)

In [13]:
class CustomDataset(Dataset):
    def __init__(self, df, train_mode=True, transforms=None):
        self.image = df['img_path']
        self.mask = df['mask_path']

        self.train_mode = train_mode
        self.transforms = transforms
        self.img_list = self.feature_func(self.image)
        self.mask_list = self.mask_func(self.mask)
   
    def feature_func(self, df_path):
        img_list=[]
        for img_dir in tqdm(df_path):
            with rasterio.open(img_dir) as image:
                image_array = image.read() #(12,24,33)
                image_array = np.transpose(image_array, (1,2,0)) #(32,32,12)
                image_array = cv2.resize(image_array, (CFG['RESIZE_WIDTH'], CFG['RESIZE_HEIGHT']))   #(12,32,32)
                image_array = np.transpose(image_array, (2,0,1)) #(24,33,12)
                tr_img = torch.FloatTensor(image_array)
                img_list.append(tr_img)
        return np.array(img_list)
    
    def mask_func(self, df_path):
        mask_list=[]
        for img_dir in tqdm(df_path):
            with rasterio.open(img_dir) as image:
                image_array = image.read() #(1,24,24)
                image_array = np.transpose(image_array, (1,2,0)) #(32,32,12)
                image_array = cv2.resize(image_array, (CFG['RESIZE_WIDTH'], CFG['RESIZE_HEIGHT']))   #(12,32,32)
                image_array = np.reshape(image_array, (1, image_array.shape[0],image_array.shape[1])) #(1,32,32)
                mask = torch.FloatTensor(image_array)

                mask_list.append(mask)
        return np.array(mask_list)

    def __getitem__(self, index):
        X_feature = self.img_list[index]

        if self.transforms is not None:
            X_feature = self.transforms(X_feature)
        
        if self.train_mode:
            label = self.mask_list[index]
            return X_feature, label
        else:
            return X_feature

    def __len__(self):
        return len(self.image)

In [14]:
train_dataset = CustomDataset(X_train, train_mode=True)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True)

vali_dataset = CustomDataset(X_valid, train_mode=True)
vali_loader = DataLoader(vali_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False)

100%|██████████| 800/800 [00:05<00:00, 151.66it/s]
100%|██████████| 800/800 [00:05<00:00, 158.09it/s]
100%|██████████| 200/200 [00:01<00:00, 150.13it/s]
100%|██████████| 200/200 [00:01<00:00, 165.09it/s]


In [15]:
# next(iter(vali_loader))[0].shape

In [16]:
vali_dataset[0][0]

tensor([[[ 655.5820,  657.3223,  658.5923,  ...,  779.1923,  785.4283,
           793.2841],
         [ 655.6202,  657.4747,  658.8280,  ...,  779.2706,  785.5397,
           793.3956],
         [ 656.3389,  657.8773,  659.0000,  ...,  779.4662,  785.6343,
           793.4771],
         ...,
         [ 367.5000,  367.5000,  367.5000,  ...,  682.0000,  677.9036,
           672.2901],
         [ 367.5000,  367.5000,  367.5000,  ...,  682.0000,  677.7332,
           671.8860],
         [ 367.5000,  367.5000,  367.5000,  ...,  682.0000,  677.4995,
           671.3322]],

        [[1102.5396, 1045.2802,  959.6609,  ..., 1194.6150, 1116.6873,
          1021.5367],
         [1071.8906, 1036.2933,  966.7753,  ..., 1067.0142, 1024.7092,
           965.8724],
         [1033.1443, 1027.3279,  977.2184,  ...,  908.0965,  898.8860,
           875.2604],
         ...,
         [ 215.0388,  237.1836,  268.0695,  ...,  936.6506,  749.6335,
           555.6659],
         [ 217.4013,  237.2947,  266.933

In [17]:
vali_dataset[0][0].shape, vali_dataset[0][1].shape

(torch.Size([12, 32, 32]), torch.Size([1, 32, 32]))

In [18]:
seg_model = smp.FPN(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    # use `imagenet` pre-trained weights for encoder initialization
    # encoder_weights="imagenet",
    # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    in_channels=12,
    # model output channels (number of classes in your dataset)
    classes=1,
)

In [19]:
device

device(type='cpu')

In [20]:
# criterion = nn.BCELoss().to(device)
criterion = DiceLoss().to(device)
optimizer = torch.optim.SGD(params=seg_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3,threshold_mode='abs',min_lr=1e-8, verbose=True)

In [79]:
def validation(model, criterion, val_loader, device):
    model.eval()
    val_loss = []
    val_f1_list = []

    with torch.no_grad():
        for img, label in tqdm(iter(val_loader)):
            img, label = img.float().to(device), label.to(device)

            logit = model(img)
            logit = logit/255
            loss = criterion(logit, label)
            
            logit = logit.detach().cpu().numpy()
            label = label.detach().cpu().numpy()
            seg_target = label[0,:,:].flatten()
            threshold = 0.5
            binary_prediction = (logit[0,0,:,:] > threshold).astype(np.uint8)
            binary_prediction = binary_prediction.flatten()
            val_f1 = f1_score(seg_target, binary_prediction)

            val_loss.append(loss.item())
            val_f1_list.append(val_f1)

        _val_loss = np.mean(val_loss)
        _val_f1 = np.mean(val_f1_list)
        
    return _val_loss, _val_f1

In [80]:
from tqdm.auto import tqdm

def train(model, optimizer, train_loader, scheduler, device): 
    model.to(device)
    total_train_loss, total_valid_loss = [],[]
    total_train_f1, total_valid_f1 = [],[]

    patience = 5
    early_stopping = EarlyStopping(patience = patience, path = f'./weights/best.pt', verbose = True)

    for epoch in range(1, CFG['EPOCHS']):  # Adjusted to include the last epoch
        model.train()
        train_loss = []
        train_f1_list = []
        for img, label in tqdm(iter(train_loader)):
            img, label = img.float().to(device), label.to(device)
            optimizer.zero_grad()
            logit = model(img)
            logit = logit/255
            loss = criterion(logit, label)  
            loss.backward()
            optimizer.step()
            
            logit = logit.detach().cpu().numpy()
            label = label.detach().cpu().numpy()
            seg_target = label[0,:,:].flatten()
            threshold = 0.5
            binary_prediction = (logit[0,0,:,:] > threshold).astype(np.uint8)
            binary_prediction = binary_prediction.flatten()
            train_f1 = f1_score(seg_target, binary_prediction)

            train_loss.append(loss.item())
            train_f1_list.append(train_f1)

        _train_loss = np.mean(train_loss)
        _train_f1 = np.mean(train_f1_list)

        _val_loss, _val_f1 = validation(model, criterion, vali_loader, device)
        total_train_loss.append(_train_loss)
        total_valid_loss.append(_val_loss)
        total_train_f1.append(_train_f1)

        print(f'Epoch [{epoch}], Train Loss : [{_train_loss:.5f}] Train F1 : [{_train_f1:.5f}] Val Loss : [{_val_loss:.5f}] Val F1 : [{_val_f1:.5f}]]')
        if scheduler is not None:
            scheduler.step(_val_loss)

        early_stopping(_val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    return total_train_loss, total_valid_loss

In [81]:
train(seg_model, optimizer, train_loader, scheduler, device)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [1], Train Loss : [0.94179] Train F1 : [0.00000] Val Loss : [0.94573] Val F1 : [0.00000]]
Validation loss decreased (inf --> 0.945726).  Saving model ...


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [2], Train Loss : [0.94230] Train F1 : [0.00000] Val Loss : [0.94497] Val F1 : [0.00000]]
Validation loss decreased (0.945726 --> 0.944967).  Saving model ...


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

Epoch [3], Train Loss : [0.94070] Train F1 : [0.00000] Val Loss : [0.94314] Val F1 : [0.00000]]
Epoch 00013: reducing learning rate of group 0 to 1.2500e-04.
Validation loss decreased (0.944967 --> 0.943135).  Saving model ...


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 