In [1]:
import json
import os
import math
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as tm
import torchvision.transforms.functional as tf
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm
%matplotlib inline
sys.path.append('../code')
sys.path.append('../../nn_tools/')

In [2]:
from core.data_utils import read_dcms, get_spacing, read_annotation, SPINAL_DISC_ID, SPINAL_VERTEBRA_ID, rotate_point
from core.visilization import visilize_coord, visilize_distmap, visilize_annotation
from core.key_point import KeyPointAcc, KeyPointDataLoader, KeyPointModel, NullLoss, SpinalModelBase
from core.key_point import KeyPointBCELossV2, SpinalModel, KeyPointBCELoss, KeyPointModelV2
from core.structure import DICOM, Study, construct_studies
from core.disease import DiseaseModelBase, DiseaseModel, DisDataLoader, Evaluator
from nn_tools import torch_utils

In [3]:
train_studies, train_annotation, train_counter = construct_studies(
    '../data/lumbar_train150', '../data/lumbar_train150_annotation.json')
valid_studies, valid_annotation, valid_counter = construct_studies(
    '../data/train/', '../data/lumbar_train51_annotation.json')

100%|##########| 150/150 [01:22<00:00,  1.82it/s]
  0%|          | 0/51 [00:00<?, ?it/s]

{'T11-T12': 1}


100%|##########| 51/51 [00:26<00:00,  1.91it/s]


In [4]:
# vertebra_labels, disc_labels = [], []
# for k, v in train_annotation.items():
#     vertebra_labels.append(v[0])
#     disc_labels.append(v[1])
# vertebra_labels = torch.cat(vertebra_labels)
# disc_labels = torch.cat(disc_labels)

# from collections import Counter
# vertebra_counter = Counter(vertebra_labels[:, -1].numpy().tolist())
# vertebra_counter = torch.tensor([vertebra_counter[i] for i in range(len(vertebra_counter))], dtype=torch.float)
# vertebra_weight = vertebra_counter.mean() / vertebra_counter
# print(vertebra_weight)

# disc_counter = Counter(disc_labels[:, -1].numpy().tolist())
# disc_counter = torch.tensor([disc_counter[i] for i in range(len(disc_counter))], dtype=torch.float)
# disc_weight = disc_counter.mean() / disc_counter
# print(disc_weight)

In [5]:
# {k: len(v) for k, v in train_counter.items()}

In [6]:
# {k: len(v) for k, v in valid_counter.items()}

In [14]:
train_dataloader = DisDataLoader(train_studies, train_annotation, batch_size=8, sagittal_size=[512, 512],
                                 transverse_size=[256, 256], k_nearest=1, num_workers=4, num_rep=20,
                                 prob_rotate=1, max_angel=180)

In [15]:
# for study_uid, study in train_studies.items():
#     frame = study.t2_sagittal_middle_frame
#     assert (study_uid, frame.series_uid, frame.instance_uid) in train_annotation

# for study_uid, study in valid_studies.items():
#     frame = study.t2_sagittal_middle_frame
#     assert (study_uid, frame.series_uid, frame.instance_uid) in valid_annotation

# for study_uid, study in train_studies.items():
#     if study.t2_transverse_uid is None:
#         print(study_uid)

# for data, label in tqdm(train_dataloader, ascii=True):
#     pass

In [16]:
train_images = {}
for study_uid, study in train_studies.items():
    frame = study.t2_sagittal_middle_frame
    train_images[(study_uid, frame.series_uid, frame.instance_uid)] = frame.image

In [17]:
backbone = resnet_fpn_backbone('resnet50', True)
spinal_model = SpinalModel(train_images, train_annotation,
                           num_candidates=128, num_selected_templates=8,
                           max_translation=0.05, scale_range=[0.9, 1.1], max_angel=10)
kp_model = KeyPointModelV2(backbone, len(SPINAL_VERTEBRA_ID), len(SPINAL_DISC_ID),
                           pixel_mean=torch.tensor(0.5), pixel_std=torch.tensor(1), dropout=0,
                           loss=KeyPointBCELossV2(lamb=1), spinal_model=spinal_model, loss_scaler=100,
                           num_cascades=2
                           ).cuda(0)
dis_model = DiseaseModel(
    kp_model, sagittal_size=[512, 512], loss_scaler=0.01, use_kp_loss=True, share_backbone=True
).cuda(0)
# dis_model = torch.nn.DataParallel(dis_model, device_ids=[0, 1])
dis_model

