In [1]:
import os
import math
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as tm
import torchvision.transforms.functional as tf
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm
%matplotlib inline
sys.path.append('../code')
sys.path.append('../../nn_tools/')

In [2]:
from core.data_utils import read_dcms, get_spacing, read_annotation, SPINAL_DISC_ID, SPINAL_VERTEBRA_ID, rotate_point
from core.visilization import visilize_coord, visilize_distmap
from core.key_point import KeyPointAcc, KeyPointDataLoader, KeyPointModel, NullLoss, SpinalModelBase
from core.key_point import KeyPointBCELossV2, SpinalModel, KeyPointModelV2, KeyPointBCELoss
from core.structure import DICOM, Study, construct_studies
from core.disease import DiseaseModel, DisDataLoader, Evaluator
from nn_tools import torch_utils

In [3]:
train_studies, train_annotation = construct_studies('../data/lumbar_train150', '../data/lumbar_train150_annotation.json')
valid_studies, valid_annotation = construct_studies('../data/train/', '../data/lumbar_train51_annotation.json')

100%|##########| 150/150 [00:57<00:00,  2.60it/s]
  0%|          | 0/51 [00:00<?, ?it/s]

{'T11-T12': 1}


100%|##########| 51/51 [00:29<00:00,  1.71it/s]


In [4]:
train_dataloader = DisDataLoader(train_studies, train_annotation, batch_size=4, sagittal_size=[512, 512],
                                 transverse_size=[256, 256], k_nearest=1, num_workers=2, num_rep=1,
                                 prob_rotate=1, max_angel=180)

In [5]:
# for study_uid, study in train_studies.items():
#     frame = study.t2_sagittal_middle_frame
#     assert (study_uid, frame.series_uid, frame.instance_uid) in train_annotation

# # for study_uid, study in valid_studies.items():
# #     frame = study.t2_sagittal_middle_frame
# #     assert (study_uid, frame.series_uid, frame.instance_uid) in valid_annotation

# for study_uid, study in train_studies.items():
#     if study.t2_transverse_uid is None:
#         print(study_uid)

# for data, label in tqdm(train_dataloader, ascii=True):
#     pass

In [6]:
dis_model = DiseaseModel(
    torch.load('../models/size512_rotate1_180_AdamW_1e-5.kp_model'),
    k_nearest=2, sagittal_size=[512, 512], transverse_size=[256, 256], agg_method='avg'
).cuda(0)
dis_model



DiseaseModel(
  (kp_model): KeyPointModel(
    (backbone): BackboneWithFPN(
      (body): IntermediateLayerGetter(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): FrozenBatchNorm2d()
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): FrozenBatchNorm2d()
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): FrozenBatchNorm2d()
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): FrozenBatchNorm2d()
            (relu): ReLU(inplace=True)
            (downsample): Sequential(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): FrozenBat

In [7]:
evaluator = Evaluator(dis_model, valid_studies, '../data/lumbar_train51_annotation.json')
optimizer = torch.optim.AdamW(dis_model.parameters(), lr=1e-6)
max_step = 50*len(train_dataloader)
result = torch_utils.fit(
    dis_model,
    train_data=train_dataloader,
    valid_data=None,
    optimizer=optimizer,
    max_step=max_step,
    loss=NullLoss(),
    metrics=[NullLoss()],
    is_higher_better=True,
    evaluate_per_steps=len(train_dataloader),
    evaluate_fn=evaluator,
#     checkpoint_dir='../models',
    early_stopping=5*len(train_dataloader),
)

using NullLoss as training loss, using NullLoss(higher is better) as early stopping metric


100%|##########| 38/38 [00:12<00:00,  2.96it/s]
100%|##########| 51/51 [00:04<00:00, 11.56it/s]


step 38 train NullLoss: 0.81514573097229; valid disease f1: 0.29345607461821155, key point accuracy: 0.8467023172905526, 


100%|##########| 38/38 [00:12<00:00,  3.00it/s]
100%|##########| 51/51 [00:04<00:00, 12.17it/s]


step 76 train NullLoss: 0.8092933297157288; valid disease f1: 0.2934561001556939, key point accuracy: 0.8556149732620321, 


100%|##########| 38/38 [00:12<00:00,  2.99it/s]
100%|##########| 51/51 [00:04<00:00, 12.01it/s]


step 114 train NullLoss: 0.8006364703178406; valid disease f1: 0.2987552101216311, key point accuracy: 0.8591800356506238, 


100%|##########| 38/38 [00:12<00:00,  2.98it/s]
100%|##########| 51/51 [00:04<00:00, 12.04it/s]


step 152 train NullLoss: 0.8206588625907898; valid disease f1: 0.29699765786360155, key point accuracy: 0.8449197860962567, 


100%|##########| 38/38 [00:12<00:00,  2.97it/s]
100%|##########| 51/51 [00:04<00:00, 12.03it/s]


step 190 train NullLoss: 0.8195558190345764; valid disease f1: 0.28535543726519985, key point accuracy: 0.8074866310160428, 


100%|##########| 38/38 [00:12<00:00,  2.97it/s]
100%|##########| 51/51 [00:04<00:00, 12.04it/s]


step 228 train NullLoss: 0.8390138149261475; valid disease f1: 0.277997372809638, key point accuracy: 0.7878787878787878, 


100%|##########| 38/38 [00:12<00:00,  2.97it/s]
100%|##########| 51/51 [00:04<00:00, 11.89it/s]


step 266 train NullLoss: 0.8093878030776978; valid disease f1: 0.3005039143076388, key point accuracy: 0.8431372549019608, 


 42%|####2     | 16/38 [00:05<00:07,  2.78it/s]


KeyboardInterrupt: 

In [9]:
result = []
for study in valid_studies.values():
    result.append(dis_model.eval()(study, True))

In [10]:
result

[{'studyUid': '1.2.3.4.5.75760356.2089',
  'data': [{'instanceUid': '1.2.3.4.5.75760356.2089.5731.5778357',
    'seriesUid': '1.2.3.4.5.75760356.2089.5731',
    'annotation': [{'point': [{'coord': [145, 71],
        'tag': {'identification': 'L1', 'vertebra': 'v2'},
        'zIndex': 4},
       {'coord': [140, 96],
        'tag': {'identification': 'L2', 'vertebra': 'v2'},
        'zIndex': 4},
       {'coord': [136, 120],
        'tag': {'identification': 'L3', 'vertebra': 'v2'},
        'zIndex': 4},
       {'coord': [134, 149],
        'tag': {'identification': 'L4', 'vertebra': 'v2'},
        'zIndex': 4},
       {'coord': [138, 174],
        'tag': {'identification': 'L5', 'vertebra': 'v2'},
        'zIndex': 4},
       {'coord': [148, 60],
        'tag': {'identification': 'T12-L1', 'disc': 'v2'},
        'zIndex': 4},
       {'coord': [142, 84],
        'tag': {'identification': 'L1-L2', 'disc': 'v2'},
        'zIndex': 4},
       {'coord': [137, 110],
        'tag': {'identific