In [53]:
# %autosave 60
%matplotlib inline
import pyift.pyift as ift
import matplotlib.pyplot as plt
plt.rcParams['axes.grid'] = False
import torch
from flim.experiments import utils, LIDSDataset, ToTensor
from flim.models.lcn import LCNCreator
import cv2 as cv
import torch.nn.functional as F
from torchvision    import transforms, models
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torchvision
from os import listdir
from os.path import isfile, join
import random
import numpy as np
import random
import gc
from torchsummary import summary
from functools import partial
from tqdm.notebook import tqdm
import skimage.measure
import math

device = torch.device('cuda')

# Loss definition

In [54]:
class TripletLoss(torch.nn.Module):
    """
    Triplet loss function.
    """

    def __init__(self, margin=2.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, pos_prob, neg_prob):

        distance_1 = F.pairwise_distance(anchor, positive, keepdim = True)
        distance_2 = F.pairwise_distance(anchor, negative, keepdim = True)

        triplet_loss = (F.relu(torch.mul(distance_1, pos_prob) - torch.mul(distance_2, neg_prob) + self.margin)).mean()
        acc = (distance_1 < distance_2).sum() * 1.0 / distance_1.size()[0]

        return triplet_loss, acc
    
class TripletLossWithPairMining(torch.nn.Module):
    """
    Triplet loss function.
    """

    def __init__(self, margin=2.0):
        super(TripletLossWithPairMining, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, pos_prob, neg_prob):

        distance_1 = F.pairwise_distance(anchor, positive, keepdim = True)
        distance_2 = F.pairwise_distance(anchor, negative, keepdim = True)
        
        # semi-hard mining
        mask = (distance_2 < (distance_1 + self.margin)) * (distance_2 > (distance_1))
        
        # hard mining
        #mask = (distance_2 < (distance_1))
        
        triplet_loss = (F.relu(torch.mul(distance_1, pos_prob) - 
                               torch.mul(distance_2, neg_prob) + self.margin) * mask).mean()
        
        
        acc = (distance_1 < distance_2).sum() * 1.0 / distance_1.size()[0]

        return triplet_loss, acc

# Dataset definition

In [55]:
debug = False
def set_debug_mode(mode):
    global debug
    debug = mode

def dprint(msg):
    if debug: print(msg)

def valid_voxel(img, v):
    return 0 <= v.x <= (img.xsize - 1) and 0 <= v.y <= (img.ysize - 1)

def find_valid_neighbour(img, anchor, same_label, adj_radius):
    A = ift.Circular(adj_radius)
    p = img.GetCoordIndex(anchor[0], anchor[1], 0)
    u = img.GetVoxel(p)
    for i in random.sample(range(1, A.n), A.n - 1):
        v = ift.GetAdjacentVoxel(A, u, i)
        if valid_voxel(img, v):
            q = img.GetVoxelIndex(v)
            if (img[p] == img[q]) == same_label: # XNOR
                return [v.x, v.y], ((v.x - u.x) ** 2 + (v.y - u.y) ** 2) ** 0.5
    dprint([u.x, u.y])
    return [u.x, u.y], 1

def get_prob_from_pixel(pmap, coord):
    return pmap[coord[1], coord[0]]


