In [None]:
import random
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import numpy as np
import torch.utils.data
import cv2
import torchvision.models.segmentation
import torch
import os
import patchify
from sklearn.datasets import load_sample_image
from sklearn.feature_extraction import image as skimg
import imgaug as ia
import imgaug.augmenters as iaa
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models.detection.mask_rcnn import MaskRCNN_ResNet50_FPN_V2_Weights
import torch.optim
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sys
sys.path.insert(1, '/home/prakharug/AFO')
sys.path.insert(1, '/home/prakharug/AFO/pycoco')
from pycoco.engine import train_one_epoch, evaluate
import miou_eval as mi

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 
print(device)

In [None]:
def combine_dims(a, start=0, count=2):
    """ Reshapes numpy array a by combining count dimensions, 
        starting at dimension index start """
    s = a.shape
    return np.reshape(a, s[:start] + (-1,) + s[start+count:])


def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
class MMCellDataset(Dataset):
    def __init__(self,root_dir, tester = False, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.img = []
        self.nml = []
        self.tester = tester
        self.datano =  root_dir.split('/')[1]
        print(self.datano)
        for im_name in os.listdir(self.root_dir):
            self.nml.append(im_name)
            # tmplist = []
            # lent = len(im_name[:-4])
            # print(im_name,im_name[:lent])
            # imgtmp = cv2.imread(self.root_dir+im_name,1).transpose(2,0,1)
            # tmplist.append(imgtmp[0])
            # tmplist.append(imgtmp[1])
            # tmplist.append(imgtmp[2])
            # for gt_name in os.listdir("../"+self.datano+"/ground_truths"):
            #     if gt_name[0:0+lent]==im_name[:lent] and gt_name[0+lent]=="_":
            #         mask = np.array(cv2.imread("../"+self.datano+"/ground_truths/"+gt_name,0))
            #         tmplist.append(mask)
            # tmplist = np.array(tmplist)
            # #print(tmplist.shape)
            # self.img.append(tmplist)   
    
    def __len__(self):
        return len([name for name in os.listdir(self.root_dir)])

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        im_name = self.nml[idx]
        tmplist = []
        lent = len(im_name[:-4])
        #print(im_name,im_name[:lent])
        imgtmp = cv2.imread(self.root_dir+im_name,1).transpose(2,0,1)
        tmplist.append(imgtmp[0])
        tmplist.append(imgtmp[1])
        tmplist.append(imgtmp[2])
        for gt_name in os.listdir("../"+self.datano+"/ground_truths"):
            if gt_name[0:0+lent]==im_name[:lent] and gt_name[0+lent]=="_":
                mask = np.array(cv2.imread("../"+self.datano+"/ground_truths/"+gt_name,0))
                tmplist.append(mask)
        tmplist = np.array(tmplist)
        #print("Tmplist: ",tmplist.shape)
        patch_stack = tmplist
        patch_stack = np.array(patch_stack)
        # k_w = random.randint(700,patch_stack.shape[2])
        # k_h = random.randint(700,patch_stack.shape[1])
        k_w = patch_stack.shape[2]
        k_h = patch_stack.shape[1]
        o_w = random.randint(0,patch_stack.shape[2]-k_w)
        o_h = random.randint(0,patch_stack.shape[1]-k_h)
        patch_img = patch_stack[0:3,o_h:o_h+k_h,o_w:o_w+k_w]
        image_hsv = cv2.cvtColor(patch_img.transpose(1,2,0), cv2.COLOR_BGR2HSV ).transpose(2,0,1)
        image_lab = cv2.cvtColor(patch_img.transpose(1,2,0), cv2.COLOR_BGR2LAB ).transpose(2,0,1)
        #print(image_hsv.shape,image_lab.shape)
        patch_img = np.concatenate((patch_img,image_hsv,image_lab),0)
        #patch_img = cv2.cvtColor(patch_img.transpose(1,2,0), cv2.COLOR_BGR2LAB ).transpose(2,0,1)
        instances = patch_stack[3:,o_h:o_h+k_h,o_w:o_w+k_w]
        instances = instances.transpose(1,2,0)
        data = {}
        masks = []
        boxes = []
        area = []
        t=0
        sem_mask = np.zeros((instances.shape[0],instances.shape[1]))
        for a in range(instances.shape[2]):
            dispim = instances[:,:,a]
            if np.all(dispim == 0):
                continue
            x,y,w,h = cv2.boundingRect(dispim)
            boxes.append([x, y, x+w, y+h])
            area.append(torch.tensor(h*w))
            masks.append(dispim/255)
            sem_mask += dispim
            t=1
        if t==0:
            #print("Abort")
            if self.tester:
                return "Problem","Hao gai","Gais"
            else:
                #print("Abort Abort ")
                return self.__getitem__((idx+1)%len(self.nml))
        masks = np.array(masks)
        sem_mask = np.array(sem_mask > 0,dtype=np.int32)
        #print(sem_mask)
        neg_mask = np.array((sem_mask==0),dtype=np.int32)
        #print(neg_mask)
        #sem_mask_c = torch.as_tensor(sem_mask.sum())
        sem_mask = torch.as_tensor(np.array([sem_mask,sem_mask,sem_mask]),dtype=torch.int32)
        neg_mask = torch.as_tensor(np.array([neg_mask,neg_mask,neg_mask]),dtype=torch.int32)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        area = torch.as_tensor(area, dtype=torch.float32)
        img = torch.as_tensor(patch_img, dtype=torch.float32)
        img = img/255
        data["boxes"] =  boxes
        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
        data["iscrowd"] = iscrowd
        data["labels"] =  torch.ones((boxes.shape[0],), dtype=torch.int64)   # there is only one class
        data["masks"] = masks
        data["area"] = area
        data["image_id"] = torch.tensor(idx)
        return img,sem_mask,neg_mask
train = MMCellDataset("../dataset4/images/")
test = MMCellDataset("../dataset4/test/")
fintest = MMCellDataset("../test/images/",True)
#train = torch.utils.data.ConcatDataset([train, test])
#print(len(test))
trainloader = DataLoader(train, batch_size=1, shuffle=True,collate_fn = collate_fn,num_workers=10)
testloader = DataLoader(test, batch_size=1, shuffle=True,collate_fn = collate_fn,num_workers=10)
fintestloader = DataLoader(fintest, batch_size=1, shuffle=True,collate_fn = collate_fn,num_workers=10)

In [None]:
class yesu(nn.Module):
    def __init__(self) -> None:
        super(yesu,self).__init__()
        self.channel_trans = nn.Sequential(
                nn.Conv2d(9, 128,(1,1),stride=1,padding='same',padding_mode='replicate',bias=False),
                nn.ReLU(),
                nn.Conv2d(128,128,(1,1),stride=1,padding='same',padding_mode='replicate',bias=False),
                nn.ReLU(),
                nn.BatchNorm2d(128),
                nn.Conv2d(128,3,(1,1),stride=1,padding='same',padding_mode='replicate',bias=False),
                nn.BatchNorm2d(3),
                nn.Sigmoid()
            )
    
    def forward(self,images,sem_mask=None,neg_mask=None):
        mean_loss = torch.Tensor([0]).to(device)
        dev_loss = torch.Tensor([0]).to(device)
        # dev_loss2 = torch.Tensor([0]).to(device) 
        for ind in range(len(images)):
            #print(images.shape)
            new_im = self.channel_trans(torch.stack([images[ind]]))
            images[ind] = new_im[0]
            if self.training:
                #print(torch.as_tensor(sem_mask[ind].sum((1,2)).shape))
                #print(sem_mask[ind][0:1,].clone().cpu().numpy().transpose(1,2,0).shape)
                #cv2.imwrite("seg.png",sem_mask[ind][0:1,].clone().cpu().numpy().transpose(1,2,0)*255)
                #neg_mask = (sem_mask==0).to(device)
                #print(torch.mul(images[ind],sem_mask[ind]).sum((1,2)))
                sem_val = torch.mul(images[ind],sem_mask[ind]).sum((1,2))/torch.as_tensor(sem_mask[ind].sum((1,2)))
                neg_val = torch.mul(images[ind],neg_mask[ind]).sum((1,2))/torch.as_tensor(neg_mask[ind].sum((1,2)))
                mean_loss += torch.abs(torch.mul(sem_val,neg_val).sum()/(((torch.square(sem_val).sum())**0.5)*((torch.square(neg_val).sum())**0.5)))
                dev_loss1 = torch.std(sem_val, axis = 0)#*sem_val.shape[0]/torch.as_tensor(sem_mask[ind].sum())
                dev_loss2 = torch.std(neg_val, axis = 0)#*neg_val.shape[0]/torch.as_tensor(neg_mask[ind].sum())
                #print(sem_val,neg_val)
                
                #dev_loss1 = torch.square(torch.mul(images[ind],sem_mask[ind]) - sem_val).sum(1,2)/torch.as_tensor(sem_mask[ind].sum())
                #dev_loss2 = torch.square(torch.mul(images[ind],neg_mask[ind]) - neg_val).sum(1,2)/torch.as_tensor(neg_mask[ind].sum())
                dev_loss += dev_loss1
        if self.training:
            return images,mean_loss 
        else:
            return images


In [None]:
model = yesu()
model.to(device)
#model.load_state_dict(torch.load('./yesu_0.torch'))
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
model.train()
lmbda = lambda epoch: 0.1
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)

