In [None]:
from math import sqrt
import copy

import numpy as np  # linear algebra
import pydicom
import os
import scipy.ndimage
import matplotlib.pyplot as plt
from skimage import measure, morphology
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import cv2
from PIL import Image

INPUT_FOLDER = '/nfs3-p1/zsxm/dataset/aorta_CTA/zhaoqifeng/img/'

In [None]:
# Load the scans in given folder path
def load_scan(path):
    slices = [pydicom.dcmread(path + '/' + s) for s in os.listdir(path)]
    slices.sort(key=lambda x: float(x.InstanceNumber))
#     try:
#         slice_thickness = np.abs(slices[0].ImagePositionPatient[2] - slices[1].ImagePositionPatient[2])
#     except:
#         slice_thickness = np.abs(slices[0].SliceLocation - slices[1].SliceLocation)

#     for s in slices:
#         s.SliceThickness = slice_thickness

    return slices


def get_pixels_hu(slices):
    image = np.stack([s.pixel_array for s in slices])
    # Convert to int16 (from sometimes int16),
    # should be possible as values should always be low enough (<32k)
    image = image.astype(np.int16)

    # Set outside-of-scan pixels to 0
    # The intercept is usually -1024, so air is approximately 0
    image[image == -2000] = 0

    # Convert to Hounsfield units (HU)
    for slice_number in range(len(slices)):
        intercept = slices[slice_number].RescaleIntercept
        slope = slices[slice_number].RescaleSlope

        if slope != 1:
            image[slice_number] = slope * image[slice_number].astype(np.float64)
            image[slice_number] = image[slice_number].astype(np.int16)

        image[slice_number] += np.int16(intercept)

    return np.array(image, dtype=np.int16)


# def set_window(image, w_center, w_width):
#     image_copy = image.copy()
#     for slice_number in range(len(image)):
#         image_copy[slice_number] = cv2.GaussianBlur(image_copy[slice_number], (3,3), 1)
#         image_copy[slice_number][image_copy[slice_number]>w_center+int(w_width/2)] = np.int16(w_center-int(w_width/2))
#         image_copy[slice_number][image_copy[slice_number]<w_center-int(w_width/2)] = np.int16(w_center-int(w_width/2))

#     return image_copy

In [None]:
patient = load_scan(INPUT_FOLDER)
patient_pixels = get_pixels_hu(patient)

plt.hist(patient_pixels.flatten(), bins=80, color='c')
plt.xlabel("Hounsfield Units (HU)")
plt.ylabel("Frequency")
plt.show()

# Show some slice in the middle
plt.figure(figsize=(10,10))
plt.imshow(patient_pixels[80], cmap=plt.cm.gray)
plt.show()

In [None]:
def set_window(image, w_center, w_width):
    image_copy = image.copy()
    for slice_number in range(len(image)):
        image_copy[slice_number] = np.int16(cv2.bilateralFilter(image_copy[slice_number].astype(np.float32), 5, 150, 1))
        image_copy[slice_number][image_copy[slice_number]>w_center+int(w_width/2)] = np.int16(w_center-int(w_width/2))
        image_copy[slice_number][image_copy[slice_number]<w_center-int(w_width/2)] = np.int16(w_center-int(w_width/2))
        
#         kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3,3))
#         image_copy[slice_number] = cv2.erode(image_copy[slice_number], kernel, iterations=2)
#         image_copy[slice_number] = cv2.dilate(image_copy[slice_number], kernel, iterations=2)
        zero_image = np.zeros_like(image_copy[slice_number])
        zero_image[image_copy[slice_number] > 250] = 255
        image_copy[slice_number] = zero_image
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5))
        image_copy[slice_number] = cv2.erode(image_copy[slice_number], kernel, iterations=1)
        image_copy[slice_number] = cv2.dilate(image_copy[slice_number], kernel, iterations=1)
    return image_copy.astype(np.uint8)

patient_copy = set_window(patient_pixels, 300, 600)
# plt.figure(figsize=(10,10))
plt.imshow(patient_copy[80, 200:350, 175:325], cmap='Greys')
plt.show()
# for i in range(len(patient_copy)):
#     plt.imshow(patient_copy[i, 200:350, 175:325], cmap=plt.cm.gray)
#     plt.show()

In [None]:
class Elem():
    def __init__(self, key, contour):
        self.root = key
        self.end = key
        self.contours = [contour]
        
    def __len__(self):
        return len(self.contours)
    
    def append(self, key, contour):
        self.end = key
        self.contours.append(contour)
        
    def get_bboxes(self):
        self.bboxes = [cv2.boundingRect(c) for c in self.contours]
        def process(bbox):
            x, y, w, h = bbox
            cx, cy = x+w//2, y+h//2
            mb = max(w, h)
            sx, sy = cx - mb, cy - mb
            ex, ey = cx + mb, cy + mb
            return sx, sy, ex, ey
            
        self.bboxes = [process(b) for b in self.bboxes]

In [None]:
start_x = int(patient_copy.shape[2]/2-patient_copy.shape[2]*0.15)
end_x = int(patient_copy.shape[2]/2+patient_copy.shape[2]*0.15)
start_y = int(patient_copy.shape[1]*0.55-patient_copy.shape[1]*0.15)
end_y = int(patient_copy.shape[1]*0.55+patient_copy.shape[1]*0.15)
patient_cut = patient_copy[:, start_y:end_y, start_x:end_x]

In [None]:
print(patient_cut.shape)
print(patient_cut.dtype)
for i in range(len(patient_cut)):
    plt.title(str(i))
    plt.imshow(patient_cut[i], cmap="gray")
    plt.show()

