In [None]:
import os
import numpy as np
# import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from sklearn.metrics import jaccard_similarity_score
from sklearn.utils.class_weight import compute_sample_weight
from source import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def center_crop(x, height, width):
        i_h, i_w = x.shape
        p_h, p_w = (i_h-height)//2, (i_w-width)//2
        return x[p_h:p_h+height, p_w:p_w+width]

In [None]:
import torchvision.transforms.functional as TF
import random

class Ds(Dataset):
    def __init__(self, root_dir, in_dir, label_dir, transform=None, target_transform=None, transforms=False, test=False):
        '''
        Args:
          root_dir (string): Path to directory with images and labels
          in_dir (string): Relative path to images (wrt root_dir)
          label_dir (string): Relative path to labels (wrt root_dir)
        '''
        self.root_dir = root_dir
        self.in_dir = in_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.transform= transform
        self.target_transform = target_transform
        self.test = test

    def __len__(self):
        if(transforms):
            return 23
        elif(not test):
            return 7
        else:
            return 30
    
    def common_transform(self, image, mask):
        # Random Affine
        ret = torchvision.transforms.RandomAffine.get_params((-0.1,0.1), [-0.01,0.01], None, None, image.size)
        image = TF.affine(image, *ret, fillcolor=0)
        mask = TF.affine(mask, *ret, fillcolor=0)                
    
        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)

        # Random vertical flipping
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
            
        # Transform to tensor
        image = transforms.Pad(90,padding_mode='reflect')(image)
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
            
        return image, mask

    def __getitem__(self, idx):
        img_name = "t-"+str(idx)+".tif"
        label_name = "l-"+str(idx)+".tif"
        img_loc = os.path.join(self.root_dir, self.in_dir, img_name)
        label_loc = os.path.join(self.root_dir, self.label_dir, label_name)
        image=Image.open(img_loc).convert("RGB")
        label=Image.open(label_loc)
        if self.transform:
            image=self.transform(image)

        if self.target_transform:
            label=self.target_transform(label)
            
        if self.transforms:
            tf = self.transforms
            image, label = self.common_transform(image, label)
        else:
            image = transforms.Pad(90,padding_mode='reflect')(image)
            image = TF.to_tensor(image)
            label = TF.to_tensor(label)
            
        label=torchvision.transforms.Lambda(transforms.Lambda(lambda x:(x>0).long()))(label)

        return image,label

In [None]:
%%time

x_transform = transforms.Compose([transforms.Grayscale(1),
                                  transforms.Resize((392,392)),
                                  ])

y_transform = transforms.Compose([transforms.Grayscale(1),
                                  transforms.Resize((388,388)),
                                  ])

trainset = Ds("../em_stack/","train","labels",x_transform, y_transform,True)
valset = Ds("../em_stack/","val","val-labels",x_transform, y_transform,False,False)
testset = Ds("../em_stack/","test","labels",x_transform, y_transform,False,True)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=3,shuffle=True,num_workers=8)
valloader = torch.utils.data.DataLoader(valset,batch_size=3,shuffle=True,num_workers=8)
testloader = torch.utils.data.DataLoader(testset,batch_size=3,shuffle=True,num_workers=8)
dataiter=iter(trainloader)
i,l = dataiter.__next__()
print(i.shape)
print(l.shape)
plt.subplot(1,2,1)
plt.imshow(i.numpy()[0][0], cmap='gray')
plt.subplot(1,2,2)
plt.imshow(l.numpy()[0][0], cmap='gray')

