
## Prepare dataset

- 데이터를 준비하는 과정입니다.
- Spleen dataset은 여기에서 받을 수 있습니다: https://registry.opendata.aws/msd/.
- 해당 예시에서는 monai 코드를 참고하였습니다: https://github.com/aws-samples/amazon-sagemaker-medical-imaging-with-monai/blob/main/Segmentation/MONAI_BYOS_spleen_segmentation_3D_Demo.ipynb

------
Target: Spleen  
Modality: CT  
Size: 61 3D volumes (31 Training + 9 Validation + 1 Testing with label and 20 Testing without label)  
Source: Memorial Sloan Kettering Cancer Center  
Challenge: Large ranging foreground size


### 패키지 설치

- 패키지가 설치되지 않았다면 필요한 패키지를 설치합니다.

In [None]:
install = False

if install:
    !pip install  "monai[all]==0.8.0"
    !python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
    !python -c "import matplotlib" || pip install -q matplotlib
    %matplotlib inline

In [None]:
import os
import shutil
import glob
import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path

from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImage,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract

### 데이터 다운로드

- monai 데이터셋 샘플 경로에서 데이터를 다운로드 받습니다.
- 전처리를 진행한 후 샘플 이미지를 확인 해 봅니다.

In [None]:
resource = "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar"
md5 = "410d4a301da4e5b2f6f86ec3ddba524e"
compressed_file = "./Task09_Spleen.tar"

data_dir = "Spleen3D" 

if not os.path.exists(data_dir):
    download_and_extract(resource, compressed_file, f"{data_dir}/datasets", md5)

In [None]:
## transform the images through Compose
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),  ## keys include image and label with image first
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(
            1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=-57, a_max=164,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

In [None]:
train_images = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/imagesTr", "*.nii.gz")))
train_labels = sorted(
    glob.glob(os.path.join(data_dir, "datasets/Task09_Spleen/labelsTr", "*.nii.gz")))
data_dicts = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(train_images, train_labels)
]
train_files, test_demo_files = data_dicts[:-1], data_dicts[-1:]

In [None]:
check_ds = Dataset(data=test_demo_files, transform=val_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)

image, label = (check_data["image"][0][0], check_data["label"][0][0])
print(f"image shape: {image.shape}, label shape: {label.shape}")

# plot only the slice [:, :, 80]
plt.figure("check", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image[:, :, 80], cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[:, :, 80])
plt.show()

### 데이터 저장 및 s3 업로드

- 전처리 데이터를 train, test (여기서는 1개만 사용) 로 나누어서 저장합니다.


In [None]:
prefix = "medical_segmentation/monai_spleen_3d"

processed_train_path = os.path.join(data_dir,"processed","train")
processed_test_path = os.path.join(data_dir,"processed","test")

processed_train_images_path = os.path.join(processed_train_path, "imagesTr")
processed_train_labels_path = os.path.join(processed_train_path, "labelsTr")

processed_test_images_path = os.path.join(processed_test_path, "imagesTr")
processed_test_labels_path = os.path.join(processed_test_path, "labelsTr")

Path(processed_train_images_path).mkdir(parents=True, exist_ok=True)
Path(processed_train_labels_path).mkdir(parents=True, exist_ok=True)
print("Directory '%s' created" %processed_train_path)

Path(processed_test_images_path).mkdir(parents=True, exist_ok=True)
Path(processed_test_labels_path).mkdir(parents=True, exist_ok=True)
print("Directory '%s' created" %processed_test_path)

In [None]:
## copy dataset for training 
for file in train_files:
    images = file["image"]
    images_dest = processed_train_images_path
    label = file["label"]
    label_dest = processed_train_labels_path
    shutil.copy(images,images_dest)
    shutil.copy(label,label_dest)

In [None]:
## copy dataset for training 
for file in test_demo_files:
    images = file["image"]
    images_dest = processed_test_images_path
    label = file["label"]
    label_dest = processed_test_labels_path
    shutil.copy(images,images_dest)
    shutil.copy(label,label_dest)

### Local 에서 테스트 진행

- 앞의 과정에서 전처리 데이터를 저장해 놓았기 때문에, 일반적인 방식 (IDE로 작업 후 CLI 등을 활용해서 training을 진행) 으로 학습이 가능해 졌습니다.
- `code` 디렉토리의 `run_local_train.sh` 을 참고해서 local 에서 학습이 잘 돌아가는지 확인 해 봅니다.
- 실행 시 에러가 발생할 경우 해당 스크립트에 있는 이슈를 참고하면 됩니다.

### S3에 데이터 업로드
- 정상적으로 동작하는 것을 확인했다면, SageMaker managed training을 위해서 s3에 업로드 해 놓도록 합니다.

In [None]:
import sagemaker

sess = sagemaker.Session()
bucket = sess.default_bucket()

In [None]:
## upload training dataset to S3
s3_inputs = sess.upload_data(
    path=processed_train_path,
    key_prefix=f"{prefix}/train",
    bucket=bucket 
)

## upload testing dataset to S3
s3_demo_test = sess.upload_data(
    path=processed_test_images_path,
    key_prefix=f"{prefix}/test",
    bucket=bucket 
)

In [None]:
print(s3_inputs)
print(s3_demo_test)

In [None]:
%store s3_inputs
%store bucket

In [None]:
%store prefix
%store test_demo_files
%store check_ds
%store check_loader
%store check_data