In [None]:
def get_intersection(origin, first, second):
    zero1, zero2 = np.zeros_like(origin), np.zeros_like(origin)
    cv2.fillPoly(zero1, [first], 125)
    cv2.fillPoly(zero2, [second], 130)
    inter = zero1 + zero2
    inter[inter<255] = 0
    contours, _ = cv2.findContours(inter, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    inter_area = 0
    for contour in contours:
        inter_area += cv2.contourArea(contour)
    second_area = cv2.contourArea(second)
    if second_area == 0:
        return 0
    assert inter_area <= second_area
    return inter_area/second_area

In [None]:
pre_contours, _ = cv2.findContours(patient_cut[0], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
pre_circle = list(map(lambda x: cv2.minEnclosingCircle(x), pre_contours))
path_dict = {}
for i in range(1, len(patient_cut)):
    cur_contours, _ = cv2.findContours(patient_cut[i], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    cur_circle = list(map(lambda x: cv2.minEnclosingCircle(x), cur_contours))
    for j in range(len(pre_circle)):
        candidate_list = []
        for k in range(len(cur_circle)):
            dis = sqrt((pre_circle[j][0][0]-cur_circle[k][0][0])**2+(pre_circle[j][0][1]-cur_circle[k][0][1])**2)
            max_r, min_r = max(pre_circle[j][1], cur_circle[k][1]), min(pre_circle[j][1], cur_circle[k][1])
            if dis <= max_r - 0.5 * min_r:
                candidate_list.append((k, dis, max_r-min_r))
        
        if len(candidate_list) == 0:
            continue
        candidate_list.sort(key=lambda x:x[1:3])
        if not (i-1, j) in path_dict:
            path_dict[(i-1, j)] = Elem((i-1, j), pre_contours[j])
        path = path_dict.pop((i-1, j))
        temp_list = [copy.deepcopy(path) for s in range(len(candidate_list))]
        for s, candidate in enumerate(candidate_list):
            k = candidate[0]
            if get_intersection(patient_cut[i], pre_contours[j], cur_contours[k]) < 0.8:
                continue
            temp_list[s].append((i, k), cur_contours[k])
            path_dict[(i, k)] = temp_list[s]
#         k = candidate_list[0][0]
#         path.append((i, k), cur_contours[k])
#         path_dict[(i, k)] = path
    
    pre_contours = cur_contours
    pre_circle = cur_circle

In [None]:
print(len(path_dict))

In [None]:
max_len = -1
for val in path_dict.values():
    this_len = len(val)
    if this_len > max_len:
        max_len = this_len
        
print(max_len)

In [None]:
path_list = list(path_dict.values())
path_list.sort(key=lambda x: len(x), reverse=True)

In [None]:
print(len(path_list[0]))
print(len(path_list[1]))
print(len(path_list[2]))
print(len(path_list[3]))

In [None]:
path = path_list[0]
canvas = np.zeros_like(patient_cut)
start = path.root[0]
end = path.end[0]+1
for i in range(start, end):
    cv2.fillPoly(canvas[i], [path.contours[i-start]], 255)
    
for i in range(len(canvas)):
    plt.title(str(i))
    plt.imshow(canvas[i], cmap="gray")
    plt.show()

In [None]:
path.get_bboxes()

In [None]:
canvas = patient_cut.copy()
start = path.root[0]
end = path.end[0]+1
for i in range(start, end):
    cv2.rectangle(canvas[i], path.bboxes[i-start][0:2], path.bboxes[i-start][2:4], 255)
    
for i in range(len(canvas)):
    plt.title(str(i))
    plt.imshow(canvas[i], cmap="gray")
    plt.show()

In [None]:
def set_window2(image, w_center, w_width):
    image_copy = image.copy().astype(np.float32)
    for slice_number in range(len(image_copy)):
        image_copy[slice_number] = np.clip(image_copy[slice_number], w_center-int(w_width/2), w_center+int(w_width/2))
        image_copy[slice_number] = (image_copy[slice_number]-image_copy[slice_number].min())/(image_copy[slice_number].max()-image_copy[slice_number].min())
    
    return image_copy

In [None]:
patient_img = set_window2(patient_pixels, 40, 400)

for i in range(start, end):
    x1, y1 = path.bboxes[i-start][0] + start_x, path.bboxes[i-start][1] + start_y
    x2, y2 = path.bboxes[i-start][2] + start_x, path.bboxes[i-start][3] + start_y
    cv2.rectangle(patient_img[i], (x1, y1), (x2, y2), 1, 2)
    
for i in range(len(patient_img)):
    plt.title(str(i))
    plt.imshow(patient_img[i], cmap="gray")
    plt.show()

In [None]:
def draw_bbox(slices, path, start_x, start_y, save_path):
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    for i, s in enumerate(slices):
        
        img = s.pixel_array
        if path.root[0] <= i <= path.end[0]:
            x1, y1 = path.bboxes[i-start][0] + start_x, path.bboxes[i-start][1] + start_y
            x2, y2 = path.bboxes[i-start][2] + start_x, path.bboxes[i-start][3] + start_y
            cv2.rectangle(img, (x1, y1), (x2, y2), 3000, 1)
        s.PixelData = pydicom.encaps.encapsulate([img.tobytes()])
        s.file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1'
        s.save_as(os.path.join(save_path, f'{i}.dcm'), write_like_original=False)

In [None]:
patient = load_scan(INPUT_FOLDER)
draw_bbox(patient, path, start_x, start_y, '/nfs3-p1/zsxm/dataset/aorta_CTA/zhaoqifeng/save/')