class DeepDynamicTreesDataset(Dataset):
    def __init__(self, path, nitems, csv="files.csv", radius_adj=5, ext_img="png", ext_label="png",
                 use_prob_map=True, prob_tresh=0.97):
        dprint("Creating Dataset...")
        self.path = path
        self.csv = csv
        self.nitems = nitems
        self.ext_img = ext_img
        self.ext_label = ext_label
        self.radius_adj = radius_adj
        self.count_epochs = 0
        self.use_prob_map = use_prob_map
        self.prob_tresh = prob_tresh

        self.imgs = {"name": [],
                     "orig": [],
                     "tensors": [],
                     "labels": [],
                     "markers_file": {"fg_seeds": [], "bg_seeds": []},
                     "test_markers" : [],
                     "prob_map": {'fg': [], 'bg': []}}

        file_csv = open(f"{path}/{csv}","r")
        self.imgs["name"] = [img.strip() for img in file_csv]

        dprint("Loading images...")
        self.read_images()
        self.read_images_as_features()
        dprint("Loading labels...")
        self.read_labels()
        dprint("Loading markers from files...")
        self.read_markers_file()

        self.items = []

        self.next_epoch(None)

    def calculate_probability_map(self, model):
        dprint("Calculating probability map...")
        self.imgs['prob_map']['fg'].clear()
        self.imgs['prob_map']['bg'].clear()

        if model is not None:
            device = next(model.parameters()).device
        for i in range(len(self.imgs['name'])):
            imag = self.imgs['orig'][i]
            seed = self.imgs['test_markers'][i]
            mimg = None
            if model is not None:
                x = self.imgs['tensors'][i].to(device)
                _feat = model(x)
                mimg = ift.CreateMImageFromNumPy(np.ascontiguousarray(_feat.squeeze(0).permute(1,2,0).detach().cpu().numpy()))
            else:
                mimg = ift.ImageToMImage(imag, ift.LABNorm_CSPACE)
            wmap = ift.DTRootWeightsMap(mimg, ift.Circular(1.0), seed, 0, False).AsNumPy()
            wmap = np.squeeze(wmap, 0)
            wmap = np.squeeze(wmap, 2)
            labl = ift.DynTreeRoot(mimg, ift.Circular(1.0), seed, 0, 0, None, 0)
            labl = labl.AsNumPy()

            _max = np.max(wmap)
            pmap = (_max - wmap) / (_max + wmap)
            pmap = np.where(pmap > self.prob_tresh, pmap, 0)
            pmfg = np.where(labl != 0, pmap, 0)
            pmbg = np.where(labl == 0, pmap, 0)

            self.imgs['prob_map']['fg'].append(pmfg)
            self.imgs['prob_map']['bg'].append(pmbg)
            self.set_pseudo_markers(i)
            

    def read_images(self):
        for name in self.imgs["name"]:
            orig_path = f"{self.path}/orig/{name}.{self.ext_img}"
            orig = ift.ReadImageByExt(orig_path)
            self.imgs["orig"].append(orig)

    def read_images_as_features(self):
        for name in self.imgs["name"]:
            orig_path = f"{self.path}/orig/{name}.{self.ext_img}"
            img = cv.imread(orig_path)
            img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
            img = transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.0, 0.0, 0.0], [1.0, 1.0, 1.0])
            ])(img)
            img_tensor = img.unsqueeze(0)
            self.imgs["tensors"].append(img_tensor)


    def read_labels(self):
        for name in self.imgs["name"]:
            gt_path = f"{self.path}/labels/{name}.{self.ext_label}"
            gt = cv.imread(gt_path, 0)
            gt = ift.CreateImageFromNumPy(np.int32(gt), False)
            self.imgs["labels"].append(gt)

    def read_markers_file(self):
        for i, name in enumerate(self.imgs["name"]):
            markers_path = f"{self.path}/markers/{name}.txt"
            img = self.imgs["orig"][i]
            seeds = ift.ReadSeeds(img, markers_path)
            sdict = seeds.AsDict()
            fg_seeds = [[img.GetVoxel(int(k)).x, img.GetVoxel(int(k)).y] for k, v in sdict.items() if v != 0]
            bg_seeds = [[img.GetVoxel(int(k)).x, img.GetVoxel(int(k)).y] for k, v in sdict.items() if v == 0]
            self.imgs["markers_file"]["fg_seeds"].append(fg_seeds)
            self.imgs["markers_file"]["bg_seeds"].append(bg_seeds)
            self.imgs["test_markers"].append(seeds)

    def set_pseudo_markers(self, i):
        fg, bg = self.imgs['prob_map']['fg'][i], self.imgs['prob_map']['bg'][i]
        bg = np.int32(np.where(bg != 0, 1, 0))
        bg = ift.CreateImageFromNumPy(bg, False)
        fg = np.int32(np.where(fg != 0, 2, 0))
        fg = ift.CreateImageFromNumPy(fg, False)
        fg_seeds = np.int32(ift.MaskImageToSet(fg, None)[0].AsNumPy())
        bg_seeds = np.int32(ift.MaskImageToSet(bg, None)[0].AsNumPy())
        fg_seeds = [[fg.GetVoxel(int(s)).x, fg.GetVoxel(int(s)).y] for s in fg_seeds]
        bg_seeds = [[bg.GetVoxel(int(s)).x, bg.GetVoxel(int(s)).y] for s in bg_seeds]
        self.imgs["markers_file"]["fg_seeds"][i] = fg_seeds
        self.imgs["markers_file"]["bg_seeds"][i] = bg_seeds


    def next_epoch(self, model):
        dprint("Loading next epoch...")
        if self.use_prob_map:
            self.calculate_probability_map(model)
        self.items.clear()
        self.count_epochs += 1

        files_len = len(self.imgs["name"])

        bg_set = self.imgs["markers_file"]["bg_seeds"]
        fg_set = self.imgs["markers_file"]["fg_seeds"]

        nitems_per_file = int(self.nitems / files_len)

        # Triplets mining
        for ix in range(0, files_len):
            for _ in range(0, nitems_per_file, 2):
                # foreground anchor sampling
                anchor = random.choice(fg_set[ix])
                positive, _ = find_valid_neighbour(self.imgs["labels"][ix], anchor, True, self.radius_adj)
                negative = random.choice(bg_set[ix])

                if not self.use_prob_map:
                    neg_prob = pos_prob = 1
                else:
                    pos_prob = get_prob_from_pixel(self.imgs['prob_map']['fg'][ix], positive)+0.0001
                    neg_prob = get_prob_from_pixel(self.imgs['prob_map']['bg'][ix], negative)+0.0001
                self.items.append([ix, anchor, positive, negative, pos_prob, neg_prob])

                # background anchor sampling
                anchor   = random.choice(bg_set[ix])
                positive, _ = find_valid_neighbour(self.imgs["labels"][ix], anchor, True, self.radius_adj)
                negative = random.choice(fg_set[ix])

                if not self.use_prob_map:
                    neg_prob = pos_prob = 1
                else:
                    pos_prob = get_prob_from_pixel(self.imgs['prob_map']['bg'][ix], positive)+0.0001
                    neg_prob = get_prob_from_pixel(self.imgs['prob_map']['fg'][ix], negative)+0.0001
                self.items.append([ix, anchor, positive, negative, pos_prob, neg_prob])

    def __getitem__(self, ix):
        ix, anchor, positive, negative, pos_prob, neg_prob = self.items[ix]
        return ix, anchor, positive, negative, pos_prob, neg_prob

    def __len__(self):
        return len(self.items)

