In [1]:
import sys

sys.path.append("c:\\Users\\robin\\Documents\\HyperBrain")
sys.path.append("c:\\Users\\robin\\Documents\\HyperBrain\\source")

In [2]:
from source.datasets.brain_dataset import BrainDataset

from torch.utils.data import DataLoader
from source.lightning.setup_model import setup_lightning_loftr
import lightning as L

ModuleNotFoundError: No module named 'source.lightning.setup_model'

In [6]:
crop_size = 640
affine_transformation_range = 0.25
perspective_transformation_range = 0.0001
patch_size = 16
max_translation_shift = 50
fine_height_width = (crop_size//patch_size)*4
coarse_height_width = crop_size//patch_size
images_directory = "../../data/cyto_downscaled_3344_3904/"
use_train_data = True

dataset_train = BrainDataset(
    images_directory=images_directory,
    train=use_train_data,
    affine_transformation_range=affine_transformation_range,
    perspective_transformation_range=perspective_transformation_range,
    crop_size=crop_size,
    patch_size=patch_size,
    max_translation_shift=max_translation_shift,
    fine_height_width=fine_height_width,
    transform=v2.Compose([v2.Normalize(mean=[0.594], std=[0.204])]),
)

In [26]:
backbone._get_name()

'ResNetFPN_16_4'

In [7]:
use_coarse_context = False
clamp_predictions = True
use_l2_with_standard_deviation = True
temperature = 0.2

# block_dimensions_8_2 = [96, 128, 192]
# block_dimensions_16_4 = [64, 96, 128, 192]
block_dimensions = [64, 96, 128, 192]
fine_feature_size = block_dimensions[1]  # 1 for 16_4, 0 for 8_2
coarse_feature_size = block_dimensions[-1]
backbone = ResNetFPN_16_4(block_dimensions=block_dimensions)

positional_encoding = PositionalEncoding(coarse_feature_size)

coarse_loftr = LocalFeatureTransformer(
    feature_dimension=coarse_feature_size,
    number_of_heads=8,
    layer_names=["self", "cross"] * 4,
)

coarse_matcher = CoarseMatching(temperature=temperature, confidence_threshold=0.2)

fine_preprocess = FinePreprocess(
    coarse_feature_size=coarse_feature_size,
    fine_feature_size=fine_feature_size,
    window_size=5,
    use_coarse_context=use_coarse_context,
)

fine_loftr = LocalFeatureTransformer(
    feature_dimension=fine_feature_size,
    number_of_heads=8,
    layer_names=["self", "cross"],
)

fine_matching = FineMatching(
    return_standard_deviation=use_l2_with_standard_deviation,
    clamp_predictions=clamp_predictions,
)

In [8]:
class LitLoFTR(L.LightningModule):
    def __init__(self, backbone, positional_encoding, coarse_loftr, coarse_matcher, fine_preprocess, fine_loftr, fine_matching, coarse_loss, fine_loss, alpha = None, gamma = None) -> None:
        super().__init__()
        self.backbone = backbone
        self.positional_encoding = positional_encoding
        self.coarse_loftr = coarse_loftr
        self.coarse_matcher = coarse_matcher
        self.fine_preprocess = fine_preprocess
        self.fine_loftr = fine_loftr
        self.fine_matching = fine_matching
        self.coarse_loss = coarse_loss
        self.fine_loss = fine_loss
        self.alpha = alpha
        self.gamma = gamma
    
    def training_step(self, batch, batch_idx):
        image_1_crop, image_2_crop, match_matrix, relative_coordinates, coordinate_mapping = batch

        coarse_image_feature_1, fine_image_feature_1 = backbone(image_1_crop)
        coarse_image_feature_2, fine_image_feature_2 = backbone(image_2_crop)
        coarse_height_width = coarse_image_feature_1.shape[-1]
        fine_height_width = fine_image_feature_1.shape[-1]

        coarse_image_feature_1 = positional_encoding(coarse_image_feature_1)
        coarse_image_feature_2 = positional_encoding(coarse_image_feature_2)

        coarse_image_feature_1 = rearrange(coarse_image_feature_1, "n c h w -> n (h w) c")
        coarse_image_feature_2 = rearrange(coarse_image_feature_2, "n c h w -> n (h w) c")

        coarse_image_feature_1, coarse_image_feature_2 = coarse_loftr(coarse_image_feature_1, coarse_image_feature_2)

        coarse_matches = coarse_matcher(coarse_image_feature_1, coarse_image_feature_2)

        coarse_matches_ground_truth = {
            "batch_indices": match_matrix.nonzero()[:, 0],
            "row_indices": match_matrix.nonzero()[:, 1],
            "column_indices": match_matrix.nonzero()[:, 2],
        }

        fine_image_feature_1_unfold, fine_image_feature_2_unfold = fine_preprocess(
            coarse_image_feature_1=coarse_image_feature_1,
            coarse_image_feature_2=coarse_image_feature_2,
            fine_image_feature_1=fine_image_feature_1,
            fine_image_feature_2=fine_image_feature_2,
            coarse_matches=coarse_matches_ground_truth,
            fine_height_width=fine_height_width,
            coarse_height_width=coarse_height_width,
        )

        fine_image_feature_1_unfold, fine_image_feature_2_unfold = fine_loftr(
            fine_image_feature_1_unfold, fine_image_feature_2_unfold
        )

        predicted_relative_coordinates = fine_matching(
            fine_image_feature_1_unfold, fine_image_feature_2_unfold
        )

        if self.coarse_loss == "focal":
            coarse_loss_value = coarse_focal_loss(
                predicted_confidence=coarse_matches["confidence_matrix"],
                ground_truth_confidence=match_matrix,
                alpha=self.alpha,
                gamma=self.gamma,
            )

        elif self.coarse_loss == "official_focal":
            coarse_loss_value = coarse_official_focal_loss(
                predicted_confidence=coarse_matches["confidence_matrix"],
                ground_truth_confidence=match_matrix,
                alpha=self.alpha,
                gamma=self.gamma,
            )

        elif self.coarse_loss == "cross_entropy":
            coarse_loss_value = coarse_cross_entropy_loss(
                predicted_confidence=coarse_matches["confidence_matrix"],
                ground_truth_confidence=match_matrix,
            )

        if self.fine_loss == "l2":
            fine_loss_value = fine_l2_loss(
                coordinates_predicted=predicted_relative_coordinates,
                coordinates_ground_truth=relative_coordinates,
            )

        elif self.fine_loss == "l2_std":
            fine_loss_value = fine_l2_loss_with_standard_deviation(
                coordinates_predicted=predicted_relative_coordinates,
                coordinates_ground_truth=relative_coordinates,
            )

        loss = coarse_loss_value + fine_loss_value
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [10]:
loftr = LitLoFTR(backbone, positional_encoding, coarse_loftr, coarse_matcher, fine_preprocess, fine_loftr, fine_matching, "focal", "l2_std", 0.7, 2.5)
trainer = L.Trainer(max_epochs=1)
trainer.fit(loftr, DataLoader(dataset_train, batch_size=1, shuffle=True))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
c:\Users\robin\miniconda3\envs\superbrain\Lib\site-packages\lightning\pytorch\trainer\connectors\logger_connector\logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
Missing logger folder: c:\Users\robin\Documents\HyperBrain\notebooks\training\lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type                    | Params
----------------------------------------------------------------
0 | backbone            | ResNetFPN_16_4    

Epoch 0: 100%|██████████| 7/7 [00:17<00:00,  0.40it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 7/7 [00:17<00:00,  0.39it/s, v_num=0]