In [None]:
# class Net(nn.Module):
#     def __init__(self, in_channels=1):
#         super(Net, self).__init__()
#         self.relu = nn.ReLU(inplace=True)
#         self.conv1 = nn.Conv2d(in_channels, 64, 3, 1, 0)
#         self.conv2 = nn.Conv2d(64, 64, 3, 1, 0)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.conv3 = nn.Conv2d(64, 128, 3, 1, 0)
#         self.conv4 = nn.Conv2d(128, 128, 3, 1, 0)
#         self.conv5 = nn.Conv2d(128, 256, 3, 1, 0)
#         self.conv6 = nn.Conv2d(256, 256, 3, 1, 0)
#         self.conv7 = nn.Conv2d(256, 512, 3, 1, 0)
#         self.conv8 = nn.Conv2d(512, 512, 3, 1, 0)
#         self.conv9 = nn.Conv2d(512, 1024, 3, 1, 0)
#         self.conv10 = nn.Conv2d(1024, 1024, 3, 1, 0)
#         self.upconv1 = nn.ConvTranspose2d(1024, 512, 2, 2)
#         self.conv11 = nn.Conv2d(1024, 512, 3, 1, 0)
#         self.conv12 = nn.Conv2d(512, 512, 3, 1, 0)
#         self.upconv2 = nn.ConvTranspose2d(512, 256, 2, 2)
#         self.conv13 = nn.Conv2d(512, 256, 3, 1, 0)
#         self.conv14 = nn.Conv2d(256, 256, 3, 1, 0)
#         self.upconv3 = nn.ConvTranspose2d(256, 128, 2, 2)
#         self.conv15 = nn.Conv2d(256, 128, 3, 1, 0)
#         self.conv16 = nn.Conv2d(128, 128, 3, 1, 0)
#         self.upconv4 = nn.ConvTranspose2d(128, 64, 2, 2)
#         self.conv17 = nn.Conv2d(128, 64, 3, 1, 0)
#         self.conv18 = nn.Conv2d(64, 64, 3, 1, 0)
#         self.conv19 = nn.Conv2d(64, 2, 1, 1, 0)

#     def center_crop(self, x, height, width):
#         _, _, i_h, i_w = x.shape
#         p_h, p_w = (i_h-height)//2, (i_w-width)//2
#         return x[:, :, p_h:p_h+height, p_w:p_w+width]

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu(x)
#         x = self.conv2(x)
#         x = self.relu(x)
#         x1 = self.center_crop(x, 392, 392)
#         x = self.pool(x)
#         x = self.conv3(x)
#         x = self.relu(x)
#         x = self.conv4(x)
#         x = self.relu(x)
#         x2 = self.center_crop(x, 200, 200)
#         x = self.pool(x)
#         x = self.conv5(x)
#         x = self.relu(x)
#         x = self.conv6(x)
#         x = self.relu(x)
#         x3 = self.center_crop(x, 104, 104)
#         x = self.pool(x)
#         x = self.conv7(x)
#         x = self.relu(x)
#         x = self.conv8(x)
#         x = self.relu(x)
#         x4 = self.center_crop(x, 56, 56)
#         x = self.pool(x)
#         x = self.conv9(x)
#         x = self.relu(x)
#         x = self.conv10(x)
#         x = self.relu(x)
#         x = self.upconv1(x)
#         x = torch.cat((x4, x), dim=1)
#         x = self.conv11(x)
#         x = self.relu(x)
#         x = self.conv12(x)
#         x = self.relu(x)
#         x = self.upconv2(x)
#         x = torch.cat((x3, x), dim=1)
#         x = self.conv13(x)
#         x = self.relu(x)
#         x = self.conv14(x)
#         x = self.relu(x)
#         x = self.upconv3(x)
#         x = torch.cat((x2, x), dim=1)
#         x = self.conv15(x)
#         x = self.relu(x)
#         x = self.conv16(x)
#         x = self.relu(x)
#         x = self.upconv4(x)
#         x = torch.cat((x1, x), dim=1)
#         x = self.conv17(x)
#         x = self.relu(x)
#         x = self.conv18(x)
#         x = self.relu(x)
#         x = self.conv19(x)
#         return x

In [None]:
class downBlock(nn.Module):
    '''
    Conv2d --> ReLU --> Conv2d --> ReLU
    '''
    def __init__(self, in_channels, out_channels):
        super(downBlock, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 0),
            nn.ReLU())
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 0),
            nn.ReLU())

    def forward(self, x):
        op = self.conv1(x)
        op = self.conv2(op)
        return op
    
class upBlock(nn.Module):
    '''
    x --> UpConv \
                --> cat --> Conv2d --> ReLU --> Conv2d --> ReLU   
    skip -> Crop /
    '''
    def __init__(self, in_channels, out_channels):
        super(upBlock, self).__init__()

        self.conv = downBlock(in_channels, out_channels)
        self.upConv = nn.ConvTranspose2d(in_channels, out_channels, 2, 2)

    def forward(self, x, skip):
        op2 = self.upConv(x)
        pad = (op2.size()[2] - skip.size()[2]) // 2
        op1 = F.pad(skip, [pad,pad,pad,pad])
        inp = torch.cat([op1, op2], 1)
        return self.conv(inp)

