In [1]:
import os
import random 
import time 
import shutil
from tqdm.notebook import tqdm

import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image, ImageReadMode, write_png
import torchvision.transforms.v2 as v2

import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import cv2
import py_sod_metrics

from model import *
from util import *

In [2]:
#seeding
##seeding is not implemented for now, sorry.
seedeverything(int(time.time()))
#seting
device = torch.device("cuda:4")
VAL_FRAC = 0.15
IMG_SIZE = (240,240)    # use (240, 428) for best perf. unless your are debugging
TRAINING_PATH = './Training_dataset/'
TESTING_PATH = './Testing_dataset/'
PRIVATE_PATH = './Private_dataset/'

In [3]:
#define model
max_epoch = 35
lr = 0.001
model = UNet(IMG_SIZE)
optim = Adam(model.parameters(), lr = lr, weight_decay =0.0)#0.00001
scheduler = StepLR(optim, step_size=15, gamma=0.1)

In [4]:
##split training and validation data
names = sorted(f[:-4] for f in os.listdir( TRAINING_PATH + 'img') if f.endswith('jpg'))
divider = int(len(names)*VAL_FRAC)
random.shuffle(names)
val_names, training_names = names[:divider], names[divider:]

## **define dataset**

In [5]:
class ImageDataset(Dataset):
    def __init__(self, names,augmentation = False, preload = True, to_gpu = False):
        '''
        args:
            augmentation:if False, consider this is a validation dataset.(otherwise a training dataset)
            preload: preload all img(label) into a tensor, reduce abt 20% time of data loading
            to_gpu: If True, transfer all preload data to ${device}. Significntly reduce time of data loading and transfering. 
                    note that you need enough gpu memory or may get an error.
        '''
        self.names = names
        self.aug = augmentation
        self.preload = preload
        self.to_gpu = to_gpu
        self.img_transform = v2.Compose([
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize(IMG_SIZE, antialias=True),
            Clip(),
            v2.Normalize((0.519,0.535,0.442),(0.196, 0.175, 0.207)),  ##(0.519,0.535,0.442),(0.196, 0.175, 0.207) or (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
        ])
        
        self.label_transform = v2.Compose([
            v2.ToDtype(torch.float32, scale=True),
            v2.Resize(IMG_SIZE, antialias=True),
            Clip(),
            #Binarize(0.5), # use this if using Dice_Loss
        ])
        
        if self.preload:
            self.img_buffer = [read_image(TRAINING_PATH + 'img/' + name + '.jpg', mode = ImageReadMode.RGB) for name in names]
            self.label_buffer = [read_image(TRAINING_PATH + 'label_img/' + name+'.png', mode = ImageReadMode.GRAY) for name in names]
            self.img_buffer = torch.stack(self.img_buffer, dim=0)
            self.label_buffer = torch.stack(self.label_buffer, dim=0)
            if self.to_gpu:
                self.img_buffer, self.label_buffer = self.img_buffer.to(device), self.label_buffer.to(device)
            self.img_buffer = self.img_transform(self.img_buffer)
            self.label_buffer = self.label_transform(self.label_buffer)

        
        # apply flip, rotate...here
        self.aug_transform = v2.Compose([
            v2.RandomHorizontalFlip(),
        ])
        self.img_transform_1 = v2.Compose([
            #v2.RandomApply(torch.nn.ModuleList([]), p=0.2),
            #RandomChannelSwap(0.1),
            #v2.RandomGrayscale(0.1),
            add_gaussian_noise(0.05),
            #v2.GaussianBlur(5),
        ])
        
    def __len__(self):
        return len(self.names)
    @torch.no_grad()
    def __getitem__(self, index):
        name = self.names[index]
        #preprocess pipeline
        #img:img_transform -> img_transform_1 -> (aug_transform)
        #label:label_transform -> aug_tramsform
        if self.preload:
            img = self.img_buffer[index].clone()
            label = self.label_buffer[index].clone()
        else:
            img = read_image(TRAINING_PATH + 'img/' + name + '.jpg', mode = ImageReadMode.RGB)
            label = read_image(TRAINING_PATH + 'label_img/' + name+'.png', mode = ImageReadMode.GRAY)
            img = self.img_transform(img)
            label = self.label_transform(label)
        if self.aug:
            img = self.img_transform_1(img)
            rng = torch.random.get_rng_state()
            img = self.aug_transform(img)
            torch.random.set_rng_state(rng)
            label = self.aug_transform(label)
        return img, label
    '''
    def _try_rng(self): # try if augmentation is ok
        for i in range(10):
            img = read_image(TRAINING_PATH + 'img/' + random.choice(self.names) + '.jpg', mode = ImageReadMode.RGB)
            rng = torch.random.get_rng_state()
            tran1 = self.aug_transform(img)
            torch.random.set_rng_state(rng)
            tran2 = self.aug_transform(img)
            if not (tran1==tran2).all().item():
                print('something went wrong!!')
                return
        print('augmentation checked.')
        return
    '''