# Qualitative and Quantitative evaluation

In [56]:
def segment_with_dynamic(ix, dataset:DeepDynamicTreesDataset, features=None):
    mimg = None
    seeds = dataset.imgs["test_markers"][ix]
    if features is not None:
        mimg = ift.CreateMImageFromNumPy(np.ascontiguousarray(features.squeeze(0).permute(1,2,0).detach().cpu().numpy()))
    else:
        img = dataset.imgs["orig"][ix]
        mimg = ift.ImageToMImage(img, ift.LABNorm_CSPACE)

    A = ift.Circular(1.0)
    label = ift.DynTreeRoot(mimg, A, seeds, 0, 0, None, 0)
    
    gt = dataset.imgs["labels"][ix]
    dice = ift.DiceSimilarity(gt, label)
    assd = ift.ASSD(gt, label)

    return dice, assd, label


def segment(ix, dataset, features=None):
    dice, assd, label = segment_with_dynamic(ix, dataset, features)
    return dice, assd

# Training procedures

In [57]:
def train_batch(x, model, data, optimizer, criterion):
    ix, u, v, w, dp, dn = [t for t in data]
    n = len(ix)
    optimizer.zero_grad()

    reduced_features = model(x)

    anchors   = torch.stack([reduced_features[ix[i],:,int(u[1][i]),int(u[0][i])] for i in range(n)]).to(device)
    positives = torch.stack([reduced_features[ix[i],:,int(v[1][i]),int(v[0][i])] for i in range(n)]).to(device)
    negatives = torch.stack([reduced_features[ix[i],:,int(w[1][i]),int(w[0][i])] for i in range(n)]).to(device)
    dpv       = dp.to(device)
    dnv       = dn.to(device)


    loss, acc = criterion(anchors, positives, negatives, dpv, dnv)
    loss.backward(retain_graph=True)
    optimizer.step()
    return loss.item(), acc.item()

In [58]:
criterion = TripletLoss(2.0)

def finish_epoch(dataset, model):
    dataset.next_epoch(model)
    dataloader = DataLoader(dataset, batch_size=256, shuffle=True)
    return dataloader

def improve_model_for_file(dataset, model, nepochs):

    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=0.0001)
    dataloader = finish_epoch(dataset, None)
    x = torch.cat(dataset.imgs["tensors"],0).to(device)
    
    for epoch in range(nepochs):
        N = len(dataloader)
        sum_batch_loss, sum_batch_acc = 0, 0
        for i, data in enumerate(dataloader):
            train_batch(x, model, data, optimizer, criterion)
        dataloader = finish_epoch(dataset, model)

# FLIM model

In [59]:
def convBlock(ni, no, kernel_size=3, dilation=1):
    return nn.Sequential(
        nn.Conv2d(ni, no, kernel_size=kernel_size, padding='same', bias=False,dilation=dilation),
        nn.BatchNorm2d(no),
        nn.ReLU(inplace=True)
    )

    
