In [1]:
import cv2
import os
import glob
import numpy as np
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import json
import shutil

In [2]:
input_path = './inputs'

In [3]:
all_images = glob.glob(os.path.join(input_path, '*.*'))

In [4]:
patch_size = 1000
stride = 500

if not os.path.isdir('./patched_cache'):
    os.makedirs('./patched_cache')
else:
    #os.removedirs('./patched_cache/')
    shutil.rmtree('./patched_cache/', ignore_errors=True)
    os.makedirs('./patched_cache')
    
for img_path in tqdm(all_images):

    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    img += img.min()
    img = img/(img.max()/255.)
    if len(img.shape) == 2:
        img = np.stack([img]*3, -1)
        #print('gray')
    img = img.astype('uint8')
    shape = img.shape
    if min(shape[:2]) > patch_size:
        
        x_count = shape[1]//stride
        y_count = shape[0]//stride
        
        for x_id in range(x_count):
            for y_id in range(y_count):
                
                if x_id == x_count - 1:
                    xmin, xmax = shape[1]-patch_size, shape[1]
                    #print(xmin, xmax)
                else:
                    xmin, xmax = x_id*stride, x_id*stride + patch_size
                if y_id == y_count - 1:
                    ymin, ymax = shape[0]-patch_size, shape[0]
                    #print(ymin, ymax)
                else:
                    ymin, ymax = y_id*stride, y_id*stride + patch_size
  
                cv2.imwrite('{}&{}&{}.png'.format('patched_cache/'+img_path.split('/')[-1].split('.')[0], xmin, ymin), img[ymin:ymax, xmin:xmax])
     
    else:
        cv2.imwrite('{}${}.png'.format('patched_cache/'+img_path.split('/')[-1].split('.')[0], 0), img)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.64s/it]


In [44]:
%%time
!python export.py --img  1280 --weights runs/fold_4.pt runs/fold_3.pt runs/fold_2.pt runs/fold_1.pt runs/fold_0.pt  --half --iou-thres 0.5 --conf-thres=0.4 --device 0 --include engine

