In [1]:
import sys

sys.path.append("/p/home/jusers/pierschke1/shared/HyperBrain")
sys.path.append("/p/home/jusers/pierschke1/shared/HyperBrain\\source")
sys.path.append("c:\\Users\\robin\\Documents\\HyperBrain")
sys.path.append("c:\\Users\\robin\\Documents\\HyperBrain\\source")

In [75]:
from source.datasets.brain_dataset import BrainDataset
from source.loftr.backbone import ResNetFPN_32_8
from source.loftr.positional_encoding import PositionalEncoding
from source.loftr.transformer import LocalFeatureTransformer
from source.loftr.coarse_matching import CoarseMatching
from source.loftr.fine_matching import FineMatching
from source.loftr.fine_preprocess import FinePreprocess
from torch.utils.data import DataLoader
import torch
from einops.einops import rearrange
from torchvision.transforms import v2
from source.data_processing.image_reading import read_image
from torchvision.transforms import ToTensor
from source.data_processing.keypoints import translate_patch_midpoints_and_refine



In [92]:
patch_size = 32
confidence_threshold = 0.9
attention_type = "linear"
model_name = "5"
block_dimensions = [32, 64, 96, 128, 192]
fine_feature_size = block_dimensions[2]
coarse_feature_size = block_dimensions[-1]
backbone = ResNetFPN_32_8(block_dimensions=block_dimensions).cuda()
state_dict_backbone = torch.load(f"../../models/coarse_matching/{model_name}/backbone.pt")
state_dict_backbone = {k.replace("module.", ""): v for k, v in state_dict_backbone.items()}
backbone.load_state_dict(state_dict_backbone)

positional_encoding = PositionalEncoding(coarse_feature_size).cuda()

coarse_loftr = LocalFeatureTransformer(
    feature_dimension=coarse_feature_size,
    number_of_heads=8,
    layer_names=["self", "cross"] * 4,
    attention_type=attention_type
).cuda()
state_dict_coarse_loftr = torch.load(f"../../models/coarse_matching/{model_name}/coarse_loftr.pt")
state_dict_coarse_loftr = {k.replace("module.", ""): v for k, v in state_dict_coarse_loftr.items()}
coarse_loftr.load_state_dict(state_dict_coarse_loftr)

coarse_matcher = CoarseMatching(
    temperature=0.1, confidence_threshold=confidence_threshold
).cuda()

fine_preprocess = FinePreprocess(
    coarse_feature_size=coarse_feature_size,
    fine_feature_size=fine_feature_size,
    window_size=5,
    use_coarse_context=False,
).cuda()

fine_loftr = LocalFeatureTransformer(
    feature_dimension=fine_feature_size,
    number_of_heads=8,
    layer_names=["self", "cross"],
    attention_type=attention_type
).cuda()
state_dict_fine_loftr = torch.load(f"../../models/coarse_matching/{model_name}/fine_loftr.pt")
state_dict_fine_loftr = {k.replace("module.", ""): v for k, v in state_dict_fine_loftr.items()}
fine_loftr.load_state_dict(state_dict_fine_loftr)

fine_matching = FineMatching().cuda()

In [93]:
image_1 = read_image("../../data/cyto_downscaled_3344_3904_evaluation/B20_0524_Slice15.tif")
image_2 = read_image("../../data/cyto_downscaled_3344_3904_evaluation/B20_0525_Slice15.tif")
image_1, image_2 = ToTensor()(image_1), ToTensor()(image_2)

In [94]:
image_1.size(), image_2.size()

(torch.Size([1, 8000, 3462]), torch.Size([1, 7382, 3668]))

In [95]:
crop_size = 1280
image_1_x_num_windows = np.ceil(image_1.shape[2]/crop_size)
image_1_x_step_size = (image_1.shape[2] - crop_size) / (image_1_x_num_windows - 1)

image_1_y_num_windows = np.ceil(image_1.shape[1]/crop_size)
image_1_y_step_size = (image_1.shape[1] - crop_size) / (image_1_y_num_windows - 1)

image_2_x_num_windows = np.ceil(image_2.shape[2]/crop_size)
image_2_x_step_size = (image_2.shape[2] - crop_size) / (image_2_x_num_windows - 1)

image_2_y_num_windows = np.ceil(image_2.shape[1]/crop_size)
image_2_y_step_size = (image_2.shape[1] - crop_size) / (image_2_y_num_windows - 1)

In [115]:
image_1_crops = {}
image_2_crops = {}

for x in range(0, int(image_1_x_num_windows)-1):    
    for y in range(0, int(image_1_y_num_windows)-1):
        image_1_crops[f"{x*crop_size}, {y*crop_size}"] = image_1[:, y:y+crop_size, x:x+crop_size]
        
for x in range(0, int(image_2_x_num_windows)):    
    for y in range(0, int(image_2_y_num_windows)):
        image_2_crops[f"{x*crop_size}, {y*crop_size}"] = image_2[:, y:y+crop_size, x:x+crop_size]

