In [8]:
from cgitb import reset
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import skimage as sm
import skimage.io
from matplotlib import pyplot as plt
import tifffile
import timm
import os
from fastai.vision.all import *
from skimage.feature import blob, blob_dog, blob_log, blob_doh
from os.path import exists

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_WORKERS = 2
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
PIN_MEMORY = True

# util

def createFolder(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print("Error: Creating directory. " + directory)


class VidDataset(Dataset):
    def __init__(self, filename, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        filenames = os.listdir(image_dir)
        filenames.sort()
        if ".DS_Store" in filenames:
            filenames.remove(".DS_Store")
        if f"focus{filename}.tif" in filenames:
            filenames.remove(f"focus{filename}.tif")
        if ".ipynb_checkpoints" in filenames:
            filenames.remove(".ipynb_checkpoints")
        self.images = filenames

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        image = sm.io.imread(img_path).astype(np.float32)
        images = torch.tensor(image/256).float()

        if self.transform is not None:
            transformed = self.transform(image=image[0], image0=image[1], image1=image[2], image2=image[3],
                                         image3=image[4], image4=image[5], image5=image[6], image6=image[7],
                                         image7=image[8], image8=image[9])
            images[0] = transformed["image"]
            images[1] = transformed["image0"]
            images[2] = transformed["image1"]
            images[3] = transformed["image2"]
            images[4] = transformed["image3"]
            images[5] = transformed["image4"]
            images[6] = transformed["image5"]
            images[7] = transformed["image6"]
            images[8] = transformed["image7"]
            images[9] = transformed["image8"]
            
        pred_name = self.images[index].replace(".tif", "")

        return images, pred_name


def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(
    filename_dir,
    batch_size,
    filename_transform,
    filename,
    num_workers=4,
    pin_memory=True,
):
    filename_ds = VidDataset(
        filename,
        image_dir=filename_dir,
        transform=filename_transform,
    )

    filename_loader = DataLoader(
        filename_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    return filename_loader


def make_predictions(loader, model, filename, device="cuda"):
    model.eval()
    loop = tqdm(loader)
    with torch.no_grad():
        for batch_idx, (x, pred_name) in enumerate(loop):
            x = x.to(device)
            preds = torch.sigmoid(model(x))
            preds = preds.detach().cpu().numpy()
            preds = np.asarray(preds*255, "uint8")
            for i in range(preds.shape[0]):
                tifffile.imwrite(
                    f"dat_midpoint/{filename}/pred_{pred_name[i]}.tif", preds[i])

    model.train()
    

def make_predictions_Orientation(loader, model, filename, device="cuda"):
    model.eval()
    loop = tqdm(loader)
    with torch.no_grad():
        for batch_idx, (x, pred_name) in enumerate(loop):
            x = x.to(device)
            preds = torch.sigmoid(model(x))
            preds = preds.detach().cpu().numpy()
            preds = np.asarray(preds*255, "uint8")
            for i in range(preds.shape[0]):
                tifffile.imwrite(
                    f"dat_midpoint/{filename}/Masks/{pred_name[i]}.tif", preds[i])

    model.train()
    

    
def sortDL(df, t, x, y):

    a = df[df["T"] == t - 2]
    b = df[df["T"] == t - 1]
    c = df[df["T"] == t + 1]
    d = df[df["T"] == t + 2]

    df = pd.concat([a, b, c, d])

    xMax = x + 13
    xMin = x - 13
    yMax = y + 13
    yMin = y - 13
    if xMax > 511:
        xMax = 511
    if yMax > 511:
        yMax = 511
    if xMin < 0:
        xMin = 0
    if yMin < 0:
        yMin = 0

    dfxmin = df[df["X"] >= xMin]
    dfx = dfxmin[dfxmin["X"] < xMax]

    dfymin = dfx[dfx["Y"] >= yMin]
    df = dfymin[dfymin["Y"] < yMax]

    return df


def intensity(vid, ti, xi, yi):

    [T, X, Y] = vid.shape

    vidBoundary = np.zeros([T, 552, 552])

    for x in range(X):
        for y in range(Y):
            vidBoundary[:, 20 + x, 20 + y] = vid[:, x, y]

    rr, cc = sm.draw.disk([yi + 20, xi + 20], 9)
    div = vidBoundary[ti][rr, cc]
    div = div[div > 0]

    mu = np.mean(div)

    return mu
    
    
def maskOrientation(mask):
    S = np.zeros([2, 2])
    X, Y = mask.shape
    x = np.zeros([X, Y])
    y = np.zeros([X, Y])
    x += np.arange(X)
    y += (Y - 1 - np.arange(Y)).reshape(Y, 1)
    A = np.sum(mask)
    Cx = np.sum(x * mask) / A
    Cy = np.sum(y * mask) / A
    xx = (x - Cx) ** 2
    yy = (y - Cy) ** 2
    xy = (x - Cx) * (y - Cy)
    S[0, 0] = -np.sum(yy * mask) / A ** 2
    S[1, 0] = S[0, 1] = np.sum(xy * mask) / A ** 2
    S[1, 1] = -np.sum(xx * mask) / A ** 2
    TrS = S[0, 0] + S[1, 1]
    I = np.zeros(shape=(2, 2))
    I[0, 0] = 1
    I[1, 1] = 1
    q = S - TrS * I / 2
    theta = np.arctan2(q[0, 1], q[0, 0]) / 2

    return theta * 180 / np.pi
    
    
def main():
    target10 = {'image0': 'image', 'image1': 'image', 'image2': 'image', 'image3': 'image',
                'image4': 'image', 'image5': 'image', 'image6': 'image', 'image7': 'image',
                'image8': 'image', 'image9': 'image'}
    filename_transform = A.Compose(
        [
            A.Normalize(
                mean=0,
                std=1,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
        additional_targets=target10,
    )
    
    cwd = os.getcwd()
    filenames = os.listdir(cwd + "/dat_pred")
    if ".DS_Store" in filenames:
        filenames.remove(".DS_Store")
    if ".ipynb_checkpoints" in filenames:
        filenames.remove(".ipynb_checkpoints")
    filenames.sort()
    
    # ---------- Find Divisions ----------

    resnet = timm.create_model("resnet34", pretrained=True)
    resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(
        7, 7), stride=(2, 2), padding=(3, 3), bias=False)

    m = resnet
    m = nn.Sequential(*list(m.children())[:-2])
    model = DynamicUnet(m, 1, (512, 512), norm_type=None).to(DEVICE)
#     x = cast(torch.randn(2, 10, 512, 512), TensorImage)
#     y = model(x)
    load_checkpoint(torch.load("models/UNetCellDivision10.pth.tar"), model)


    label=0
    for filename in filenames:
        print(filename)
        path_to_file = f"dat_midpoint/{filename}/dfDivisions{filename}.pkl"
        if False == exists(path_to_file):
            print(label)
            createFolder(f"dat_midpoint/{filename}")
            FILENAME_IMG_DIR = f"dat_pred/{filename}/"
            filename_loader = get_loaders(
                FILENAME_IMG_DIR,
                BATCH_SIZE,
                filename_transform,  # train_transform
                filename,
                NUM_WORKERS,
                PIN_MEMORY,
            )

            focus = sm.io.imread(f"dat_pred/{filename}/focus{filename}.tif").astype(int)
            T = focus.shape[0]-4

            make_predictions(filename_loader, model, filename)

            vid = np.zeros([T, 512, 512])

            img = np.zeros([552, 552])
            vid[0] = sm.io.imread(f"dat_midpoint/{filename}/pred_{filename}_{0}.tif").astype(int)
            img[20:532, 20:532] = vid[0]
            blobs = blob_log(img, min_sigma=10, max_sigma=25, num_sigma=25, threshold=30)
            blobs_logs = np.concatenate((blobs, np.zeros([len(blobs), 1])), axis=1)

            for t in range(1, T):
                img = np.zeros([552, 552])        
                vid[t] = sm.io.imread(f"dat_midpoint/{filename}/pred_{filename}_{t}.tif").astype(int)
                img[20:532, 20:532] = vid[t]
                blobs = blob_log(img, min_sigma=10, max_sigma=25, num_sigma=25, threshold=30)
                blobs_log = np.concatenate((blobs, np.zeros([len(blobs), 1]) + t), axis=1)
                blobs_logs = np.concatenate((blobs_logs, blobs_log))


            _df = []
            for blob in blobs_logs:
                y, x, r, t = blob
                mu = intensity(vid, int(t), int(x - 20), int(y - 20))

                _df.append(
                    {
                        "Label": label,
                        "T": int(t + 1),
                        "X": int(x - 20),
                        "Y": 532 - int(y),  # map coords without boundary
                        "Intensity": mu,
                    }
                )
                label += 1

            df = pd.DataFrame(_df)
            df.to_pickle(f"dat_midpoint/{filename}/_dfDivisions{filename}.pkl")
            dfRemove = pd.read_pickle(f"dat_midpoint/{filename}/_dfDivisions{filename}.pkl")

            for i in range(len(df)):
                ti, xi, yi = df["T"].iloc[i], df["X"].iloc[i], df["Y"].iloc[i]
                labeli = df["Label"].iloc[i]
                dfmulti = sortDL(df, ti, xi, yi)
                dfmulti = dfmulti.drop_duplicates(subset=["T", "X", "Y"])

                if len(dfmulti) > 0:
                    mui = df["Intensity"].iloc[i]
                    for j in range(len(dfmulti)):
                        tj, xj, yj = (
                            dfmulti["T"].iloc[j],
                            dfmulti["X"].iloc[j],
                            dfmulti["Y"].iloc[j],
                        )
                        labelj = dfmulti["Label"].iloc[j]
                        muj = dfmulti["Intensity"].iloc[j]

                        if mui < muj:
                            indexNames = dfRemove[dfRemove["Label"] == labeli].index
                            dfRemove.drop(indexNames, inplace=True)
                        else:
                            indexNames = dfRemove[dfRemove["Label"] == labelj].index
                            dfRemove.drop(indexNames, inplace=True)

            dfDivisions = dfRemove.drop_duplicates(subset=["T", "X", "Y"])
            dfDivisions.to_pickle(f"dat_midpoint/{filename}/dfDivisions{filename}.pkl")
            os.remove(f"dat_midpoint/{filename}/_dfDivisions{filename}.pkl")

            createFolder(f"dat_midpoint/{filename}/Divisions")

            for k in range(len(dfDivisions)):
                label = int(dfDivisions["Label"].iloc[k])
                t = int(dfDivisions["T"].iloc[k])
                x = int(dfDivisions["X"].iloc[k])
                y = int(512 - dfDivisions["Y"].iloc[k])

                xMax = int(x + 30)
                xMin = int(x - 30)
                yMax = int(y + 30)
                yMin = int(y - 30)
                if xMax > 512:
                    xMaxCrop = 60 - (xMax - 512)
                    xMax = 512
                else:
                    xMaxCrop = 60
                if xMin < 0:
                    xMinCrop = -xMin
                    xMin = 0
                else:
                    xMinCrop = 0
                if yMax > 512:
                    yMaxCrop = 60 - (yMax - 512)
                    yMax = 512
                else:
                    yMaxCrop = 60
                if yMin < 0:
                    yMinCrop = -yMin
                    yMin = 0
                else:
                    yMinCrop = 0

                vid = np.zeros([10, 120, 120])
                for i in range(5):
                    image = np.zeros([60, 60])

                    image[yMinCrop:yMaxCrop, xMinCrop:xMaxCrop] = focus[t - 1 + i, yMin:yMax, xMin:xMax, 1]

                    image = np.asarray(image, "uint8")
                    tifffile.imwrite("dat_midpoint/images.tif", image)

                    division = Image.open("dat_midpoint/images.tif")

                    division = division.resize((120, 120))
                    vid[2 * i] = division

                    image = np.zeros([60, 60])

                    image[yMinCrop:yMaxCrop, xMinCrop:xMaxCrop] = focus[t - 1 + i, yMin:yMax, xMin:xMax, 0]

                    image = np.asarray(image, "uint8")
                    tifffile.imwrite("dat_midpoint/images.tif", image)

                    division = Image.open("dat_midpoint/images.tif")

                    division = division.resize((120, 120))
                    vid[2 * i + 1] = division

                vid = np.asarray(vid, "uint8")
                tifffile.imwrite(
                    f"dat_midpoint/{filename}/Divisions/division{label}.tif", vid
                )

    # ---------- Find Orientation ----------
            
    resnet = timm.create_model("resnet34", pretrained=True)
    resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(
        7, 7), stride=(2, 2), padding=(3, 3), bias=False)

    m = resnet
    m = nn.Sequential(*list(m.children())[:-2])
    model = DynamicUnet(m, 1, (120, 120), norm_type=None).to(DEVICE)
#     x = cast(torch.randn(2, 10, 512, 512), TensorImage)
#     y = model(x)
    load_checkpoint(torch.load("models/UNetOrientation.pth.tar"), model)

    for filename in filenames:
        print(filename)
        createFolder(f"dat_midpoint/{filename}/Masks")
        
        filename_loader = get_loaders(
            f"dat_midpoint/{filename}/Divisions/",
            BATCH_SIZE*8,
            filename_transform, 
            filename,
            NUM_WORKERS,
            PIN_MEMORY,
        )
        make_predictions_Orientation(filename_loader, model, filename)
    
    for filename in filenames:
        _df = []
        dfDivisions = pd.read_pickle(f"dat_midpoint/{filename}/dfDivisions{filename}.pkl")
        for k in range(len(dfDivisions)):
            label = int(dfDivisions["Label"].iloc[k])
            mask = sm.io.imread(f"dat_midpoint/{filename}/Masks/division{label}.tif").astype(int)[0]
            ori_mask = maskOrientation(mask)
            
            _df.append(
                {
                    "Label": label,
                    "T": dfDivisions["T"].iloc[k],
                    "X": dfDivisions["X"].iloc[k],
                    "Y": dfDivisions["Y"].iloc[k],  
                    "Orientation": ori_mask,
                }
            )
            print(label, int(ori_mask), dfDivisions["T"].iloc[k], 
                  dfDivisions["X"].iloc[k], dfDivisions["Y"].iloc[k])
            
        df = pd.DataFrame(_df)
        df.to_pickle(f"dat_output/dfDivision{filename}.pkl")

In [9]:
main()

=> Loading checkpoint
Unwound18h13


  0%|          | 0/3 [00:00<?, ?it/s]

=> Loading checkpoint
Unwound18h13


100%|██████████| 3/3 [00:55<00:00, 18.43s/it]


0 37 1 264 401
1 -64 1 300 450
2 14 1 91 256
4 82 2 232 70
5 -53 2 351 391
6 -53 3 350 361
7 -69 3 470 188
8 49 3 504 501
9 -62 3 419 462
10 25 3 402 364
11 45 4 8 56
12 -69 5 437 197
13 31 6 322 282
14 -29 6 150 369
15 -27 7 93 400
16 -83 7 200 370
17 80 7 277 389
18 -50 7 262 430
19 -70 7 382 317
20 58 8 30 353
21 47 8 456 253
22 65 9 84 477
23 62 10 282 505
24 -87 10 7 414
25 -77 10 192 447
26 6 11 179 346
27 -28 11 46 459
28 -52 12 223 267
29 -83 12 280 437
30 51 12 50 361
31 3 12 37 391
33 44 13 86 427
34 41 13 267 315
35 17 13 135 259
36 -3 14 27 411
37 12 14 69 63
38 -9 14 176 253
39 70 15 227 19
40 -54 15 307 373
41 -47 16 76 22
42 -63 16 104 6
43 -54 17 79 324
44 -59 18 45 345
45 -51 19 9 15
46 70 19 502 33
47 53 23 23 219
48 -49 24 116 26
49 -83 25 493 314
50 -61 25 496 486
51 -63 26 5 135
52 76 28 135 145
53 22 29 337 104
54 52 29 116 57
55 84 31 128 37
56 -31 33 64 414
57 -79 33 490 322
58 -80 35 425 80
59 34 35 194 141
60 -76 36 15 142
61 -81 37 160 296
62 -49 38 463 243
6