In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from tqdm import tqdm

from src.datasets import AFTNetVolumeDataset
from src.models.aft import PAFT_DC
from src.models.cunet import CUNet
from src import utils
import torchmri

In [2]:
class RACUNet1(Module):
    def __init__(self) -> None:
        super().__init__()
        self.acunet = CUNet(
            in_channels=4,
            out_channels=4,
            layer_channels=[32, 64, 128, 256, 512],
            attention=True
        )

    def forward(self, isp_in: Tensor, ksp_in: Tensor) -> Tensor:
        isp_pred = isp_in + self.acunet(isp_in)
        ksp_pred = torchmri.fft.fftn(isp_pred, dim=(-2, -1))
        ksp_pred = apply_k_space_consistency(ksp_in, ksp_pred)
        isp_pred = torchmri.fft.ifftn(ksp_pred, dim=(-2, -1))
        return isp_pred


class RACUNet2(Module):
    def __init__(self) -> None:
        super().__init__()
        self.acunet = CUNet(
            in_channels=4,
            out_channels=4,
            layer_channels=[32, 64, 128, 256, 512],
            attention=True
        )

    def forward(self, ksp_in: Tensor) -> Tensor:
        ksp_out = ksp_in + self.acunet(ksp_in)
        ksp_out = apply_k_space_consistency(ksp_in, ksp_out)
        return ksp_out


def apply_k_space_consistency(k_space_under_gt, k_space_output):
    """
    Apply k-space data consistency by replacing the values in the k-space output
    with the undersampled ground truth values where the ground truth is not zero.

    Args:
        k_space_gt (torch.Tensor): The undersampled ground truth k-space data.
        k_space_output (torch.Tensor): The output k-space data from the model.

    Returns:
        torch.Tensor: The k-space output with data consistency applied.
    """
    # Ensure the tensors are on the same device
    k_space_under_gt = k_space_under_gt.to(k_space_output.device)

    # Create a mask where the undersampled ground truth is not zero
    mask = k_space_under_gt != 0

    # Apply the mask to update the k-space output with the ground truth values
    k_space_output[mask] = k_space_under_gt[mask]

    return k_space_output


class RACUNet_PAFT_RACUNet(Module):
    def __init__(self) -> None:
        super().__init__()
        self.aft = PAFT_DC(320)
        self.racunet1 = RACUNet1()
        self.racunet2 = RACUNet2()

    def forward(self, ksp_in: Tensor) -> Tensor:
        return self.racunet1(self.aft(self.racunet2(ksp_in)), ksp_in)

In [5]:
device = 'cuda:1'
model = RACUNet_PAFT_RACUNet()
f = 'runs/best_model.pt'
ckp = torch.load(f, map_location='cpu')
model.load_state_dict(ckp)
model.to(device)
model.eval()

RACUNet_PAFT_RACUNet(
  (aft): PAFT_DC(
    (fc): Sequential(
      (0): Linear(
        (real_layer): Linear(in_features=320, out_features=640, bias=True)
        (imag_layer): Linear(in_features=320, out_features=640, bias=True)
      )
      (1): LeakyReLU()
      (2): Linear(
        (real_layer): Linear(in_features=640, out_features=640, bias=True)
        (imag_layer): Linear(in_features=640, out_features=640, bias=True)
      )
      (3): LeakyReLU()
      (4): Linear(
        (real_layer): Linear(in_features=640, out_features=320, bias=True)
        (imag_layer): Linear(in_features=640, out_features=320, bias=True)
      )
    )
  )
  (racunet1): RACUNet1(
    (acunet): CUNet(
      (encoder): Encoder(
        (layers): ModuleList(
          (0): BasicBlock(
            (conv1): Conv2d(
              (real_conv): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (imag_conv): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [6]:
ps = sorted(Path('data/v0/test_50').glob('*.h5'))
df = pd.DataFrame(columns=['Acceleration Rate', 'Name', 'Metric', 'Value'])
Path('results').mkdir(exist_ok=True)

for acc in ['1x', '2x', '4x', '8x']:
    Path(f'results/{acc}').mkdir(exist_ok=True)
    test_dataset = AFTNetVolumeDataset(
        'data/v0', 'test_50', [acc], [0.], (640, 320))
    test_loader = DataLoader(
        test_dataset, batch_size=None, shuffle=False, num_workers=1)
    for i, batch in enumerate(tqdm(test_loader, desc=acc, dynamic_ncols=True)):
        isp_in, isp_true, ksp_in, ksp_true = batch['ispace_under'], batch[
            'ispace'], batch['kspace_under'], batch['kspace']
        if isp_true.shape[0] < 7:
            continue
        with torch.no_grad():
            isp_pred = model(ksp_in.to(device))

        img_true = torchmri.utils.rss(isp_true[:, :, 160:-160], dim=-3).numpy()
        img_pred = torchmri.utils.rss(
            isp_pred[:, :, 160:-160], dim=-3).cpu().numpy()
        np.save(f'results/{acc}/{ps[i].resolve().stem}.npy', img_pred)

        metrics_out = utils.img_metrics(img_true, img_pred)
        for k, v in metrics_out.items():
            df.loc[len(df)] = [acc, ps[i].resolve().stem, k, v]
df.to_csv('results/metrics.csv', index=False)

1x:   0%|          | 0/50 [00:00<?, ?it/s]

1x: 100%|██████████| 50/50 [02:07<00:00,  2.55s/it]
2x: 100%|██████████| 50/50 [02:00<00:00,  2.42s/it]
4x: 100%|██████████| 50/50 [02:01<00:00,  2.42s/it]
8x: 100%|██████████| 50/50 [02:00<00:00,  2.41s/it]


In [7]:
custom_dict = {
    '1x': 0, '2x': 1, '4x': 2, '8x': 3,
    'SSIM': 0, 'PSNR': 1, 'NRMSE': 2
}

df_mean = df.groupby(['Acceleration Rate', 'Metric'])['Value'].mean().sort_index(
    level=['Acceleration Rate', 'Metric'], key=lambda x: x.map(custom_dict))
df_std = df.groupby(['Acceleration Rate', 'Metric'])['Value'].std().sort_index(
    level=['Acceleration Rate', 'Metric'], key=lambda x: x.map(custom_dict))

mean_std = []
for mean, std in zip(df_mean, df_std):
    if mean < 1.1:
        mean_std.append(f'{mean:.3f} ± {std:.3f}')
    else:
        mean_std.append(f'{mean:.1f} ± {std:.1f}')

pd.Series(data=mean_std, index=df_mean.index, name='Value').to_csv(
    'results/metrics_mean-std.csv')

pd.Series(data=mean_std, index=df_mean.index).to_csv(
    'results/metrics_mean-std.txt', header=False, index=False)