In [116]:
matches_image_1 = []
matches_image_2 = []

with torch.no_grad():
    for coordinates_1, crop_1 in image_1_crops.items():
        for coordinates_2, crop_2 in image_2_crops.items():

            coarse_image_feature_1, fine_image_feature_1 = backbone(crop_1.cuda().unsqueeze(0))
            coarse_image_feature_2, fine_image_feature_2 = backbone(crop_2.cuda().unsqueeze(0))

            fine_height_width = fine_image_feature_1.shape[-1]
            coarse_height_width = coarse_image_feature_1.shape[-1]

            coarse_image_feature_1 = positional_encoding(coarse_image_feature_1)
            coarse_image_feature_2 = positional_encoding(coarse_image_feature_2)

            coarse_image_feature_1 = rearrange(
                coarse_image_feature_1, "n c h w -> n (h w) c"
            )
            coarse_image_feature_2 = rearrange(
                coarse_image_feature_2, "n c h w -> n (h w) c"
            )

            coarse_image_feature_1, coarse_image_feature_2 = coarse_loftr(
                coarse_image_feature_1, coarse_image_feature_2
            )

            coarse_matches_predicted = coarse_matcher(
                coarse_image_feature_1, coarse_image_feature_2
            )
            match_matrix_predicted = coarse_matches_predicted["match_matrix"]

            (
                fine_image_feature_1_unfold,
                fine_image_feature_2_unfold,
            ) = fine_preprocess(
                coarse_image_feature_1=coarse_image_feature_1,
                coarse_image_feature_2=coarse_image_feature_2,
                fine_image_feature_1=fine_image_feature_1,
                fine_image_feature_2=fine_image_feature_2,
                coarse_matches=coarse_matches_predicted,
                fine_height_width=fine_height_width,
                coarse_height_width=coarse_height_width
            )

            # Skip crops that do not contain any matches
            if fine_image_feature_1_unfold.size(0) == 0:
                continue

            fine_image_feature_1_unfold = fine_image_feature_1_unfold.to("cuda")
            fine_image_feature_2_unfold = fine_image_feature_2_unfold.to("cuda")


            fine_image_feature_1_unfold, fine_image_feature_2_unfold = fine_loftr(fine_image_feature_1_unfold, fine_image_feature_2_unfold)



            predicted_relative_coordinates = fine_matching(
                fine_image_feature_1_unfold, fine_image_feature_2_unfold
            )

            match_matrix_predicted = match_matrix_predicted.cpu()
            predicted_relative_coordinates = predicted_relative_coordinates.cpu()

            (
                crop_1_patch_mid_coordinates,
                crop_2_patch_mid_coordinates,
                crop_2_patch_mid_coordinates_refined,
            ) = translate_patch_midpoints_and_refine(
                match_matrix=match_matrix_predicted,
                patch_size=patch_size,
                relative_coordinates=predicted_relative_coordinates,
                fine_feature_size=fine_image_feature_1.shape[-1]
            )

            crop_1_patch_mid_coordinates = crop_1_patch_mid_coordinates.float()
            crop_1_patch_mid_coordinates += torch.Tensor([int(i) for i in coordinates_1.split(",")])
            crop_2_patch_mid_coordinates_refined += torch.Tensor([int(i) for i in coordinates_2.split(",")])

            matches_image_1.append(crop_1_patch_mid_coordinates)
            matches_image_2.append(crop_2_patch_mid_coordinates_refined)


In [117]:
matches_image_1 = torch.concat(matches_image_1)
matches_image_2 = torch.concat(matches_image_2)

In [103]:
from source.miscellaneous.evaluation import read_deformation, evaluate_test_image_pair

# TODO
- fix read_deformation function
- make sure that adding the coordinates is correct: Why are there to large coordinates?

In [106]:
def read_deformation() -> torch.Tensor:
    # Read deformation
    deformation_path = (
        r"../../data/cyto_downscaled_3344_3904_evaluation/deformation.pt"
    )
    deformation_file = h5py.File(deformation_path, "r")
    deformation = torch.Tensor(np.array(deformation_file["deformation"]) / 10)
    deformation = kornia.augmentation.Resize(size=(3463, 8000), resample="NEAREST")(deformation.permute(2,1,0))
    deformation = deformation.permute(0, 3, 2, 1).squeeze(0)
    deformation = torch.flip(deformation, dims=[-1])

    return deformation.long()

In [118]:
deformation = torch.load(r"../../data/cyto_downscaled_3344_3904_evaluation/deformation.pt")
(
    number_of_matches,
    average_distance,
    match_precision,
    auc,
    matches_per_patch,
    entropy,
) = evaluate_test_image_pair(matches_image_1, matches_image_2, deformation)

  x_indices = torch.searchsorted(x_borders, x_coords) - 1


In [123]:
deformation.shape

torch.Size([8000, 3463, 2])