# DEPRECATED
use `dataset/grappa.py` instead

In [1]:
from models.external.NAFNet.NAFNet_arch import NAFNet
import torch
import numpy as np
import lightning as L

from common.utils import save_reconstructions
from pathlib import Path
from collections import defaultdict
import fastmri
import h5py
import os
from torch.utils.data import DataLoader, Dataset
from typing import Optional, Tuple, List, Union, Final

PUBLIC_ACCS = ["acc4", "acc5", "acc8"]
root_path = "/home/Data/leaderboard"
out_path = "reconstructions/grappa"

In [4]:
for acc in os.listdir(root_path):
    reconstructions = defaultdict(dict)
    file_list = os.listdir(f"{root_path}/{acc}/image")
    for file in file_list:
        grappa = h5py.File(f"{root_path}/{acc}/image/{file}", "r")["image_grappa"]
        for i in range(grappa.shape[0]):
            reconstructions[file][i] = np.array(grappa[i])
    for fname in reconstructions:
        reconstructions[fname] = np.stack([reconstructions[fname][slice] for slice in sorted(reconstructions[fname])])

    if acc in PUBLIC_ACCS:
        save_reconstructions(reconstructions, Path(f"{out_path}/public"))

    else:
        save_reconstructions(reconstructions, Path(f"{out_path}/private"))

KeyboardInterrupt: 

In [2]:
Grappa_DataType = Tuple[
    torch.Tensor, # Grappa Image
    Optional[torch.Tensor], # Ground Truth Image (target)
]

class GrappaDataset(Dataset):
    def __init__(self, root_path: str, input_key: str, target_key: Optional[str] = None) -> None:
        self.root_path = root_path
        self.input_key = input_key
        self.target_key = target_key
        
        self.file_list = os.listdir(f"{root_path}/image")
        self.grappa_data = defaultdict(dict)
        self.target_data = defaultdict(dict)
        
        for file in self.file_list:
            grappa = h5py.File(f"{root_path}/image/{file}", "r")[input_key]
            for i in range(grappa.shape[0]):
                self.grappa_data[file][i] = torch.tensor(grappa[i])

        if target_key is not None:
            for file in self.file_list:
                target = h5py.File(f"{root_path}/image/{file}", "r")[target_key]
                for i in range(target.shape[0]):
                    self.target_data[file][i] = torch.tensor(target[i])

    def __len__(self) -> int:
        return len(self.file_list)

    def __getitem__(self, idx: int) -> Grappa_DataType:
        if self.target_key is not None:
            return self.grappa_data[self.file_list[idx]], self.target_data[self.file_list[idx]]
        return self.grappa_data[self.file_list[idx]], None


In [3]:
# grappa_dataloaders = {}
# for acc in os.listdir(root_path):
#     grappa_dataset = GrappaDataset(root_path, acc)
#     grappa_raw_data = h5py.File(f"{root_path}/{acc}/image/{file}", "r")["image_grappa"]
#     grappa_data = torch.tensor(grappa_raw_data)

In [4]:
class GrappaDataModule(L.LightningDataModule):
    PUBLIC_ACCS: Final[List[str]] = ["acc4", "acc5", "acc8"]

    def __init__(
            self,
            root: Union[str, Path],
            batch_size: int = 1,
            input_key: str = "image_grappa",
            target_key: str = "image_label",
        ) -> None:
        super().__init__()
        self.root = Path(root)
        self.batch_size = batch_size
        self.input_key = input_key
        self.target_key = target_key
        
        public_acc, private_acc = sorted(
            os.listdir(root / "leaderboard"),
            key=lambda x: (x not in self.PUBLIC_ACCS),
        )

        self.path_train = self.root / "train"
        self.path_val = self.root / "val"
        self.path_test = self.root / "leaderboard" / public_acc
        self.path_predict = self.root / "leaderboard" / private_acc

        self.data_train: Optional[GrappaDataset] = None
        self.data_val: Optional[GrappaDataset] = None
        self.data_test: Optional[GrappaDataset] = None
        self.data_predict: Optional[GrappaDataset] = None


    def setup(self, stage: Optional[str] = None) -> None:
        if stage in ("fit", "train") and self.dataset_train is None:
            self.dataset_train = GrappaDataset(self.path_train, self.input_key, self.target_key)

        if stage in ("fit", "val") and self.dataset_val is None:
            self.dataset_val = GrappaDataset(self.path_val, self.input_key, self.target_key)

        if stage in ("test",) and self.dataset_test is None:
            self.dataset_test = GrappaDataset(self.path_test, self.input_key)

        if stage in ("predict",) and self.dataset_predict is None:
            self.dataset_predict = GrappaDataset(self.path_predict, self.input_key)

    
    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=8)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset_val, batch_size=self.batch_size, shuffle=False, num_workers=8)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers=8)
    
    def predict_dataloader(self) -> DataLoader:
        return DataLoader(self.dataset_predict, batch_size=self.batch_size, shuffle=False, num_workers=8)

In [5]:
model = NAFNet(
    img_channel=2
)

In [6]:
model.to("cuda")

NAFNet(
  (intro): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (ending): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoders): ModuleList()
  (decoders): ModuleList()
  (middle_blks): Sequential(
    (0): NAFBlock(
      (conv1): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (conv3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      (sca): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (sg): SimpleGate()
      (conv4): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      (conv5): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      (norm1): LayerNorm2d()
      (norm2): LayerNorm2d()
      (dropout1): Identity()
      (dropout2): Identity()
    )
  )
  (ups): ModuleList()
  (downs): ModuleList()
)

In [4]:
model = NAFNet(
    img_channel=2,
    width=64,
    enc_blk_nums=[1, 1, 1, 28],
    middle_blk_num=4,
    dec_blk_nums=[1, 1, 1, 1],
).to("cuda")

NAFNet(
  (intro): Conv2d(2, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (ending): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (encoders): ModuleList()
  (decoders): ModuleList()
  (middle_blks): Sequential(
    (0): NAFBlock(
      (conv1): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (conv3): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      (sca): Sequential(
        (0): AdaptiveAvgPool2d(output_size=1)
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      )
      (sg): SimpleGate()
      (conv4): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      (conv5): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      (norm1): LayerNorm2d()
      (norm2): LayerNorm2d()
      (dropout1): Identity()
      (dropout2): Identity()
    )
    (1): NAFBlock(
      (conv1): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      (conv2