In [None]:
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import pytorch_lightning as pl
import torch
from model.isnet import DISNet, GtEncoder

In [None]:
state_dict = torch.load('saved_model/pretrained/isnet.pth', map_location='cpu')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Any

bce_loss = nn.BCELoss()
bce_w_loss = nn.BCEWithLogitsLoss()
mse_loss = nn.MSELoss(reduction = "mean")

def bce_loss_calc(gt_feature_maps, gt):
    sum_loss = 0
    for idx, output in enumerate(gt_feature_maps):
        loss = bce_loss(component, gt)
        sum_loss += loss
    return sum_loss

def feature_sync(gt_outputs, u2net_outputs):
    loss_lst = []
    for idx, gt_output in enumerate(gt_outputs):
        loss = mse_loss(gt_output, u2net_outputs[idx])
        loss_lst.append(loss)

    loss = sum(loss_lst)
    return loss

    loss = nn.L1Loss()
    return loss(input_data, target_data)

class Net(pl.LightningModule):
    def __init__(self, model, pretrained: str = None, lr: float = 0.001, epsilon: float = 1e-08, batch_size: int = 0) -> object:
        super(Net, self).__init__()
        self.lr = lr
        self.epsilon = epsilon
        self.net = model
        self.gt_encoder = None
        self.batch_size = batch_size
        
        if pretrained:
            state_dict = torch.load(pretrained, map_location='cpu')
            self.net.load_state_dict(state_dict)
            print('----------------------------------------------------------------------------------------------------')
            print('pretrained loaded')
            print('----------------------------------------------------------------------------------------------------')
    
    def load_gt_encoder(self, gt_encoder, pretrained: str = None):
        self.gt_encoder = gt_encoder
        if pretrained:
            state_dict = torch.load(pretrained, map_location='cpu')
            self.gt_encoder.load_state_dict(state_dict)
            
        self.gt_encoder.eval()
        print('gt_encoder is loaded')
        print('----------------------------------------------------------------------------------------------------')
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.u2net.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=self.epsilon, weight_decay=0)
        return optimizer
    
    def forward(self, x):
        return self.net(x)
    
    def _common_step(self, batch, batch_idx):
        image, gt = batch['image'], batch['gt']
        im_side_outputs, im_features = self.net(image)
        loss = bce_loss(outputs, gt)
        
        if self.gt_encoder:
            gt_side_outputs, gt_features = self.gt_encoder(gt)
            fs_mse_loss = feat_sync(gt_features, im_features)
            loss += fs_mse_loss
        
        self.log(f"{stage}_loss", loss, on_epoch=True)
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss = self._common_step(batch, batch_idx)
        self.val_loss = loss
        return loss
    
#     def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
#         state_dict = checkpoint['state_dict']
#         epoch = checkpoint['epoch']
#         PATH = f'epoch={epoch}-val_loss={self.val_loss}-batch_size={self.batch_size}.pth'
        
#         key_word = 'net.'
#         new_sd = OD()
#         for key, value in state_dict.items():
#             if key_word in key:
#                 key = key.replace(key_word, '')
#             new_sd[key] = value
        
#         # save .pth file seperately
#         torch.save(new_sd, PATH)
    
    def predict_step(self, batch, batch_idx):
        image, gt = batch['image'], batch['gt']
        return self.net(image)
        

# DISNet Train

In [None]:
disnet = DISNet(3,1)
gt_encoder = GtEncoder(1,1)
disnet_pretrained = 'saved_model/pretrained/isnet.pth'

net = Net(disnet, pretrained=disnet_pretrained)
net.load_gt_encoder(gt_encoder, pretrained=None)

# GTEncoder Train

In [None]:
gt_encoder = GtEncoder(1,1)
net = Net(gt_encoder)

# Dataset GtEncoder

In [None]:
%load_ext autoreload
%autoreload 2

from utils.gt_dataset import Dataset
import albumentations as A

input_size = 1280

mask_transform = A.Compose([
    A.Resize(width=input_size, height=input_size),
    A.RandomCrop(width=1024, height=1024),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.8),
    A.RandomRotate90(p=0.8)
])

image_transform = A.Compose([
    A.CLAHE(p=0.8),
    A.RandomBrightnessContrast(p=0.8),
    A.RandomGamma(p=0.8)]
)

batch_size = 8

tr_ds = Dataset(image_path='../data/DIS5K/DIS-TR/gt', transform=mask_transform)
len(tr_ds)

In [None]:
A.Compose(mask_transform[1:])(image=im_arr)

In [None]:
import numpy as np
image_path = tr_ds.images[1200]
im = Image.open(image_path).convert('L').resize((1024,1024))
im_arr = np.array(im)
Image.fromarray(im_arr)

In [None]:
from PIL import Image
data = tr_ds[0]
display(Image.fromarray(data['image']))
Image.fromarray(data['gt'])

# Dataset DISNet

In [None]:
%load_ext autoreload
%autoreload 2

from utils.isnet_dataset import Dataset
import albumentations as A

input_size = 1280

mask_transform = A.Compose([
    A.Resize(width=input_size, height=input_size),
    A.RandomCrop(width=1024, height=1024),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.8),
    A.RandomRotate90(p=0.8)
])

image_transform = A.Compose([
    A.CLAHE(p=0.8),
    A.RandomBrightnessContrast(p=0.8),
    A.RandomGamma(p=0.8)]
)

batch_size = 8

tr_ds = Dataset(image_path='../data/DIS5K/DIS-TR/im', gt_path='../data/DIS5K/DIS-TR/gt',
                image_transform=image_transform,
                gt_transform=mask_transform,
                load_on_mem=True)
len(tr_ds)

In [None]:
from PIL import Image
data = tr_ds[0]
display(Image.fromarray(data['image']))
Image.fromarray(data['gt'])