In [None]:
from src.engine.predictor import Predictor
from src.models.models import get_model
from src.utils.timer import Timer
from src.utils.util import summary_model_info

  warn(
  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from pathlib import Path
import torch
from torch import nn
import numpy as np
from PIL import Image
from tqdm import tqdm
import rich
from pprint import pp, pprint
import json
from src.utils.transform import get_transforms, VisionTransformersBuilder

output_dir = Path("TEMP")
model_config = {}
model = get_model("unet", model_config)
device = "cuda" if torch.cuda.is_available() else "cpu"


def preprocess(input: Path) -> torch.Tensor:
    image = Image.open(input).convert("L")

    builder = VisionTransformersBuilder()
    transforms = builder.resize((512, 512)).to_pil_image().convert_image_dtype().build()

    image_tensor = transforms(image).unsqueeze(0)
    return image_tensor


def postprocess(pred: torch.Tensor):
    pred[pred >= 0.5] = 255
    pred[pred < 0.5] = 0

    pred = pred.squeeze(0).squeeze(0).type(torch.uint8)
    return pred


class Predictor:
    def __init__(self, model: nn.Module):
        model = model.to(device)
        model.eval()
        self.model = model
        self.timer = Timer()

    @torch.inference_mode()
    def predict(self, inputs: list[Path]):
        for input in tqdm(inputs, desc="Predicting..."):

            with self.timer.timeit("preprocess"):
                x = preprocess(input)

            with self.timer.timeit("inference"):
                x = x.to(device)
                pred = self.model(x)

            with self.timer.timeit("postprocess"):
                pred = postprocess(pred).detach().cpu().numpy()

        cost = self.timer.total_elapsed_time()
        print(f"Predicting had cost {cost}s, average: {cost / len(inputs)}s")
        all_cost = self.timer.all_elapsed_time()
        rich.print_json(json.dumps(all_cost, indent=2))


input_path = Path("xxx")
inputs = [input for input in input_path.iterdir() if ".png" in input.suffix.lower()]
predictor = Predictor(model)
predictor.predict(inputs)
summary_model_info(model, (1, 512, 512), device)