In [None]:
class Net(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):
        super(Net, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        channels=[64,128,256,512,1024]

        self.conv1 = downBlock(self.in_channels, channels[0])
        self.pool1 = nn.MaxPool2d(2,2) 
        self.conv2 = downBlock(channels[0], channels[1])
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = downBlock(channels[1], channels[2])
        self.pool3 = nn.MaxPool2d(2,2)
        self.conv4 = downBlock(channels[2], channels[3])
        self.pool4 = nn.MaxPool2d(2,2)

        self.conv5 = downBlock(channels[3], channels[4])

        self.up1 = upBlock(channels[4], channels[3])
        self.up2 = upBlock(channels[3], channels[2])
        self.up3 = upBlock(channels[2], channels[1])
        self.up4 = upBlock(channels[1], channels[0])

        self.conv6 = nn.Conv2d(channels[0], out_channels, 1)

    def forward(self, x):
        x1 = self.conv1(x)                       # ||| ---> x1
        x2 = self.conv2(self.pool1(x1))          #   ||| ---> x2
        x3 = self.conv3(self.pool2(x2))          #     ||| ---> x3
        x4 = self.conv4(self.pool3(x3))          #       ||| ---> x4
        x5 = self.conv5(self.pool4(x4))          #         ||| = x5
        x6 = self.up1(x5, x4)                    #
        x7 = self.up2(x6, x3)
        x8 = self.up3(x7, x2)
        x9 = self.up4(x8, x1)
        op = self.conv6(x9)
        return op

In [None]:
class Loss(nn.Module):
    def __init__(self, type='cross_entropy', weight=None):
        super(Loss, self).__init__()
        self.type = type
        self.weight = weight
        self.losses = []
        self.dice_class_scores = []

    def cross_entropy2d(self, logits, target):
        c = logits.shape[1]
        logits = logits.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
        target = target.view(-1) 
        loss = F.cross_entropy(logits, target, self.weight)
        return loss
    
    def softDiceLoss(self, logits, targets):
        
        smooth = 1
        num = targets.size(0)
        probs = torch.sigmoid(logits)
        m1 = probs.view(num, -1).float()
        m2 = targets.view(num, -1).float()
        intersection = (m1 * m2)
        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1-score.sum() / num
        return score
#         smooth = 0.01
#         num = targets.size(0)
#         probs = torch.sigmoid(logits)
#         m1 = probs.view(num, -1)
#         target= nn.functional.one_hot(target)
#         m2 = target.view(num, -1)
#         intersection = (m1 * m2)
#         score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
#         score = 1 - score.sum() / num
#         return score
        
        
        
        
#         smooth = 0.01
#         b_sz = logits.shape[0]
#         c = logits.shape[1]
#         logits = logits.transpose(1,2).transpose(2,3).contiguous().view(-1,c)
#         target = target.view(-1)
#         oneHot = nn.functional.one_hot(target)
#         print(logits.shape)
#         m1 = F.softmax(logits,dim=1)
#         print(m1.shape)
#         m2 = oneHot.view(-1,c).float()
#         print(m2.shape)
#         intersection = (m1 * m2)
#         score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
#         class_score = 1 - score
#         self.dice_class_scores = [self.dice_class_scores, class_score]
#         score =-(score*self.weight).sum() / ((self.weight.sum())*b_sz)
#         return score

    def forward(self, logits, target):
        if self.type=='cross_entropy':
            loss = self.cross_entropy2d(logits, target)
        elif self.type=='dice':
            loss = self.softDiceLoss(logits, target)
        self.losses = [self.losses, loss.item()]
        return loss

In [None]:
%%time

net = Net()

if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs")
#     net=nn.DataParallel(net)
    
net = net.to(device)
criterion = Loss(type="cross_entropy",weight=torch.Tensor((6,1)).to(device))
optimizer = optim.Adam(net.parameters(), weight_decay=1e-4)

In [None]:
def softDiceLoss(logits, targets):

    smooth = 1
    num = targets.size(0)
    probs = torch.sigmoid(logits)
    m1 = probs.view(num, -1).float()
    m2 = targets.view(num, -1).float()
    intersection = (m1 * m2)
    score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
    score = 1-score.sum() / num
    return score

In [None]:
checkpoint = torch.load('../em_stack/model59.pth')
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print(epoch)

In [None]:
n,l = dataiter.next()
n,l = n.to(device), l.to(device)
x=(l.reshape(-1)==0).sum().item()
y=(l.reshape(-1)==1).sum().item()
print((x+y)/x)
print((x+y)/y)
print(nn.functional.one_hot(l[0][0]).shape)
plt.imshow(nn.functional.one_hot(l[0][0])[:,:,1].detach().cpu().numpy())
# plt.subplot(1,2,1)
# plt.imshow(n[0][0].cpu().numpy(), cmap='gray')
# plt.subplot(1,2,2)
# plt.imshow(l[0][0].cpu().numpy(), cmap='gray')
plt.show()

In [None]:
# %%time

loss_arr=[]

for epoch in range(60):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(images)
#         loss = criterion(outputs.permute(0,2,3,1).reshape(-1,2),labels.permute(0,2,3,1).reshape(-1))
        loss=criterion(outputs,labels)
#         loss_arr = [loss_arr, loss]
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loss_arr.append(loss.item())

        if epoch%10 == 9:
            print('[%d, %5d] loss: %.3f' %
                (epoch + 1, i + 1, running_loss / 10))
            ops = net(n)
            fore = ops[0][0].cpu().detach().numpy()
            back = ops[0][1].cpu().detach().numpy()
            print(np.max(fore))
            print(np.min(fore))
            print(np.max(back))
            print(np.min(back))
            plt.subplot(1,2,1)
            plt.imshow(fore, cmap='gray')
            plt.subplot(1,2,2)
            plt.imshow(back, cmap='gray')
            plt.show()
            torch.save({
                'epoch': epoch,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, "../em_stack/model"+str(epoch)+".pth")
            running_loss = 0.0
plt.plot(loss_arr)
plt.show()

In [None]:
plt.plot(loss_arr)
plt.xlabel('Batch')
plt.ylabel('Cross Entropy Loss')
plt.title('Loss v/s batch')
plt.savefig('CellLoss.png')

In [None]:
thresh=2

In [None]:
def jaccard_loss(labels, outputs, eps=1e-7):
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the jaccard loss so we
    return the negated jaccard loss.
    Args:
        true: a tensor of shape [B, H, W] or [B, 1, H, W].
        logits: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        jacc_loss: the Jaccard loss.
    """
    outputs=outputs.argmax(dim=1,keepdim=True).reshape(-1)
#     outputs = outputs[:,1,:,:].reshape(-1)>thresh
    labels=labels.reshape(-1)
    backf=(labels==0).sum().item()
    foref=(labels==1).sum().item()
    foref=(foref+backf)/foref
    backf=(foref+backf)/backf
#     outputs[outputs==1]*=foref
#     outputs[outputs==0]*=backf
#     labels[labels==1]*=foref
#     labels[labels==0]*=backf
    

    return jaccard_score(labels,outputs.detach().cpu(),"weighted")


In [None]:
iou_loss = 0
num_batches=0
with torch.no_grad():
    for data in trainloader:
        image, labels = data
        images = image.to(device)
        outputs = net(images)
        iou_loss += jaccard_loss(labels, outputs)
        num_batches+=1
        for i in range(0,outputs.shape[0],6):
#             plt.subplot(1,3,1)
#             plt.imshow(center_crop(image[i][0],388,388), cmap='gray')
#             plt.subplot(1,3,2)
#             plt.imshow(outputs.argmax(dim=1,keepdim=True)[i][0].detach().cpu().numpy(), cmap='gray')
#             plt.subplot(1,3,3)
#             plt.imshow(labels.numpy()[i][0], cmap='gray')
#             plt.show()
            plt.imshow(center_crop(image[i][0],388,388), cmap='gray')
            plt.show()
#             plt.subplot(1,3,2)
            plt.imshow(outputs.argmax(dim=1,keepdim=True)[i][0].detach().cpu().numpy(), cmap='gray')
            plt.show()
#             plt.subplot(1,3,3)
            plt.imshow(labels.numpy()[i][0], cmap='gray')
            plt.show()
    iou_loss/=num_batches

print("iou_loss = ",iou_loss)
        

In [None]:
for data in testloader:
        image, labels = data
        images = image.to(device)
        outputs = net(images)
        for i in range(0,outputs.shape[0],6):
            plt.subplot(1,3,1)
            plt.imshow(center_crop(image[i][0],388,388), cmap='gray')
            plt.subplot(1,3,2)
            plt.imshow(outputs.argmax(dim=1,keepdim=True)[i][0].detach().cpu().numpy(), cmap='gray')
            plt.show()