In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

%cd ..
!hostname

/p/fastdata/pli/Private/oberstrass1/datasets/vervet1818-3d
jrlogin05.jureca


In [2]:
import torch
from torch.utils.data import DataLoader

import numpy as np
import h5py as h5

import math

from vervet1818_3d.resnet_wider import resnet50x1

In [17]:
from collections import namedtuple

import torchvision.transforms as transforms

from pli_styles.modality.inclination import corr_factor_transmittance_weighted, inclination_from_retardation
from pli_styles.modality.fom import hsv_fom

from pli.data import Section


Coord = namedtuple("Coord", ('x', 'y'))


def generate_fom(trans, dir, ret):
    corr = corr_factor_transmittance_weighted(
        trans,
        t_M=0.32,  # 0.23
        t_c=0.65,  # 0.65
        r_ref_wm=0.96, # 0.96
        r_ref_gm=0.16, # 0.16
        median_kernel_size=3
    )
    incl = inclination_from_retardation(
        ret,
        corr
    )
    fom = hsv_fom(
        np.rad2deg(dir),
        incl,
        saturation_min=0,
        inclination_scale='cosinus'
    )
    return fom

class SectionDataset(torch.utils.data.Dataset):

    def __init__(self, trans_file, dir_file, ret_file, patch_shape, out_shape, ram=True, norm_trans=None, norm_ret=None):
        # Expands the dataset to size input by repeating the provided ROIs
        # rois is a list of dicts with entries 'mask', 'ntrans', 'ret' and 'dir'
        super().__init__()

        # Scale to size that was used in Imagenet training
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(256),
            transforms.CenterCrop(256),
        ])

        self.ram = ram
        self.trans_section_mod = Section(path=trans_file)
        self.dir_section_mod = Section(path=dir_file)
        self.ret_section_mod = Section(path=ret_file)
        if ram:
            print("Load sections to RAM...")
            self.trans_section = np.array(self.trans_section_mod.image)
            self.dir_section = np.array(self.dir_section_mod.image)
            self.ret_section = np.array(self.ret_section_mod.image)
            print("All sections loaded to RAM")
        else:
            print("Do not load sections to RAM")
            self.trans_section = self.trans_section_mod.image
            self.dir_section = self.dir_section_mod.image
            self.ret_section = self.ret_section_mod.image

        if norm_trans is None:
            if self.trans_section_mod.norm_value is not None:
                self.norm_trans = self.trans_section_mod.norm_value
            else:
                print("[WARNING] Did not find a normalization value for Transmittance")
                self.norm_trans = 1.0
        else:
            self.norm_trans = norm_trans
            print(f"Normalize Transmittance by value of {self.norm_trans}")
        if norm_ret is None:
            self.norm_ret = 1.0
        else:
            self.norm_ret = norm_ret
            print(f"Normalize Retardation by value of {self.norm_ret}")
        self.brain_id = self.trans_section_mod.brain_id
        self.section_id = self.trans_section_mod.id
        self.section_roi = self.trans_section_mod.roi

        assert (patch_shape[0] - out_shape[0]) % 2 == 0  # Border symmetric
        assert (patch_shape[1] - out_shape[1]) % 2 == 0  # Border symmetric
        self.patch_shape = patch_shape
        self.out_shape = out_shape
        self.border = ((patch_shape[0] - out_shape[0]) // 2, (patch_shape[1] - out_shape[1]) // 2)
        self.shape = self.trans_section.shape

        self.coords = [Coord(x=x, y=y) for x in np.arange(0, self.shape[1], out_shape[1]) for y in
                       np.arange(0, self.shape[0], out_shape[0])]

    def __getitem__(self, i):
        x = self.coords[i].x
        y = self.coords[i].y

        b_y = self.border[0]
        b_x = self.border[1]

        pad_y_0 = max(b_y - y, 0)
        pad_x_0 = max(b_x - x, 0)
        pad_y_1 = max(y + (self.patch_shape[0] - b_y) - self.shape[0], 0)
        pad_x_1 = max(x + (self.patch_shape[1] - b_x) - self.shape[1], 0)

        trans_crop = np.array(
            self.trans_section[max(0, y - b_y):min(self.shape[0], y + self.patch_shape[0] - b_y),
            max(0, x - b_x):min(self.shape[1], x + self.patch_shape[1] - b_x)],
            dtype=np.float32
        ) / self.norm_trans
        ret_crop = np.array(
            self.ret_section[max(0, y - b_y):min(self.shape[0], y + self.patch_shape[0] - b_y),
            max(0, x - b_x):min(self.shape[1], x + self.patch_shape[1] - b_x)],
            dtype=np.float32
        ) / self.norm_ret
        dir_crop = np.deg2rad(
            self.dir_section[max(0, y - b_y):min(self.shape[0], y + self.patch_shape[0] - b_y),
            max(0, x - b_x):min(self.shape[1], x + self.patch_shape[1] - b_x)],
            dtype=np.float32
        )

        fom_crop = generate_fom(trans_crop, dir_crop, ret_crop)

        fom_crop = np.pad(fom_crop, ((pad_y_0, pad_y_1), (pad_x_0, pad_x_1), (0, 0)), mode='constant', constant_values=0.0)
        
        fom_crop = self.transforms(fom_crop)
       
        return {'x': x, 'y': y, 'crop': fom_crop}

    def __len__(self):
        return len(self.coords)

In [110]:
import os
import re
from typing import Tuple
from glob import glob
from tqdm import tqdm


def get_files(
        trans: str,
        dir: str,
        ret: str,
        out: str,
        rank: int = 0,
        size: int = 1
):
    print(trans)
    trans_files = sorted(glob(trans))
    dir_files = sorted(glob(dir))
    ret_files = sorted(glob(ret))

    if os.path.isdir(out):
        ft_files = []
        for d_f in dir_files:
            d_fname = os.path.splitext(os.path.basename(d_f))[0]
            d_base = os.path.splitext(d_fname)[0]
            ft_file = re.sub("direction", "Features", d_base, flags=re.IGNORECASE)
            if "Features" not in ft_file:
                ft_file += "_Features.h5"
            else:
                ft_file += ".h5"
            ft_files.append(os.path.join(out, ft_file))
    else:
        ft_files = [out]

    for i, (trans_file, dir_file, ret_file, ft_file) \
            in enumerate(zip(trans_files, dir_files, ret_files, ft_files)):
        if i % size == rank:
            if not os.path.isfile(ft_file):
                yield trans_file, dir_file, ret_file, ft_file
            else:
                print(f"{ft_file} already exists. Skip.")


def create_features(
        encoder: torch.nn.Module,
        section_loader: DataLoader,
        h_size: int,
        out_size: Tuple[int, ...],
        stride: Tuple[int, ...],
        rank: int
):
    print("Initialize output featuremaps...")
    h_features = np.zeros((*out_size, h_size), dtype=np.float32)
    
    def get_outputs(batch, network):
        with torch.no_grad():
            network.eval()
            h = network(
                batch['crop'].to(network.device),
            )
        return {'x': batch['x'], 'y': batch['y'], 'h': h}

    def transfer(batch, network):
        b = get_outputs(batch, network)
        for x, y, h in zip(b['x'], b['y'], b['h']):
            try:
                h_features[y // stride[0], x // stride[1]] = h.cpu().numpy()
            except:
                raise Exception(f"ERROR creating mask at x={x}, y={y}, shape={h_features.shape}")

    print("Start feature generation...")
    for batch in tqdm(section_loader, desc=f"Rank {rank}"):
        transfer(batch, encoder)

    return h_features


def save_features(
        h_features: np.ndarray,
        z_features: np.ndarray,
        ft_file: str,
        spacing: Tuple[float, ...] = (1.0, 1.0),
        origin: Tuple[float, ...] = (0.0, 0.0),
        dtype: str = None,
):
    print("Save features...")
    with h5.File(ft_file, "w") as f:
        feature_group = f.create_group("Features")
        dset_h = feature_group.create_dataset(f"{h_features.shape[-1]}", data=h_features.transpose(2, 0, 1), dtype=dtype)
        dset_h.attrs['spacing'] = spacing
        dset_h.attrs['origin'] = origin
        dset_z = feature_group.create_dataset(f"{z_features.shape[-1]}", data=z_features.transpose(2, 0, 1), dtype=dtype)
        dset_z.attrs['spacing'] = spacing
        dset_z.attrs['origin'] = origin
    print(f"Featuremaps created at {ft_file}")

In [10]:
sd = 'models/simclr_converter/resnet50-1x.pth'
device = 'cpu'

###

encoder = resnet50x1()
sd = torch.load(sd, map_location='cpu')
encoder.load_state_dict(sd['state_dict'])

encoder.device = device
encoder.to(device);

In [11]:
patch_size = 128
overlap = 0.0

###

patch_shape = (patch_size, patch_size)
stride = (int((1 - overlap) * patch_size), int((1 - overlap) * patch_size))

h_size = encoder.fc.in_features

In [112]:
trans = "in/volume-reconstruction/Transmittance_Retardation/nifti/Vervet1818aa_*NTransmittance.nii.gz"
dir = "in/volume-reconstruction/Transmittance_Retardation/nifti/Vervet1818aa_*Direction.nii.gz"
ret = "in/volume-reconstruction/Transmittance_Retardation/nifti/Vervet1818aa_*Retardation.nii.gz"
  
out = "data/aa/features/simclr-imagenet"

batch_size = 1
num_workers = 1

rank = 0
size = 1

ram = False
dtype = "float16"

###

for trans_file, dir_file, ret_file, ft_file in get_files(trans, dir, ret, out, rank, size):
    print(f"Initialize DataLoader for {trans_file}, {dir_file}, {ret_file}")

    section_dataset = SectionDataset(
        trans_file=trans_file,
        dir_file=dir_file,
        ret_file=ret_file,
        patch_shape=patch_shape,
        out_shape=stride,
        ram=ram,
    )
    section_loader = DataLoader(
        section_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers
    )

    out_size = tuple(math.ceil(s / stride[i]) for i, s in enumerate(section_dataset.shape[:2]))

    h_features = create_features(encoder, section_loader, h_size, out_size, stride, rank)

    spacing = tuple(stride[i] * s for i, s in enumerate(section_dataset.trans_section_mod.spacing))
    origin = section_dataset.trans_section_mod.origin

    save_features(h_features, ft_file, spacing, origin, dtype)

in/volume-reconstruction/Transmittance_Retardation/nifti/Vervet1818aa_*NTransmittance.nii.gz
Initialize DataLoader for in/volume-reconstruction/Transmittance_Retardation/nifti/Vervet1818aa_60mu_70ms_s0841_x00-21_y00-14_NTransmittance.nii.gz, in/volume-reconstruction/Transmittance_Retardation/nifti/Vervet1818aa_60mu_70ms_s0841_x00-21_y00-14_Direction.nii.gz, in/volume-reconstruction/Transmittance_Retardation/nifti/Vervet1818aa_60mu_70ms_s0841_x00-21_y00-14_Retardation.nii.gz
Do not load sections to RAM
Initialize output featuremaps...
Start feature generation...


  h_features[y // stride[0], x // stride[1]] = h.cpu().numpy()
Rank 0:   0%|          | 2/54675 [00:27<211:36:10, 13.93s/it]


KeyboardInterrupt: 