DiseaseModel(
  (backbone): KeyPointModelV2(
    (backbone): BackboneWithFPN(
      (body): IntermediateLayerGetter(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): FrozenBatchNorm2d()
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): Bottleneck(
            (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn1): FrozenBatchNorm2d()
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): FrozenBatchNorm2d()
            (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (bn3): FrozenBatchNorm2d()
            (relu): ReLU(inplace=True)
            (downsample): Sequential(
              (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (1): FrozenB

In [11]:
# small_studies, train_annotation, train_counter = construct_studies(
#     '../data/small/', '../data/lumbar_train150_annotation.json')
# small_dataloader = DisDataLoader(small_studies, train_annotation, batch_size=8, sagittal_size=[512, 512],
#                                  transverse_size=[256, 256], k_nearest=1, num_workers=3, num_rep=10,
#                                  prob_rotate=1, max_angel=180)
# batch_data, batch_label = next(iter(small_dataloader))

# visilize_coord(batch_data[0][0], batch_data[2][0])

# batch_pred = dis_model(*batch_data)

# tf.to_pil_image(batch_pred[1][0])

In [12]:
# study = Study('../data/lumbar_train150/study10/')
# pred_coord, *_ = dis_model.eval().module(study)
# human_coord = study.t2_sagittal_middle_frame.pixel_coord2human_coord(pred_coord.cpu())
# human_coord
# study.t2_transverse.point_distance(human_coord)
# [dicom.image_position for dicom in study.t2_transverse]
# study.t2_transverse.k_nearest(human_coord, 1, 8)
# k_nearest = study.t2_transverse_k_nearest(pred_coord.cpu(), 2, [256, 256], 8)
# k_nearest.shape
# tf.to_pil_image(k_nearest[8][0])

In [None]:
evaluator = Evaluator(
    dis_model, valid_studies, '../data/lumbar_train51_annotation.json', num_rep=20, max_dist=6,
#     metric='macro precision'
)
optimizer = torch.optim.AdamW(dis_model.parameters(), lr=1e-5)
max_step = 50*len(train_dataloader)
fit_result = torch_utils.fit(
    dis_model,
    train_data=train_dataloader,
    valid_data=None,
    optimizer=optimizer,
    max_step=max_step,
    loss=NullLoss(),
    metrics=[NullLoss()],
    is_higher_better=True,
    evaluate_per_steps=len(train_dataloader),
    evaluate_fn=evaluator,
    checkpoint_dir='../models',
#     early_stopping=5*len(train_dataloader),
)

using NullLoss as training loss, using NullLoss(higher is better) as early stopping metric


100%|##########| 375/375 [01:39<00:00,  3.77it/s]
100%|##########| 1020/1020 [00:58<00:00, 17.44it/s]


step 375 train NullLoss: 0.3456895649433136
valid macro f1: 0.11663050384249678
valid micro f1: 0.10184402905461612
valid avg key point acc: 0.159950784663046
valid macro precision: 0.3769636241674067
valid disc v1 precision: 0.17266189405310114
valid disc v2 precision: 0.290456448924776
valid disc v3 precision: 0.3000000285714245
valid disc v4 precision: 0.5164835128607664
valid disc v5 precision: 0.277310961796477
valid vertebra v1 precision: 0.5999999789473729
valid vertebra v2 precision: 0.4818325440179291


100%|##########| 375/375 [01:41<00:00,  3.70it/s]
100%|##########| 1020/1020 [00:59<00:00, 17.28it/s]


step 750 train NullLoss: 0.11682767421007156
valid macro f1: 0.10651488892341233
valid micro f1: 0.12390269252631435
valid avg key point acc: 0.14873584327171624
valid macro precision: 0.37436403564142984
valid disc v1 precision: 0.34482760404280416
valid disc v2 precision: 0.19913422518318397
valid disc v3 precision: 0.244186076257433
valid disc v4 precision: 0.4571428693877516
valid disc v5 precision: 0.04255328881844918
valid vertebra v1 precision: 0.5223880563599916
valid vertebra v2 precision: 0.8103161294403951


100%|##########| 375/375 [01:41<00:00,  3.69it/s]
100%|##########| 1020/1020 [00:57<00:00, 17.61it/s]


step 1125 train NullLoss: 0.08557390421628952
valid macro f1: 0.2037840898055412
valid micro f1: 0.2343221397741601
valid avg key point acc: 0.2630228812916918
valid macro precision: 0.4311544587635584
valid disc v1 precision: 0.4063745057221313
valid disc v2 precision: 0.4157782551906928
valid disc v3 precision: 0.273062747375443
valid disc v4 precision: 0.5849056443574255
valid disc v5 precision: 0.017751536360764927
valid vertebra v1 precision: 0.5634517702079421
valid vertebra v2 precision: 0.756756752130509


100%|##########| 375/375 [01:41<00:00,  3.70it/s]
100%|##########| 1020/1020 [00:58<00:00, 17.38it/s]


step 1500 train NullLoss: 0.07829757779836655
valid macro f1: 0.1848568152398223
valid micro f1: 0.2211652856182101
valid avg key point acc: 0.23787586551488596
valid macro precision: 0.4324737644247196
valid disc v1 precision: 0.5985130074902227
valid disc v2 precision: 0.46300716167030254
valid disc v3 precision: 0.1805054382306543
valid disc v4 precision: 0.3863636621900768
valid disc v5 precision: 0.06944450424381884
valid vertebra v1 precision: 0.688311663855628
valid vertebra v2 precision: 0.641170913292334


100%|##########| 375/375 [01:39<00:00,  3.77it/s]
100%|##########| 1020/1020 [00:59<00:00, 17.21it/s]


step 1875 train NullLoss: 0.06678938120603561
valid macro f1: 0.21801682853135992
valid micro f1: 0.23348815827940772
valid avg key point acc: 0.3003010984609384
valid macro precision: 0.421848249117435
valid disc v1 precision: 0.48059149794486145
valid disc v2 precision: 0.11478600720677017
valid disc v3 precision: 0.2843750134765617
valid disc v4 precision: 0.6277372076296047
valid disc v5 precision: 0.12376241348886995
valid vertebra v1 precision: 0.6449999855000015
valid vertebra v2 precision: 0.6766856185753758


100%|##########| 375/375 [01:40<00:00,  3.73it/s]
100%|##########| 1020/1020 [00:58<00:00, 17.42it/s]


step 2250 train NullLoss: 0.060649801045656204
valid macro f1: 0.2541467784872816
valid micro f1: 0.28256544167164216
valid avg key point acc: 0.3360970841924703
valid macro precision: 0.4414392490482601
valid disc v1 precision: 0.5729166644965279
valid disc v2 precision: 0.3234265796004693
valid disc v3 precision: 0.047120442559139136
valid disc v4 precision: 0.6136363464187354
valid disc v5 precision: 0.2500000240384592
valid vertebra v1 precision: 0.6465517115041628
valid vertebra v2 precision: 0.6364229747203267


100%|##########| 375/375 [01:41<00:00,  3.68it/s]
100%|##########| 1020/1020 [00:57<00:00, 17.70it/s]


step 2625 train NullLoss: 0.059307653456926346
valid macro f1: 0.20352458156473294
valid micro f1: 0.22257426357855098
valid avg key point acc: 0.29890934080669646
valid macro precision: 0.3842788624041819
valid disc v1 precision: 0.5369003676420528
valid disc v2 precision: 0.20445345326099382
valid disc v3 precision: 0.038690503649374784
valid disc v4 precision: 0.507812498779297
valid disc v5 precision: 0.2331288671007525
valid vertebra v1 precision: 0.5450980356785854
valid vertebra v2 precision: 0.6238683107182171


100%|##########| 375/375 [01:40<00:00,  3.72it/s]
Traceback (most recent call last):
  File "/home/cxt/anaconda3/envs/torch15/lib/python3.8/multiprocessing/queues.py", line 245, in _feed
    send_bytes(obj)
  File "/home/cxt/anaconda3/envs/torch15/lib/python3.8/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/home/cxt/anaconda3/envs/torch15/lib/python3.8/multiprocessing/connection.py", line 411, in _send_bytes
    self._send(header + buf)
  File "/home/cxt/anaconda3/envs/torch15/lib/python3.8/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
OSError: [Errno 9] Bad file descriptor
Traceback (most recent call last):
  File "/home/cxt/anaconda3/envs/torch15/lib/python3.8/multiprocessing/queues.py", line 235, in _feed
    close()
  File "/home/cxt/anaconda3/envs/torch15/lib/python3.8/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/home/cxt/anaconda3/envs/torch15/lib

step 3000 train NullLoss: 0.053144317120313644
valid macro f1: 0.26902734912373144
valid micro f1: 0.2961275669767867
valid avg key point acc: 0.3533634501407075
valid macro precision: 0.44675514873089195
valid disc v1 precision: 0.6294765804551906
valid disc v2 precision: 0.247818507929546
valid disc v3 precision: 0.14218011174501843
valid disc v4 precision: 0.5151515128558314
valid disc v5 precision: 0.23076926035502635
valid vertebra v1 precision: 0.7900355665455113
valid vertebra v2 precision: 0.5718545012301193


100%|##########| 375/375 [01:39<00:00,  3.77it/s]
100%|##########| 1020/1020 [00:59<00:00, 17.25it/s]


step 3375 train NullLoss: 0.05078485608100891
valid macro f1: 0.27897641563776693
valid micro f1: 0.3267467040870124
valid avg key point acc: 0.37439219379463967
valid macro precision: 0.4449820411963089
valid disc v1 precision: 0.6077694208579093
valid disc v2 precision: 0.2834645737491473
valid disc v3 precision: 0.5172413785176377
valid disc v4 precision: 0.2686567509467536
valid disc v5 precision: 0.08715600117834851
valid vertebra v1 precision: 0.7368420886426604
valid vertebra v2 precision: 0.6137440744817053


100%|##########| 375/375 [01:40<00:00,  3.72it/s]
100%|##########| 1020/1020 [00:58<00:00, 17.45it/s]


step 3750 train NullLoss: 0.04977820813655853
valid macro f1: 0.3303986313022262
valid micro f1: 0.3939307207056541
valid avg key point acc: 0.4236565568806844
valid macro precision: 0.47213611452656146
valid disc v1 precision: 0.5372781056265538
valid disc v2 precision: 0.3894878735987096
valid disc v3 precision: 0.4130879381150128
valid disc v4 precision: 0.3949044719866915
valid disc v5 precision: 0.2622222433580228
valid vertebra v1 precision: 0.5236686376527433
valid vertebra v2 precision: 0.7843035313481962


100%|##########| 375/375 [01:40<00:00,  3.72it/s]
100%|##########| 1020/1020 [00:57<00:00, 17.80it/s]


step 4125 train NullLoss: 0.04714526981115341
valid macro f1: 0.34386695961157
valid micro f1: 0.4135737026652078
valid avg key point acc: 0.46656877393888363
valid macro precision: 0.4643366540622344
valid disc v1 precision: 0.6077812805934536
valid disc v2 precision: 0.16775885536674037
valid disc v3 precision: 0.18525521038732667
valid disc v4 precision: 0.602150526650481
valid disc v5 precision: 0.2629482260599023
valid vertebra v1 precision: 0.5421994863325071
valid vertebra v2 precision: 0.8822629930452295


100%|##########| 375/375 [01:40<00:00,  3.74it/s]
100%|##########| 1020/1020 [00:58<00:00, 17.47it/s]


step 4500 train NullLoss: 0.045596715062856674
valid macro f1: 0.31970403482447346
valid micro f1: 0.3618046458384217
valid avg key point acc: 0.42213283407584395
valid macro precision: 0.4602354489411279
valid disc v1 precision: 0.710526310971938
valid disc v2 precision: 0.5405007351840726
valid disc v3 precision: 0.1670146276820615
valid disc v4 precision: 0.38888890123456654
valid disc v5 precision: 0.0796460548985792
valid vertebra v1 precision: 0.7242424106519751
valid vertebra v2 precision: 0.6108291019647027


100%|##########| 375/375 [01:39<00:00,  3.78it/s]
100%|##########| 1020/1020 [00:58<00:00, 17.30it/s]


step 4875 train NullLoss: 0.044691555202007294
valid macro f1: 0.3518974231412559
valid micro f1: 0.424821003864251
valid avg key point acc: 0.47372614678142594
valid macro precision: 0.4623134867622703
valid disc v1 precision: 0.6102175948870843
valid disc v2 precision: 0.2574385573754578
valid disc v3 precision: 0.4477611959790599
valid disc v4 precision: 0.3945946059897723
valid disc v5 precision: 0.08914731867075049
valid vertebra v1 precision: 0.6934097310366917
valid vertebra v2 precision: 0.7436254033970755


100%|##########| 375/375 [01:40<00:00,  3.71it/s]
 56%|#####6    | 574/1020 [00:32<00:26, 17.11it/s]

In [11]:
dis_model.kp_model = torch.load('../models/2020070102.kp_model_v2', map_location='cuda:0')
evaluator = Evaluator(dis_model, valid_studies, '../data/lumbar_train51_annotation.json',
                      num_rep=100, max_dist=6)
evaluator()

100%|##########| 5100/5100 [03:04<00:00, 27.65it/s]


[('macro f1', 0.6117107510148846),
 ('micro f1', 0.7109671186430598),
 ('avg key point acc', 0.8585061880156308),
 ('macro precision', 0.5431540563052295),
 ('disc v1 precision', 0.7410078856683826),
 ('disc v2 precision', 0.5044110818622162),
 ('disc v3 precision', 0.33070411761945323),
 ('disc v4 precision', 0.5620817834634679),
 ('disc v5 precision', 0.1303009601210605),
 ('vertebra v1 precision', 0.7496183190863004),
 ('vertebra v2 precision', 0.7839542463157259)]

In [22]:
# torch.save(dis_model.cpu(), '../models/2020070103.dis_model')

In [18]:
# testA_studies = construct_studies('../data/lumbar_testA50/')

# result = []
# for study in tqdm(testA_studies.values(), ascii=True):
#     result.append(dis_model.eval()(study, True))

# i = 9
# visilize_annotation(testA_studies[result[i]['studyUid']].t2_sagittal_middle_frame.image, result[i])

# with open('../predictions/2020070102.json', 'w') as file:
#     json.dump(result, file)
    

100%|##########| 50/50 [00:34<00:00,  1.43it/s]
