In [1]:
!pip install pytorch-lightning

Defaulting to user installation because normal site-packages is not writeable
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.5.9-py3-none-any.whl (527 kB)
[K     |████████████████████████████████| 527 kB 10.2 MB/s eta 0:00:01
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 26.1 MB/s eta 0:00:01     |██████████████████████▏         | 573 kB 26.1 MB/s eta 0:00:01
Collecting torchmetrics>=0.4.1
  Downloading torchmetrics-0.7.0-py3-none-any.whl (396 kB)
[K     |████████████████████████████████| 396 kB 28.4 MB/s eta 0:00:01
Collecting setuptools==59.5.0
  Downloading setuptools-59.5.0-py3-none-any.whl (952 kB)
[K     |████████████████████████████████| 952 kB 22.6 MB/s eta 0:00:01
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.1.0-py3-none-any.whl (133 kB)
[K     |████████████████████████████████| 133 kB 29.3 MB/s eta 0:00:01
Collecting pyDeprecate==0.3.1
  Downloading pyDepr

Installing collected packages: setuptools, multidict, frozenlist, yarl, idna-ssl, charset-normalizer, asynctest, async-timeout, aiosignal, pyDeprecate, fsspec, aiohttp, torchmetrics, future, pytorch-lightning
Successfully installed aiohttp-3.8.1 aiosignal-1.2.0 async-timeout-4.0.2 asynctest-0.13.0 charset-normalizer-2.0.10 frozenlist-1.2.0 fsspec-2022.1.0 future-0.18.2 idna-ssl-1.1.0 multidict-5.2.0 pyDeprecate-0.3.1 pytorch-lightning-1.5.9 setuptools-59.5.0 torchmetrics-0.7.0 yarl-1.7.2
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [1]:
import os
gpu_number = "2"
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_number

def windowing_brain(img_array, channel=3, return_uint8=True):
    img_array = img_array.transpose((2, 0, 1))
    slice_range = np.arange(img_array.shape[0])
    slice_range = np.random.choice(slice_range, 32)
    slice_range = np.sort(slice_range)
    img_array = img_array[slice_range]
    if channel == 1:
        img_array = img_array + 40
        img_array = img_array + 40
        img_array = np.clip(img_array, 0, 160)
        img_array = img_array / 160

    elif channel == 3:
        dcm0 = img_array - 5
        dcm0 = np.clip(dcm0, 0, 50)
        dcm0 = dcm0 / 50.

        dcm1 = img_array + 0
        dcm1 = np.clip(dcm1, 0, 80)
        dcm1 = dcm1 / 80.

        dcm2 = img_array + 20
        dcm2 = np.clip(dcm2, 0, 200)
        dcm2 = dcm2 / 200.

        img_array = np.stack([dcm0, dcm1, dcm2], 0)
        
    if return_uint8: 
        return np.uint8(img_array * (2 ** 8 - 1))
    
    else: # the value is normalized to [0, 1]
        return img_array

In [2]:
import os
import numpy as np
from src.data_loader.classification import ClassifyDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from glob import glob

batch_size = 2
on_memory = False
argumentation_proba = 0.8
augmentation_policy_dict = {
    "positional": True,
    "noise": False,
    "elastic":False,
    "brightness_contrast": False,
    "color": False,
    "to_jpeg": False
}
image_channel_dict={"image": "rgb"}
preprocess_input = windowing_brain
target_size = (512, 512)
interpolation = "bilinear"
class_mode = "binary"
# class_mode = "categorical"
dtype="float32"

data_common_path = "./datasets/1.normal_npy/"

train_image_path_list = glob(f"{data_common_path}/train/*/*.npy")
valid_image_path_list = glob(f"{data_common_path}/valid/*/*.npy")
test_image_path_list = glob(f"{data_common_path}/test/*/*.npy")

label_list = os.listdir(f"{data_common_path}/train/")

label_to_index_dict = {label:index for index, label in enumerate(label_list)}
index_to_label_dict = {index:label for index, label in enumerate(label_list)}

label_policy = lambda label: label_to_index_dict[label]
def label_policy(label):
    age, gender = label.split("_")
    age, gender = float(age), float(gender)
    return [age, gender]

common_arg_dict = {
    "label_policy": label_policy,
    "argumentation_policy_dict": augmentation_policy_dict,
    "image_channel_dict": image_channel_dict,
    "preprocess_input": preprocess_input,
    "target_size": target_size,
    "interpolation": interpolation,
    "class_mode": class_mode,
    "dtype": dtype
}

num_workers = min(batch_size // 2, 8)

train_dataset = ClassifyDataset(image_path_list=train_image_path_list,
                               on_memory=on_memory,
                               argumentation_proba=argumentation_proba,
                                **common_arg_dict
)
valid_dataset = ClassifyDataset(image_path_list=valid_image_path_list,
                               on_memory=on_memory,
                               argumentation_proba=0,
                                **common_arg_dict
)
test_dataset = ClassifyDataset(image_path_list=test_image_path_list,
                               on_memory=False,
                               argumentation_proba=0,
                               **common_arg_dict
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
)

Total data num 20162
Total data num 2520
Total data num 2442


In [3]:
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
import torch
from torch import nn

In [26]:
# F.cross_entropy
# FM.accuracy
class Classifier(pl.LightningModule):
    def __init__(self, model, optimizer, loss_fun, metric, lr):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.loss_fun = loss_fun
        self.metric = metric
        self.lr = lr
        self.act_layer = nn.Sigmoid()
    def forward(self, x):
        output = self.model(x)
        output = self.act_layer(output)
        return output
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fun(logits, y)
        age_average_precision, genender_acc = self.metric(logits, y)
        metrics = {'train_precision': age_average_precision, 'train_acc': genender_acc, 'train_loss': loss}
        self.log_dict(metrics, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fun(logits, y)
        age_average_precision, genender_acc = self.metric(logits, y)
        metrics = {'val_precision': age_average_precision,'val_acc': genender_acc, 'val_loss': loss}
        self.log_dict(metrics)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fun(logits, y)
        age_average_precision, genender_acc = self.metric(logits, y)
        metrics = {'test_precision': age_average_precision, 'test_acc': genender_acc, 'test_loss': loss}
        self.log_dict(metrics)
    
    def configure_optimizers(self):
        return self.optimizer(self.model.parameters(), lr=self.lr)

In [5]:
from src.model.net_3d.resnet import ResNet, BasicBlock, Bottleneck
import torch
import pytorch_model_summary

base_model = ResNet(Bottleneck, [3, 4, 6, 3], [64, 128, 256, 512], n_input_channels=3, n_classes=2).cuda()
print(pytorch_model_summary.summary(base_model, torch.zeros(1, 3, 32, 512, 512).cuda(), show_input=True))
# base_model(torch.zeros(1, 1, 32, 512, 256).cuda).shape

-----------------------------------------------------------------------------------
           Layer (type)                Input Shape         Param #     Tr. Param #
               Conv3d-1       [1, 3, 32, 512, 512]          65,856          65,856
          BatchNorm3d-2      [1, 64, 32, 256, 256]             128             128
                 ReLU-3      [1, 64, 32, 256, 256]               0               0
            MaxPool3d-4      [1, 64, 32, 256, 256]               0               0
           Bottleneck-5      [1, 64, 16, 128, 128]         148,736         148,736
           Bottleneck-6     [1, 256, 16, 128, 128]         144,128         144,128
           Bottleneck-7     [1, 256, 16, 128, 128]         144,128         144,128
           Bottleneck-8     [1, 256, 16, 128, 128]         674,304         674,304
           Bottleneck-9        [1, 512, 8, 64, 64]         574,976         574,976
          Bottleneck-10        [1, 512, 8, 64, 64]         574,976         574,976
   

In [27]:
from timm.models import nf_resnet101 
import torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint


optimizer = torch.optim.Adam

def loss_fun(logit, y):
    logit_age = logit[..., 0]
    y_age = y[0]
    logit_gender = logit[..., 1]
    y_gender = y[1].float()
    age_loss = nn.L1Loss().cuda()(logit_age, y_age)
    gender_loss = nn.BCELoss().cuda()(logit_gender, y_gender)
        
    return age_loss + gender_loss

def metric(logit, y):
    logit_age = logit[..., 0]
    y_age = y[0]
    logit_gender = logit[..., 1]
    y_gender = y[1].int()
    
    age_average_precision = torchmetrics.MeanAbsoluteError().cuda()(logit_age, y_age)
    genender_acc = torchmetrics.Accuracy().cuda()(logit_gender, y_gender)
    
    return age_average_precision, genender_acc

logger = CSVLogger("logs", name="brain_classification")

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='logs/',
    filename='epoch{epoch:02d}-val_loss{val_loss:.2f}',
    auto_insert_metric_name=False,
    save_top_k=5, 
)

model = Classifier(base_model, optimizer, loss_fun, metric, 1e-6).cuda()
# model = Classifier.load_from_checkpoint(checkpoint_path="./logs/epoch03-val_loss0.14.ckpt",
#                                                                                  model=base_model,
#                                                                                  optimizer=optimizer,
#                                                                                  loss_fun=loss_fun,
#                                                                                  metric=metric)

In [28]:
trainer = pl.Trainer(gpus=1, logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model, train_loader, valid_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name      | Type    | Params
--------------------------------------
0 | model     | ResNet  | 46.2 M
1 | act_layer | Sigmoid | 0     
--------------------------------------
46.2 M    Trainable params
0         Non-trainable params
46.2 M    Total params
184.812   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Baseline 성능
- sex accuracy: 98 %
- age: L1 distance -> age < 3 years
    - age를 잘 맞추는 게 임상적으로 더 의미 있음.
    
- modeling
    - Conv3D vs. CNN + Transformer
        - Conv3D: 학습했을 때 3 살 내로 된다면?? -> 심플하니까
        - 만약 잘못한다 -> CNN + Transformer
- Augmentations
    - geometric: flip, rotation 
    - blur, ...?
    
- 32 장 기준으로 만들기
    - 32장 
    - 32장 이상 -> 일정한 간격으로 32장 골라 쓰기

# Unused Code

In [7]:
import os
import random
import math
from tqdm import tqdm
from glob import glob
import shutil
data_common_path = "./datasets/1.normal_npy/"
os.makedirs(f"{data_common_path}/train", exist_ok=True)
os.makedirs(f"{data_common_path}/valid", exist_ok=True)
os.makedirs(f"{data_common_path}/test", exist_ok=True)

data_path = glob(f"{data_common_path}/*")

data_path_list = glob(f"{data_common_path}/*")[:-3]

for data_path in tqdm(data_path_list):
    source = os.path.basename(data_path)
    npy_path_list = glob(f"{data_path}/*.npy") 
    random.shuffle(npy_path_list)
    npy_num = len(npy_path_list)
    train_num = math.ceil(npy_num * 0.8)
    valid_num = math.ceil(npy_num * 0.9)
    for npy_index, npy_path in enumerate(npy_path_list):
        npy_basename = os.path.basename(npy_path)
        if npy_index < train_num:
            target = "train"
        elif npy_index < valid_num:
            target = "valid"
        else:
            target = "test"
            
        os.makedirs(f"{data_common_path}/{target}/{source}", exist_ok=True)    
        new_npy_path = f"{data_common_path}/{target}/{source}/{npy_basename}"
        shutil.move(npy_path, new_npy_path)



0it [00:00, ?it/s][A


In [13]:
temp = np.zeros((48,512,512))
slice_range = np.arange(48)
slice_range = np.random.choice(slice_range, 32)
slice_range = np.sort(slice_range)

temp[slice_range].shape

(32, 512, 512)

In [None]:
slice_range