class ContrastiveFeatureReducerFLIM(nn.Module):
    def __init__(self):
        super(ContrastiveFeatureReducerFLIM, self).__init__()
        
        model_path = 'flim-models/model'
        architecture_name = 'arch.json'
        architecture = utils.load_architecture('%s/%s' % (model_path,architecture_name))
        self.encoder = utils.build_model(architecture, input_shape=[3])
        
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.reducer = nn.Sequential(
            convBlock(163,128,1),
            nn.Conv2d(128, 16, kernel_size=1, padding='same', bias=False)
        )
        

    def forward(self, x):
        inter_feats = [x]
        y = self.encoder.features[0](x)
        y = self.encoder.features[1](y)
        inter_feats.append(y)
        y = self.encoder.features[2](y)
        y = self.encoder.features[3](y)
        y = self.encoder.features[4](y)
        inter_feats.append(y)
        y = self.encoder.features[5](y)
        y = self.encoder.features[6](y)
        y = self.encoder.features[7](y)
        inter_feats.append(y)
        
        y = torch.cat(inter_feats, 1)
        
        reduced = self.reducer(y)
        reduced_x = torch.cat((reduced,x), 1)
    
        return reduced_x

# U-Net Model

In [60]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3,padding='same')
        self.relu  = nn.ReLU()
        self.bs1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3,padding='same')
        self.bs2   = nn.BatchNorm2d(out_ch)
    
    def forward(self, x):
        return self.bs2(self.relu(self.conv2(self.bs1(self.relu(self.conv1(x))))))
    
    
class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)

    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs
    
class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs
    
class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_feats_out=1, skipconnect_x=False):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_feats_out, 1)
        self.skipconnect_x = skipconnect_x

    def forward(self, x, cat_x=False):
        out_sz   = (x.shape[2], x.shape[3])
        
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        out      = nn.functional.interpolate(out, out_sz)
        if self.skipconnect_x:
            out  = torch.cat((x,out), 1)
        return out

# Aux

In [61]:
def create_tmp_dataset(dataset_path, file, ext):
    csv_tmp = f"{dataset_path}/tmp.csv"
    with open(csv_tmp, 'w') as f:
        f.write(file)
    
    dataset = DeepDynamicTreesDataset(dataset_path, 2048, csv="tmp.csv",
                                      ext_img=ext[0], ext_label=ext[1],
                                      prob_tresh=0.97, radius_adj=10)
    
    return dataset

In [62]:
def init_model_flim():
    return ContrastiveFeatureReducerFLIM().to(device)

def init_model_unet():
    return UNet(num_feats_out=16, skipconnect_x=True).to(device)


def step_flim(dataset):
    model = init_model_flim()
    
    improve_model_for_file(dataset, model, 5)
    
    x = dataset.imgs["tensors"][0].to(device)
    features = model(x)
    dice, assd = segment(0, dataset, features)
    
    return dice, assd

def step_unet(dataset):
    model = init_model_unet()
    
    improve_model_for_file(dataset, model, 5)
    
    x = dataset.imgs["tensors"][0].to(device)
    features = model(x)
    dice, assd = segment(0, dataset, features)
    
    return dice, assd
    
configs = {'U-Net': step_unet,
          'FLIM': step_flim}

# Evaluation

In [64]:
import csv
from datetime import datetime
now = datetime.now()

datasets = [('datasets/GRABCUT',['ppm', 'pgm']), 
            ('datasets/gulshan', ['png', 'png']), 
            ('datasets/andrade',['ppm', 'pgm'])]

with open(f"EXP_{now.strftime('%Y_%m_%d_%H_%M_%S')}.csv", 'w') as file:
    file.write('model; dataset; mean dice; std dice; mean assd; std assd\n')
    
    for dataset_path, ext in datasets:
        dataset_csv = dataset_path + '/files-all.csv'
        file_csv = open(dataset_csv,"r")
        imgs = [img.strip() for img in file_csv]

        for desc, model_eval in configs.items():
            dices = []
            assds = []

            for i in tqdm(range(len(imgs))):
                dataset = create_tmp_dataset(dataset_path, imgs[i], ext)
                dice, assd = model_eval(dataset)
                print(dice,assd)
                dices.append(dice)
                assds.append(assd)

            file.write(f'{desc}; {dataset_path}; {np.mean(dices)}; {np.std(dices)}; {np.mean(assds)}; {np.std(assds)}\n')

A Jupyter Widget

0.6803915922833286 15.243093344498268
0.9807565476103237 1.2480453412931762
0.8882307460456826 5.012147172657184
0.8495402536735358 3.8100067201075674
0.8989894058487526 9.637029423423705


KeyboardInterrupt: 