In [6]:
#prepare dataset
start_time = time.time()
train_dataset = ImageDataset(training_names, augmentation = True, to_gpu = True) # take abt 2.5 min
print(time.time()-start_time)
start_time = time.time()
val_dataset = ImageDataset(val_names, augmentation = False, to_gpu = True) # take abt 0.5 min
print(time.time()-start_time)

#train_dataset._try_rng()
train_loader = DataLoader(dataset = train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(dataset = val_dataset, batch_size=8, shuffle=True)

22.420174598693848
3.5902302265167236


# **training**

In [7]:
#training
model.to(device)
for epoch in range(max_epoch):
    ## training
    model.train()
    total_loss = 0
    loss_cnt = 0
    #####timing#####
    load_time, trans_time, comp_time = 0,0,0
    start_time = time.time()
    ################
    for x, y in tqdm(train_loader, leave = False, desc = 'training'):
        #####timing#####
        mid_time = time.time()
        ################
        x, y = x.to(device), y.to(device)
        #####timing#####
        end_time = time.time()
        ################
        y_hat = model(x)
        loss = F.binary_cross_entropy(y_hat, y)#(F.binary_cross_entropy(y_hat, y) + dl_criterion(y_hat, y))/2
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        total_loss += loss * len(x)
        loss_cnt += len(x)


        #####timing#####
        comp_time += time.time()-end_time
        trans_time += end_time-mid_time
        load_time += mid_time-start_time
        start_time = time.time()
        ################
    #print(f'loading time:{load_time:.1f}s, transfering time:{trans_time:.1f}s, computing time:{comp_time:.1f}s')

    
    train_loss = total_loss/loss_cnt
    ## validation
    model.eval()
    total_loss = 0
    loss_cnt = 0
    with torch.no_grad():
        for x, y in tqdm(val_loader, leave = False, desc = 'validating'):
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = F.binary_cross_entropy(y_hat, y)
            total_loss += loss * len(x)
            loss_cnt += len(x)
        val_loss = total_loss/loss_cnt
    print(f'epoch{epoch+1:3d}: train loss {train_loss:.4f}\tval loss {val_loss:.4f}','\t|\t'
         f'cost: ld {load_time:.1f}s, trans {trans_time:.1f}s, comp {comp_time:.1f}s')


    scheduler.step()

training:   0%|          | 0/459 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  1: train loss 0.1025	val loss 0.0480 	|	cost: ld 1.1s, trans 0.0s, comp 39.1s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  2: train loss 0.0477	val loss 0.0463 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  3: train loss 0.0454	val loss 0.0437 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  4: train loss 0.0440	val loss 0.0429 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  5: train loss 0.0432	val loss 0.0418 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  6: train loss 0.0424	val loss 0.0420 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  7: train loss 0.0420	val loss 0.0410 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  8: train loss 0.0411	val loss 0.0402 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch  9: train loss 0.0408	val loss 0.0400 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 10: train loss 0.0404	val loss 0.0387 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 11: train loss 0.0400	val loss 0.0393 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 12: train loss 0.0397	val loss 0.0378 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 13: train loss 0.0391	val loss 0.0376 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 14: train loss 0.0389	val loss 0.0371 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 15: train loss 0.0386	val loss 0.0370 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 16: train loss 0.0367	val loss 0.0358 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 17: train loss 0.0362	val loss 0.0354 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 18: train loss 0.0361	val loss 0.0354 	|	cost: ld 1.1s, trans 0.0s, comp 38.6s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 19: train loss 0.0359	val loss 0.0353 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 20: train loss 0.0358	val loss 0.0351 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 21: train loss 0.0356	val loss 0.0351 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 22: train loss 0.0355	val loss 0.0351 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 23: train loss 0.0354	val loss 0.0351 	|	cost: ld 1.1s, trans 0.0s, comp 38.6s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 24: train loss 0.0353	val loss 0.0350 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 25: train loss 0.0352	val loss 0.0348 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 26: train loss 0.0352	val loss 0.0351 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 27: train loss 0.0351	val loss 0.0347 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 28: train loss 0.0350	val loss 0.0345 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 29: train loss 0.0350	val loss 0.0345 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 30: train loss 0.0349	val loss 0.0345 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 31: train loss 0.0345	val loss 0.0344 	|	cost: ld 1.1s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 32: train loss 0.0344	val loss 0.0343 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 33: train loss 0.0344	val loss 0.0345 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 34: train loss 0.0344	val loss 0.0343 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


training:   0%|          | 0/459 [00:00<?, ?it/s]

validating:   0%|          | 0/81 [00:00<?, ?it/s]

epoch 35: train loss 0.0344	val loss 0.0344 	|	cost: ld 1.0s, trans 0.0s, comp 38.7s


# **validation**

In [8]:
#use validation set to evaluate
def inference(folder, names, model, transform, output_folder, binarize = True, device = 'cuda:0'):
    os.makedirs(output_folder, exist_ok=True)
    if names is None:
        names = sorted(f[:-4] for f in os.listdir(folder) if f.endswith('jpg'))
    saved_cnt = 0
    model.to(device)
    model.eval()
    with torch.no_grad():
        for name in tqdm(names, leave = False, desc = 'inferencing'):
            img = read_image(folder+name+'.jpg', mode = ImageReadMode.RGB)
            img = transform(img).unsqueeze(0).to(device)
            predict = nn.functional.interpolate(model(img), size=(240, 428), mode='bilinear', align_corners=False)
            if binarize:
                predict = (predict >0.5)* 255
            else:
                predict = predict * 255
            predict = predict.squeeze(0).to('cpu', dtype = torch.uint8 )
            write_png(predict, output_folder+name+'.png')
            saved_cnt += 1
    print(f'{saved_cnt} files saved to {output_folder}')
    return



inference(TRAINING_PATH+'img/', val_names, model, train_dataset.img_transform, './tmp/', binarize = True, device = device)


FMv2 = py_sod_metrics.FmeasureV2(
    metric_handlers={
        "fm": py_sod_metrics.FmeasureHandler(with_dynamic=True, with_adaptive=False, beta=0.3),
    }
)

for name in val_names:
    label = cv2.imread(TRAINING_PATH+'label_img/'+name+'.png', cv2.IMREAD_GRAYSCALE)
    predict = cv2.imread('./tmp/'+name+'.png', cv2.IMREAD_GRAYSCALE)
    FMv2.step(pred=predict, gt=label)

fmv2 = FMv2.get_results()
print("mean F score: ",fmv2["fm"]["dynamic"].mean())

inferencing:   0%|          | 0/648 [00:00<?, ?it/s]

648 files saved to ./tmp/
mean F score:  0.7797104532135186


# **submiting**

In [9]:
#output public testing folder

# if you want  to load model, use next two line. Note that you should use the weight save by torchscript (the version without _.pt)
#model = torch.jit.load(PATH)
#model.eval()


t = time.localtime()
t = f'{t.tm_mon:02d}{t.tm_mday:02d}-{t.tm_hour:02d}{t.tm_min:02d}'
print('submit-'+t)
output_folder = './submit/submit-'+t+'/'
inference(TESTING_PATH, None, model, train_dataset.img_transform, output_folder, binarize = True, device = device)
inference(PRIVATE_PATH, None, model, train_dataset.img_transform, output_folder, binarize = True, device = device)
shutil.make_archive('submit/submit-'+t, 'zip', output_folder)

torch.save(model.state_dict(), './weights/'+t+'_.pt')
torch.jit.script(model).save('./weights/'+t+'.pt')

submit-0604-1936


inferencing:   0%|          | 0/720 [00:00<?, ?it/s]

720 files saved to ./submit/submit-0604-1936/


inferencing:   0%|          | 0/720 [00:00<?, ?it/s]

720 files saved to ./submit/submit-0604-1936/


# **calculate std and mean**

In [10]:
tc0, tc1, tc2 = [],[],[]
vc0, vc1, vc2 = [],[],[]

for i in range(len(train_dataset)):
    tc0.append(train_dataset[i][0][0,:,:])
    tc1.append(train_dataset[i][0][1,:,:])
    tc2.append(train_dataset[i][0][2,:,:])
    
for i in range(len(val_dataset)):
    vc0.append(val_dataset[i][0][0,:,:])
    vc1.append(val_dataset[i][0][1,:,:])
    vc2.append(val_dataset[i][0][2,:,:])

In [11]:
ttc0 = torch.stack(tc0, dim=0)
ttc1 = torch.stack(tc1, dim=0)
ttc2 = torch.stack(tc2, dim=0)
tvc0 = torch.stack(vc0, dim=0)
tvc1 = torch.stack(vc1, dim=0)
tvc2 = torch.stack(vc2, dim=0)
print('----std_mean in training set----')
print(torch.std_mean(ttc0))
print(torch.std_mean(ttc1))
print(torch.std_mean(ttc2))
print('----std_mean in validation set----')
print(torch.std_mean(tvc0))
print(torch.std_mean(tvc1))
print(torch.std_mean(tvc2))
print('----std_mean in all dataset(training + validating)----')
print(torch.std_mean(torch.cat( (ttc0,tvc0) )))
print(torch.std_mean(torch.cat( (ttc1,tvc1) )))
print(torch.std_mean(torch.cat( (ttc2,tvc2) )))
#->[(0.519,0.535,0.442),(0.196, 0.175, 0.207)]

----std_mean in training set----
(tensor(1.0015, device='cuda:4'), tensor(0.0019, device='cuda:4'))
(tensor(0.9965, device='cuda:4'), tensor(0.0008, device='cuda:4'))
(tensor(0.9975, device='cuda:4'), tensor(0.0022, device='cuda:4'))
----std_mean in validation set----
(tensor(1.0114, device='cuda:4'), tensor(0.0042, device='cuda:4'))
(tensor(1.0142, device='cuda:4'), tensor(-0.0071, device='cuda:4'))
(tensor(1.0100, device='cuda:4'), tensor(0.0011, device='cuda:4'))
----std_mean in all dataset(training + validating)----
(tensor(1.0030, device='cuda:4'), tensor(0.0023, device='cuda:4'))
(tensor(0.9992, device='cuda:4'), tensor(-0.0004, device='cuda:4'))
(tensor(0.9994, device='cuda:4'), tensor(0.0021, device='cuda:4'))
