In [2]:
# !pip install albumentations==1.2.1
# !pip install opencv-python-headless==4.5.2.52
# !pip install timm

from cgitb import reset
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import skimage as sm
import skimage.io
from matplotlib import pyplot as plt
import tifffile
import timm
from fastai.vision.all import *
import cv2

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8
NUM_WORKERS = 2
IMAGE_HEIGHT = 401
IMAGE_WIDTH = 401
PIN_MEMORY = True
LOAD_MODEL = True

# util

def createFolder(directory):
    try:
        if not os.path.exists(directory):
            os.makedirs(directory)
    except OSError:
        print("Error: Creating directory. " + directory)


# dataset


class VidDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
#         print(img_path)
        image = sm.io.imread(img_path).astype(np.float32)
        image = np.transpose(image, (2, 0, 1))
        images = torch.tensor(image/2**16).float()

        if self.transform is not None:
            transformed = self.transform(image=image[0], image0=image[1], image1=image[2])
            images[0] = transformed["image"]
            images[1] = transformed["image0"]
            images[2] = transformed["image1"]

            # save_transform(image, mask0, transformed)

        return images


def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])


def get_loaders(
    filename_dir,
    batch_size,
    filename_transform,
    num_workers=4,
    pin_memory=True
):
    filename_ds = VidDataset(
        image_dir=filename_dir,
        transform=filename_transform
    )

    filename_loader = DataLoader(
        filename_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return filename_loader


def make_predictions(loader, model, folder="dat/train/input", device="cuda"):
    model.eval()
    loop = tqdm(loader)
    with torch.no_grad():
        p
        for batch_idx, (x, pred_name) in enumerate(loop):
            x = x.to(device)
            preds = torch.sigmoid(model(x))
            preds = preds.detach().cpu().numpy()
            preds = np.asarray(preds*256, "uint8")
            for i in range(preds.shape[0]):
                tifffile.imwrite(
                    f"dat_output/{pred_name[i]}_pred.tif", preds[i])

    model.train()


def main():
    target3 = {'image0': 'image', 'image1': 'image', 'image2': 'image', 'mask': 'mask'}
    filename_transform = A.Compose(
        [
            A.Normalize(
                mean=0,
                std=1,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
        additional_targets=target3,
    )

    resnet = timm.create_model("resnet101")

    m = resnet
    m = nn.Sequential(*list(m.children())[:-2])
    model = DynamicUnet(m, 1, (401, 401), norm_type=None).to(DEVICE)

    load_checkpoint(torch.load("models/UNetPetal.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler()
    
    print("predicting images")

    createFolder(f"dat_output")
    FILENAME_IMG_DIR = f"dat/testing/input/"
    filename_loader = get_loaders(
        FILENAME_IMG_DIR,
        BATCH_SIZE,
        filename_transform,  # train_transform
        NUM_WORKERS,
        PIN_MEMORY,
    )

    make_predictions(filename_loader, model, folder="dat/testing/input/")


main()