[34m[1mexport: [0mdata=data/coco128.yaml, weights=['runs/fold_4.pt', 'runs/fold_3.pt', 'runs/fold_2.pt', 'runs/fold_1.pt', 'runs/fold_0.pt'], imgsz=[1280], batch_size=1, device=0, half=True, inplace=False, keras=False, optimize=False, int8=False, dynamic=False, simplify=False, opset=12, verbose=False, workspace=4, nms=False, agnostic_nms=False, topk_per_class=100, topk_all=100, iou_thres=0.5, conf_thres=0.4, include=['engine']
YOLOv5 🚀 v6.2-189-g2f1eb21 Python-3.9.13 torch-1.12.1 CUDA:0 (Quadro RTX 5000, 16125MiB)

Fusing layers... 
Model summary: 206 layers, 12308200 parameters, 0 gradients, 16.1 GFLOPs

[34m[1mPyTorch:[0m starting from runs/fold_4.pt with output shape (1, 102000, 6) (24.3 MB)

[34m[1mONNX:[0m starting export with onnx 1.12.0...
[34m[1mONNX:[0m export success ✅ 5.1s, saved as runs/fold_4.onnx (24.3 MB)

[34m[1mTensorRT:[0m starting export with TensorRT 8.4.3.1...
[10/14/2022-22:47:01] [TRT] [I] [MemUsageChange] Init CUDA: CPU +303, GPU +0, now: CPU 2385

KeyboardInterrupt: 

In [36]:
%%time
!python custom_det.py --img 1280 --source patched_cache --weights runs/fold_4.pt runs/fold_3.pt runs/fold_2.pt runs/fold_1.pt runs/fold_0.pt --name testa --max-det 20000 --half --iou-thres 0.5 --conf-thres=0.4 --save-txt --save-conf --line-thickness 1 --hide-labels --project patched_cache/detect --nosave

[34m[1mcustom_det: [0mweights=['runs/fold_4.pt', 'runs/fold_3.pt', 'runs/fold_2.pt', 'runs/fold_1.pt', 'runs/fold_0.pt'], source=patched_cache, data=data/coco128.yaml, imgsz=[1280, 1280], conf_thres=0.4, iou_thres=0.5, max_det=20000, device=, view_img=False, save_txt=True, save_conf=True, save_crop=False, nosave=True, classes=None, agnostic_nms=False, augment=False, visualize=False, update=False, project=patched_cache/detect, name=testa, exist_ok=False, line_thickness=1, hide_labels=True, hide_conf=False, half=True, dnn=False, vid_stride=1
Patching...
YOLO Inferencing...
YOLOv5 🚀 v6.2-189-g2f1eb21 Python-3.9.13 torch-1.12.1 CUDA:0 (Quadro RTX 5000, 16125MiB)

Fusing layers... 
Model summary: 206 layers, 12308200 parameters, 0 gradients, 16.1 GFLOPs
Fusing layers... 
Model summary: 206 layers, 12308200 parameters, 0 gradients, 16.1 GFLOPs
Fusing layers... 
Model summary: 206 layers, 12308200 parameters, 0 gradients, 16.1 GFLOPs
Fusing layers... 
Model summary: 206 layers, 12308200 pa

In [6]:
import albumentations as A
import torch
import torch.nn as nn
import timm

import segmentation_models_pytorch as smp
from torch.cuda.amp import autocast, GradScaler

In [7]:
from ensemble_boxes import weighted_boxes_fusion

In [8]:
import tifffile as tif

In [9]:
mean = (0.485, 0.456, 0.406) # RGB
std = (0.229, 0.224, 0.225) # RGB

In [10]:
albu_transforms = {
    'valid' : A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean, std),
    ]),
}

In [11]:
class BboxDataset(torch.utils.data.Dataset):
    def __init__(self,
                 img,
                 boxes,
                 mode='train'):
        self.img = img
        self.boxes = boxes
        self.mode = mode

    def __len__(self):
        return len(self.boxes)

    def __getitem__(self, idx: int):
        
        shape = self.img.shape
        xmin, ymin, xmax, ymax = self.boxes[idx]
        xmin, ymin, xmax, ymax = round(xmin*shape[1]), round(ymin*shape[0]), round(xmax*shape[1]), round(ymax*shape[0])
        croped = self.img[ymin:ymax, xmin:xmax]
        #print(croped.shape)
        auged = albu_transforms['valid'](image=croped)
        image = torch.from_numpy(auged['image']).permute(2,0,1)

        return image

In [12]:
class TimmSED(nn.Module):
    def __init__(self, base_model_name: str, pretrained=False, num_classes=24, in_channels=1):
        super().__init__()
        
        #self.bn0 = nn.BatchNorm2d(CFG.n_mels)

        #self.encoder = timm.create_model(
        #    base_model_name, pretrained=pretrained, in_chans=in_channels, num_classes=num_classes)
        self.encoder = smp.Unet(
                encoder_name=base_model_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
                encoder_weights=pretrained,     # use `imagenet` pre-trained weights for encoder initialization
                in_channels=in_channels,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
                classes=num_classes,                      # model output channels (number of classes in your dataset)
            )
        

    def forward(self, input_data):
        x = input_data 
        logit = self.encoder(x)

        return logit

In [13]:
model = TimmSED(
    base_model_name="efficientnet-b0",
    pretrained=None,
    num_classes=2,
    in_channels=3)

model.to('cuda')
model.load_state_dict(torch.load('../fold-0.bin'))
model.eval()
print('loaded')

loaded


In [14]:
def bbox_inference(image, boxes):
    ds = BboxDataset(image, boxes)
    dl = torch.utils.data.DataLoader(
            ds, batch_size=32, num_workers=12, pin_memory=True, shuffle=False, drop_last=False
        )
    results = []
    for data in tqdm(dl):
        with torch.no_grad():
            with autocast():
                seg_results = torch.sigmoid(model(data.to('cuda'))[:,1])
                #print(seg_results.shape)
                final_result = (seg_results>0.5).int().to('cpu').numpy().tolist()
        results += final_result
    return results

In [15]:
glob.glob('patched_cache/detect/testa/labels/{}&*&*.txt'.format(img_path.split('/')[-1].split('.')[0]))

[]

In [18]:
if not os.path.isdir('./outputs'):
    os.makedirs('./outputs')
else:
    shutil.rmtree('./outputs', ignore_errors=True)
    os.makedirs('./outputs')

for img_path in all_images:
    raw_image = cv2.imread(img_path)
    shape = raw_image.shape
    if min(shape[:2]) > 1000:
        print('patched')
        det_files = glob.glob('patched_cache/detect/testa/labels/{}&*&*.txt'.format(img_path.split('/')[-1].split('.')[0]))
        if len(det_files) > 0:
            box_set = []
            conf_set = []
            cls_set = []
            for file in det_files:
                boxes = []
                confs = []
                split = file.split('/')[-1].split('&')
                leftx, topy = int(split[-2])/shape[1], int(split[-1][:-4])/shape[0]
                with open(file, 'r') as f:
                    data = f.readlines()
                    f.close()
                for res in data:
                    cls, x, y, w, h, conf = res.split(' ')
                    x, y, w, h, conf = float(x), float(y), float(w), float(h), float(conf)

                    xmin, ymin, xmax, ymax = (x-0.5*w)*1000, (y-0.5*h)*1000, (x+0.5*w)*1000, (y+0.5*h)*1000

                    if min(xmin, ymin, xmax, ymax) > 5 and max(xmin, ymin, xmax, ymax) < 995:
                        #print((xmin, ymin, xmax, ymax), min(xmin, ymin, xmax, ymax), max(xmin, ymin, xmax, ymax))
                        xmin, ymin, xmax, ymax = xmin/shape[1]+leftx, ymin/shape[0]+topy, xmax/shape[1]+leftx, ymax/shape[0]+topy
                        boxes.append([xmin, ymin, xmax, ymax])
                        confs.append(conf)
                        #print(1)
                        #continue

                box_set.append(boxes)
                conf_set.append(confs)
                cls_set.append([0]*len(confs))


            boxes, confs, _ = weighted_boxes_fusion(box_set, conf_set, cls_set)
            #boxes = []
            #for box in wbf_boxes:
            #    xmin, ymin, xmax, ymax = box
            #    xmin, ymin, xmax, ymax = round(xmin*shape[1]), round(ymin*shape[0]), round(xmax*shape[1]), round(ymax*shape[0])
                #print(min((xmax-xmin)*shape[1], (ymax-ymin)*shape[0]))
                #if min(xmax-xmin, ymax-ymin) > 2:
                #    boxes.append(box)
                #else:
                #    print('zero')
        #break
    else:
        print('unpatched')
        file = 'patched_cache/detect/testa/labels/{}$0.txt'.format(img_path.split('/')[-1].split('.')[0])
        if os.path.isfile(file):
            #print('No Patch')
            boxes = []
            confs = []
            with open(file, 'r') as f:
                data = f.readlines()
                f.close()
            for res in data:
                cls, x, y, w, h, conf = res.split(' ')
                x, y, w, h, conf = float(x), float(y), float(w), float(h), float(conf)
                xmin, ymin, xmax, ymax = (x-0.5*w), (y-0.5*h), (x+0.5*w), (y+0.5*h)
                if min((xmax-xmin)*shape[1], (ymax-ymin)*shape[0]) > 2:
                    boxes.append([xmin, ymin, xmax, ymax])
                    confs.append(conf)

    base = np.zeros((shape[0], shape[1]), dtype='uint16')
    if boxes is not None:
        if len(boxes) > 65000:
            base = np.zeros((shape[0], shape[1]), dtype='uint32')
        masks = bbox_inference(raw_image, boxes)

        cell_count = 1
        for box, mask in zip(boxes, masks):
            xmin, ymin, xmax, ymax = box
            xmin, ymin, xmax, ymax = round(xmin*shape[1]), round(ymin*shape[0]), round(xmax*shape[1]), round(ymax*shape[0])
            mask = cv2.resize(np.array(mask), (xmax-xmin, ymax-ymin), interpolation=cv2.INTER_NEAREST).astype(bool)

            base[ymin:ymax, xmin:xmax][mask] = cell_count
            cell_count+=1

    #
    #raw_image = np.stack([base*0.5]*3, -1) + raw_image*0.5
    #cv2.imwrite('vis/{}.png'.format(idx), raw_image)

        #os.removedirs('./outputs')
    tif.imwrite('./outputs/{}_label.tiff'.format(img_path.split('/')[-1].split('.')[0]), base, compression='zlib')

patched


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.40it/s]


patched


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  3.88it/s]


unpatched


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.19it/s]
