In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from tqdm import tqdm
import torchinfo
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import random_split
import torchvision
import cv2 as cv
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from glob import glob
from torchvision.io import read_image

# Model

In [3]:
import torch
from torch import nn
import matplotlib.pyplot as plt

Pool = nn.MaxPool2d

def batchnorm(x):
    return nn.BatchNorm2d(x.size()[1])(x)

class Conv(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU()
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x
    
class Residual(nn.Module):
    def __init__(self, inp_dim, out_dim):
        super(Residual, self).__init__()
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(inp_dim)
        self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
        self.bn2 = nn.BatchNorm2d(int(out_dim/2))
        self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
        self.bn3 = nn.BatchNorm2d(int(out_dim/2))
        self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
        self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True
        
    def forward(self, x):
        if self.need_skip:
            residual = self.skip_layer(x)
        else:
            residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        out += residual
        return out 

class Hourglass(nn.Module):
    def __init__(self, n, f, bn=None, increase=0):
        super(Hourglass, self).__init__()
        nf = f + increase
        self.up1 = Residual(f, f)
        # Lower branch
        self.pool1 = Pool(2, 2)
        self.low1 = Residual(f, nf)
        self.n = n
        # Recursive hourglass
        if self.n > 1:
            self.low2 = Hourglass(n-1, nf, bn=bn)
        else:
            self.low2 = Residual(nf, nf)
        self.low3 = Residual(nf, f)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x):
        up1  = self.up1(x)
        pool1 = self.pool1(x)
        low1 = self.low1(pool1)
        low2 = self.low2(low1)
        low3 = self.low3(low2)
        up2  = self.up2(low3)
        return up1 + up2


class HeatmapLoss(torch.nn.Module):
    """
    loss for detection heatmap
    """
    def __init__(self):
        super(HeatmapLoss, self).__init__()

    def forward(self, pred, gt):
        losses = torch.zeros(pred.shape[1])
        for i in range(pred.shape[1]):
            l = ((pred[:, i, :, :] - gt)**2)
            l = l.mean(dim=3).mean(dim=2).mean(dim=1)
            losses[i] = l.mean()
        return losses.mean()

class UnFlatten(nn.Module):
    def forward(self, input):
        return input.view(-1, 256, 4, 4)

class Merge(nn.Module):
    def __init__(self, x_dim, y_dim):
        super(Merge, self).__init__()
        self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False)

    def forward(self, x):
        return self.conv(x)
    
class StackedHourglass(nn.Module):
    def __init__(self, nstack, inp_dim, oup_dim, in_channels=3, bn=False, increase=0, **kwargs):
        super(StackedHourglass, self).__init__()
        
        self.nstack = nstack
        self.pre = nn.Sequential(
            Conv(in_channels, 64, 7, 2, bn=True, relu=True),
            Residual(64, 128),
            Pool(2, 2),
            Residual(128, 128),
            Residual(128, inp_dim)
        )
        
        self.hgs = nn.ModuleList( [
        nn.Sequential(
            Hourglass(4, inp_dim, bn, increase),
        ) for i in range(nstack)] )
        
        self.features = nn.ModuleList( [
        nn.Sequential(
            Residual(inp_dim, inp_dim),
            Conv(inp_dim, inp_dim, 1, bn=True, relu=True)
        ) for i in range(nstack)] )
        
        self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] )
        self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] )
        self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] )
        self.nstack = nstack
        self.heatmapLoss = HeatmapLoss()

    def forward(self, imgs):
        ## our posenet
        x = self.pre(imgs)
        combined_hm_preds = []
        for i in range(self.nstack):
            hg = self.hgs[i](x)
            feature = self.features[i](hg)
            preds = self.outs[i](feature)
            combined_hm_preds.append(preds)
            if i < self.nstack - 1:
                x = x + self.merge_preds[i](preds) + self.merge_features[i](feature)

        # result = torch.stack(combined_hm_preds, 1)
        # batch = result[0]
        # for img in batch:
        #     print(img.shape)
        #     plt.imshow(img.detach().cpu().numpy().squeeze())
        #     plt.show()

        return torch.stack(combined_hm_preds, 1)


In [4]:
hourglass_weights = torch.load('best_model_151.pt')

In [5]:
hourglass = StackedHourglass(8, 256, 1)

In [6]:
hourglass.load_state_dict(hourglass_weights)
hourglass = hourglass.to('cuda')

# Center Detector

In [43]:
files = glob('ISIC2018_Task1-2_Training_Input/*.jpg')
transform = torchvision.transforms.Resize((256, 256), antialias=True)
i = torch.linspace(0,1,64).reshape(1,64,1)
j = torch.linspace(0,1,64).reshape(1,1,64)
csv_data = 'file, x, y,\n'
with torch.no_grad():
    with tqdm(total=len(files)) as pbar:
        for file in files:
            img = torchvision.io.read_image(file)
            img = transform(img)
            torch.cuda.empty_cache()
            netinput = img.unsqueeze(0).to('cuda') / 255.
            output = hourglass(netinput)
            doutput = output.detach().cpu()[0,:]
            del netinput
            del output
            cen = ((torch.sum(doutput[7] * j) / torch.sum(doutput[7])).item(), (torch.sum(doutput[7] * i) / torch.sum(doutput[7])).item())
            csv_data += f'{file},{cen[0]},{cen[1]},\n'
            pbar.update(1)

100%|██████████████████████████████████████████████████████████████████████████████| 2594/2594 [15:48<00:00,  2.74it/s]


In [44]:
with open('centers_new.csv', 'w') as f:
    f.write(csv_data)
    f.flush()