In [1]:
import argparse
import json
import os
import subprocess
import matplotlib.pyplot as plt
import torch
import yaml
from crnn import CRNN, ArtifactRemovalCRNN
from dataloader import SliceDataset, SimulatedDataset, SimulatedSPFDataset, SliceDatasetAug
from deepinv.transform import Transform
from einops import rearrange
from radial import RadialDCLayer, to_torch_complex, MCNUFFT_CRNN
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from transform import VideoRotate, VideoDiffeo, SubsampleTime, MonophasicTimeWarp, TemporalNoise, TimeReverse
from ei import EILoss
from mc import MCLoss
from lsfpnet import LSFPNet, ArtifactRemovalLSFPNet
from radial_lsfp import MCNUFFT
from utils import prep_nufft, log_gradient_stats, plot_enhancement_curve, get_cosine_ei_weight, plot_reconstruction_sample, get_git_commit, save_checkpoint, load_checkpoint, to_torch_complex
from eval import eval_grasp, eval_sample
import csv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load data
split_file = "/gpfs/data/karczmar-lab/workspaces/rachelgordon/breastMRI-recon/ddei/data/data_split.json"
with open(split_file, "r") as fp:
    splits = json.load(fp)


train_patient_ids = splits["train"]
    
    
    
train_dataset = SliceDataset(
    root_dir="/ess/scratch/scratch1/rachelgordon/dce-8tf/binned_kspace",
    patient_ids=train_patient_ids,
    dataset_key="ktspace",
    file_pattern="*.h5",
    slice_idx=41,
    N_coils=16
)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
)

Number of files in root directory:  298


In [3]:
for measured_kspace, csmap, grasp_img in train_loader:
    print("kspace shape: ", measured_kspace.shape)

kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])


KeyboardInterrupt: 

In [2]:
# load data
split_file = "/gpfs/data/karczmar-lab/workspaces/rachelgordon/breastMRI-recon/ddei/data/data_split.json"
with open(split_file, "r") as fp:
    splits = json.load(fp)


train_patient_ids = splits["train"]
    
    
    
train_dataset = SliceDatasetAug(
    root_dir="/ess/scratch/scratch1/rachelgordon/fastMRI_breast_data/full_kspace",
    csmaps_dir="/ess/scratch/scratch1/rachelgordon/dce-8tf/cs_maps",
    patient_ids=train_patient_ids,
    dataset_key="kspace",
    file_pattern="*.h5",
    slice_idx=41,
    N_coils=16
)

train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
)



# output shape: kspace shape:  torch.Size([1, 2, 8, 16, 36, 640]) -- (B, I, T, C, Sp, Sam)

for measured_kspace, csmap, grasp_img in train_loader:
    print("kspace shape: ", measured_kspace.shape)

Number of files in root directory:  300
kspace shape:  torch.Size([1, 2, 9, 16, 32, 640])
kspace shape:  torch.Size([1, 2, 18, 16, 16, 640])
kspace shape:  torch.Size([1, 2, 9, 16, 32, 640])
kspace shape:  torch.Size([1, 2, 18, 16, 16, 640])
kspace shape:  torch.Size([1, 2, 6, 16, 48, 640])
kspace shape:  torch.Size([1, 2, 6, 16, 48, 640])
kspace shape:  torch.Size([1, 2, 18, 16, 16, 640])
kspace shape:  torch.Size([1, 2, 24, 16, 12, 640])
kspace shape:  torch.Size([1, 2, 6, 16, 48, 640])
kspace shape:  torch.Size([1, 2, 6, 16, 48, 640])
kspace shape:  torch.Size([1, 2, 9, 16, 32, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 24, 16, 12, 640])
kspace shape:  torch.Size([1, 2, 12, 16, 24, 640])
kspace shape:  torch.Size([1, 2, 6, 16, 48, 640])
kspace shape:  torch.Size([1, 2, 18, 16, 16, 640])
kspace shape:  torch.Size([1, 2, 12, 16, 24, 640])
kspace shape:  torch.Size([1, 2, 8, 16, 36, 640])
kspace shape:  torch.Size([1, 2, 6, 16, 48, 640])
ks

KeyboardInterrupt: 