In [7]:
for epoch in range(30): 

    model.train()

    for i, data in enumerate(trainloader, 0):
        images, targets,negt = data
        images = list(image.to(device) for image in images)
        targets = list(targets.to(device) for targets in targets)
        negt = list(targets.to(device) for targets in negt)
        #targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        #print(targets[0]["boxes"].shape)
        #optimizer.zero_grad()
        images,loss = model(images, targets,negt)
        #losses = sum(loss for loss in loss_dict.values())
        loss.backward()
        optimizer.step()
        #print(images[0])
        #cv2.imwrite("jsr.png",images[0].detach().cpu().numpy().transpose(1,2,0)*255)
        del targets,negt,images
        print(i,'loss:', loss.item())
    scheduler.step()
    torch.save(model.state_dict(), "yesu_"+str(epoch)+".torch")


KeyboardInterrupt



In [None]:
torch.save(model.state_dict(), "yesu_yesu.torch")

In [None]:
model.eval()
#model.load_state_dict(torch.load('./yesu_0.torch'))
for im_name in os.listdir('../dataset4/images/'):
    image_bgr = cv2.imread('../dataset4/images/'+im_name,1).transpose(2,0,1)
    #image_bgr = cv2.cvtColor(image_bgr.transpose(1,2,0), cv2.COLOR_BGR2RGB ).transpose(2,0,1)
    image_hsv = cv2.cvtColor(image_bgr.transpose(1,2,0), cv2.COLOR_BGR2HSV ).transpose(2,0,1)
    #print(image_bgr)
    image_lab = cv2.cvtColor(image_bgr.transpose(1,2,0), cv2.COLOR_BGR2LAB ).transpose(2,0,1)
    patch_img = np.concatenate((image_bgr,image_hsv,image_lab),0,dtype=np.float32)
    patch_img = torch.as_tensor(patch_img,dtype=torch.float32,device=device)/255
    print(model([patch_img],None,None)[0].shape)
    imgf = model([patch_img],None,None)[0].detach().cpu().numpy().transpose(1,2,0)*255
    print(imgf)
    #imt = torch.as_tensor([patch_img],dtype=torch.float32,device=device),None,None
    cv2.imwrite("../dataset4/yesu/test.jpg",imgf)