<a href="https://colab.research.google.com/github/talhaanwarch/my_pytorch/blob/master/advance/3D_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive/',force_remount=True)

Mounted at /content/drive/


In [2]:
%%capture
#!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install "monai-weekly[nibabel]"
!pip install -U pytorch-lightning
!pip install torchmetrics
!pip install -U tqdm
%matplotlib inline

In [3]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-1c2af932-d37e-2f0e-a4e8-702df125d137)


In [4]:
import tqdm
print(tqdm.__version__)
assert tqdm.__version__ >='4.47.0', 'tqdm version >=4.47.0'

4.61.2


In [5]:
#download data
import os
if os.path.isfile('Task09_Spleen.tar') is False:
  !wget https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar

In [6]:
import tarfile
if os.path.isdir('Task09_Spleen') is False:
  my_tar = tarfile.open('/content/Task09_Spleen.tar')
  my_tar.extractall('/content/')
  my_tar.close()

In [7]:
from glob import glob
import os
images = sorted(glob("Task09_Spleen/imagesTr/*.nii.gz"))
segs = sorted(glob("Task09_Spleen/labelsTr/*.nii.gz"))
data_dicts = [
            {"image": image_name, "label": label_name}
            for image_name, label_name in zip(images, segs)
        ]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]

In [8]:
import nibabel as nib
for i in range(len(train_files)):
  img = nib.load(train_files[i]['image']).get_fdata()
  lab = nib.load(train_files[i]['label']).get_fdata()
  print(img.shape,lab.shape)
  break

(512, 512, 55) (512, 512, 55)


In [9]:
import nibabel as nib
raw_val_data=[]
for i in range(len(val_files)):
  img = nib.load(val_files[i]['image']).get_fdata()
  lab = nib.load(val_files[i]['label']).get_fdata()
  print(img.shape,lab.shape)
  raw_val_data.append((img,lab))
  

(512, 512, 33) (512, 512, 33)
(512, 512, 50) (512, 512, 50)
(512, 512, 135) (512, 512, 135)
(512, 512, 97) (512, 512, 97)
(512, 512, 101) (512, 512, 101)
(512, 512, 80) (512, 512, 80)
(512, 512, 60) (512, 512, 60)
(512, 512, 31) (512, 512, 31)
(512, 512, 41) (512, 512, 41)


In [10]:
from  monai import transforms as T
train_transforms = T.Compose(
            [
                T.LoadImaged(keys=["image", "label"]),
                T.AddChanneld(keys=["image", "label"]),
                T.Spacingd(keys=["image", "label"],pixdim=(1.5, 1.5, 2.0),mode=("bilinear", "nearest"),),
                T.Orientationd(keys=["image", "label"], axcodes="RAS"),
                T.ScaleIntensityd(keys=["image"]),
                T.CropForegroundd(keys=["image", "label"], source_key="image"),
                T.RandCropByPosNegLabeld(keys=["image", "label"],label_key="label",spatial_size=(64, 64, 64),pos=1,neg=1,
                    num_samples=8,image_key="image", image_threshold=0,),
                T.EnsureTyped(keys=["image", "label"]),
            ]
        )
val_transforms = T.Compose(
  [
      T.LoadImaged(keys=["image", "label"]),
      T.AddChanneld(keys=["image", "label"]),
      T.Spacingd(keys=["image", "label"],pixdim=(1.5, 1.5, 2.0),mode=("bilinear", "nearest"),),
      T.Orientationd(keys=["image", "label"], axcodes="RAS"),
      T.ScaleIntensityd(keys=["image"]),
      T.CropForegroundd(keys=["image", "label"], source_key="image"),
      T.EnsureTyped(keys=["image", "label"]),
  ]
)

In [11]:
from pytorch_lightning import seed_everything, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint
from monai.losses import DiceLoss
from torch.utils.data import DataLoader, Dataset
from monai.data import CacheDataset, list_data_collate,decollate_batch
from pytorch_lightning.loggers import TensorBoardLogger
from monai.networks.nets import UNet
from monai.metrics import DiceMetric

import torch.nn as nn
import torch
import torchmetrics


In [12]:
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference

class OurModel(LightningModule):
  def __init__(self):
    super(OurModel,self).__init__()
    #architecute
    self.layer = UNet(
                      dimensions=3,
                      in_channels=1,
                      out_channels=2,
                      channels=(16, 32, 64, 128, 256),
                      strides=(2, 2, 2, 2),
                      num_res_units=2,
                      norm=Norm.BATCH,
                  )

  #parameters
    self.lr=1e-4
    self.bs=16
    self.numworker=2
    self.criterion = DiceLoss(to_onehot_y=True, softmax=True)
    self.post_pred = T.Compose([T.EnsureType(), T.AsDiscrete(argmax=True, to_onehot=True, n_classes=2)])
    self.post_label = T.Compose([T.EnsureType(), T.AsDiscrete(to_onehot=True, n_classes=2)])
    self.metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
    self.metric1= DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
    self.best_val_dice = 0
    self.best_val_epoch = 0

  def forward(self,x):
    return self.layer(x)


  def configure_optimizers(self):
    opt=torch.optim.Adam(self.parameters(), lr=self.lr)
    return opt

  def train_dataloader(self):
    ds = CacheDataset(data=train_files, transform=train_transforms,cache_rate=1.0, num_workers=self.numworker)
    loader=DataLoader(ds, batch_size=self.bs, shuffle=True,num_workers=self.numworker, collate_fn=list_data_collate)
    return loader

  def training_step(self,batch,batch_idx):
    image,segment=batch["image"], batch["label"]
    out=self(image)
    loss=self.criterion(out,segment)

    self.log('train/loss', loss, on_epoch=False,prog_bar=True)
    return loss

  def val_dataloader(self):
    ds = CacheDataset(data=val_files, transform=val_transforms,cache_rate=1.0, num_workers=self.numworker)
    loader=DataLoader(ds, batch_size=1, shuffle=False,num_workers=self.numworker, collate_fn=list_data_collate)
    return loader
    
  def validation_step(self,batch,batch_idx):
    image,segment=batch["image"], batch["label"]
    roi_size = (32, 32, 32)
    sw_batch_size = 4
    outputs = sliding_window_inference(image, roi_size, sw_batch_size, self.forward)
    loss=self.criterion(outputs,segment)
    outputs = [self.post_pred(i) for i in decollate_batch(outputs)]
    labels = [self.post_label(i) for i in decollate_batch(segment)]
    self.metric(y_pred=outputs, y=labels)
    self.metric1(y_pred=outputs, y=labels)
    dice = self.metric1.aggregate().item()
    self.log('val/loss', loss, on_epoch=True,prog_bar=True)
    self.log('val/dice', dice, on_epoch=True,prog_bar=True)
    return{"val_loss": loss, "val_number": len(outputs)}
  
  def validation_epoch_end(self, outputs):
        val_loss, num_items = 0, 0
        for output in outputs:
            val_loss += output["val_loss"].sum().item()
            num_items += output["val_number"]
        mean_val_dice = self.metric.aggregate().item()
        self.metric.reset()
        mean_val_loss = torch.tensor(val_loss / num_items)
        tensorboard_logs = {
            "val_dice": mean_val_dice,
            "val_loss": mean_val_loss,
        }
        if mean_val_dice > self.best_val_dice:
            self.best_val_dice = mean_val_dice
            self.best_val_epoch = self.current_epoch
        print(
            f"current epoch: {self.current_epoch} "
            f"current mean dice: {mean_val_dice:.4f}"
            f"\nbest mean dice: {self.best_val_dice:.4f} "
            f"at epoch: {self.best_val_epoch}"
        )
        return {"log": tensorboard_logs}

In [13]:
model = OurModel()
logger = TensorBoardLogger("logs", name="my_logs")
checkpoint_callback = ModelCheckpoint(monitor='val/loss',dirpath='/content/drive/MyDrive/',
                                      filename='spleen200' )
trainer = Trainer(max_epochs=200, auto_lr_find=False, auto_scale_batch_size=False,
                  #tpu_cores=8,precision=16,
                  gpus=-1,precision=16,
                  logger=logger,
                  progress_bar_refresh_rate=30,
                  #resume_from_checkpoint='/content/drive/MyDrive/spleen200-v1.ckpt',
                  callbacks=[checkpoint_callback]
                  )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [None]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type     | Params
---------------------------------------
0 | layer     | UNet     | 4.8 M 
1 | criterion | DiceLoss | 0     
---------------------------------------
4.8 M     Trainable params
0         Non-trainable params
4.8 M     Total params
19.236    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]


Loading dataset:   0%|          | 0/9 [00:00<?, ?it/s][A
Loading dataset:  11%|█         | 1/9 [00:02<00:19,  2.45s/it][A
Loading dataset:  22%|██▏       | 2/9 [00:04<00:15,  2.19s/it][A
Loading dataset:  33%|███▎      | 3/9 [00:08<00:16,  2.82s/it][A
Loading dataset:  44%|████▍     | 4/9 [00:11<00:14,  2.89s/it][A
Loading dataset:  56%|█████▌    | 5/9 [00:14<00:12,  3.05s/it][A
Loading dataset:  67%|██████▋   | 6/9 [00:14<00:06,  2.16s/it][A
Loading dataset:  78%|███████▊  | 7/9 [00:17<00:04,  2.43s/it][A
Loading dataset: 100%|██████████| 9/9 [00:18<00:00,  2.10s/it]


current epoch: 0 current mean dice: 0.0098
best mean dice: 0.0098 at epoch: 0


Loading dataset: 100%|██████████| 32/32 [01:40<00:00,  3.15s/it]


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

current epoch: 0 current mean dice: 0.0134
best mean dice: 0.0134 at epoch: 0


Validating: 0it [00:00, ?it/s]

current epoch: 1 current mean dice: 0.0144
best mean dice: 0.0144 at epoch: 1


Validating: 0it [00:00, ?it/s]

current epoch: 2 current mean dice: 0.0149
best mean dice: 0.0149 at epoch: 2


Validating: 0it [00:00, ?it/s]

current epoch: 3 current mean dice: 0.0149
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 4 current mean dice: 0.0147
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 5 current mean dice: 0.0144
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 6 current mean dice: 0.0141
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 7 current mean dice: 0.0141
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 8 current mean dice: 0.0143
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 9 current mean dice: 0.0144
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 10 current mean dice: 0.0146
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 11 current mean dice: 0.0147
best mean dice: 0.0149 at epoch: 3


Validating: 0it [00:00, ?it/s]

current epoch: 12 current mean dice: 0.0152
best mean dice: 0.0152 at epoch: 12


Validating: 0it [00:00, ?it/s]

current epoch: 13 current mean dice: 0.0155
best mean dice: 0.0155 at epoch: 13


Validating: 0it [00:00, ?it/s]

current epoch: 14 current mean dice: 0.0161
best mean dice: 0.0161 at epoch: 14


Validating: 0it [00:00, ?it/s]

current epoch: 15 current mean dice: 0.0166
best mean dice: 0.0166 at epoch: 15


Validating: 0it [00:00, ?it/s]

current epoch: 16 current mean dice: 0.0170
best mean dice: 0.0170 at epoch: 16


Validating: 0it [00:00, ?it/s]

current epoch: 17 current mean dice: 0.0177
best mean dice: 0.0177 at epoch: 17


Validating: 0it [00:00, ?it/s]

current epoch: 18 current mean dice: 0.0180
best mean dice: 0.0180 at epoch: 18


Validating: 0it [00:00, ?it/s]

current epoch: 19 current mean dice: 0.0182
best mean dice: 0.0182 at epoch: 19


Validating: 0it [00:00, ?it/s]

current epoch: 20 current mean dice: 0.0184
best mean dice: 0.0184 at epoch: 20


Validating: 0it [00:00, ?it/s]

current epoch: 21 current mean dice: 0.0187
best mean dice: 0.0187 at epoch: 21


Validating: 0it [00:00, ?it/s]

current epoch: 22 current mean dice: 0.0188
best mean dice: 0.0188 at epoch: 22


Validating: 0it [00:00, ?it/s]

current epoch: 23 current mean dice: 0.0190
best mean dice: 0.0190 at epoch: 23


Validating: 0it [00:00, ?it/s]

current epoch: 24 current mean dice: 0.0192
best mean dice: 0.0192 at epoch: 24


Validating: 0it [00:00, ?it/s]

current epoch: 25 current mean dice: 0.0194
best mean dice: 0.0194 at epoch: 25


Validating: 0it [00:00, ?it/s]

current epoch: 26 current mean dice: 0.0194
best mean dice: 0.0194 at epoch: 26


Validating: 0it [00:00, ?it/s]

current epoch: 27 current mean dice: 0.0194
best mean dice: 0.0194 at epoch: 26


Validating: 0it [00:00, ?it/s]

current epoch: 28 current mean dice: 0.0194
best mean dice: 0.0194 at epoch: 26


Validating: 0it [00:00, ?it/s]

current epoch: 29 current mean dice: 0.0199
best mean dice: 0.0199 at epoch: 29


Validating: 0it [00:00, ?it/s]

current epoch: 30 current mean dice: 0.0207
best mean dice: 0.0207 at epoch: 30


Validating: 0it [00:00, ?it/s]

current epoch: 31 current mean dice: 0.0209
best mean dice: 0.0209 at epoch: 31


Validating: 0it [00:00, ?it/s]

current epoch: 32 current mean dice: 0.0209
best mean dice: 0.0209 at epoch: 31


Validating: 0it [00:00, ?it/s]

current epoch: 33 current mean dice: 0.0212
best mean dice: 0.0212 at epoch: 33


Validating: 0it [00:00, ?it/s]

current epoch: 34 current mean dice: 0.0215
best mean dice: 0.0215 at epoch: 34


Validating: 0it [00:00, ?it/s]

current epoch: 35 current mean dice: 0.0205
best mean dice: 0.0215 at epoch: 34


Validating: 0it [00:00, ?it/s]

current epoch: 36 current mean dice: 0.0219
best mean dice: 0.0219 at epoch: 36


Validating: 0it [00:00, ?it/s]

current epoch: 37 current mean dice: 0.0219
best mean dice: 0.0219 at epoch: 37


Validating: 0it [00:00, ?it/s]

current epoch: 38 current mean dice: 0.0224
best mean dice: 0.0224 at epoch: 38


Validating: 0it [00:00, ?it/s]

current epoch: 39 current mean dice: 0.0217
best mean dice: 0.0224 at epoch: 38


Validating: 0it [00:00, ?it/s]

current epoch: 40 current mean dice: 0.0200
best mean dice: 0.0224 at epoch: 38


Validating: 0it [00:00, ?it/s]

current epoch: 41 current mean dice: 0.0192
best mean dice: 0.0224 at epoch: 38


Validating: 0it [00:00, ?it/s]

current epoch: 42 current mean dice: 0.0230
best mean dice: 0.0230 at epoch: 42


Validating: 0it [00:00, ?it/s]

current epoch: 43 current mean dice: 0.0236
best mean dice: 0.0236 at epoch: 43


Validating: 0it [00:00, ?it/s]

current epoch: 44 current mean dice: 0.0237
best mean dice: 0.0237 at epoch: 44


Validating: 0it [00:00, ?it/s]

current epoch: 45 current mean dice: 0.0257
best mean dice: 0.0257 at epoch: 45


Validating: 0it [00:00, ?it/s]

current epoch: 46 current mean dice: 0.0213
best mean dice: 0.0257 at epoch: 45


Validating: 0it [00:00, ?it/s]

current epoch: 47 current mean dice: 0.0214
best mean dice: 0.0257 at epoch: 45


Validating: 0it [00:00, ?it/s]

current epoch: 48 current mean dice: 0.0243
best mean dice: 0.0257 at epoch: 45


Validating: 0it [00:00, ?it/s]

current epoch: 49 current mean dice: 0.0239
best mean dice: 0.0257 at epoch: 45


Validating: 0it [00:00, ?it/s]

current epoch: 50 current mean dice: 0.0272
best mean dice: 0.0272 at epoch: 50


Validating: 0it [00:00, ?it/s]

current epoch: 51 current mean dice: 0.0270
best mean dice: 0.0272 at epoch: 50


Validating: 0it [00:00, ?it/s]

current epoch: 52 current mean dice: 0.0244
best mean dice: 0.0272 at epoch: 50


Validating: 0it [00:00, ?it/s]

current epoch: 53 current mean dice: 0.0256
best mean dice: 0.0272 at epoch: 50


Validating: 0it [00:00, ?it/s]

current epoch: 54 current mean dice: 0.0271
best mean dice: 0.0272 at epoch: 50


Validating: 0it [00:00, ?it/s]

current epoch: 55 current mean dice: 0.0277
best mean dice: 0.0277 at epoch: 55


Validating: 0it [00:00, ?it/s]

current epoch: 56 current mean dice: 0.0265
best mean dice: 0.0277 at epoch: 55


Validating: 0it [00:00, ?it/s]

current epoch: 57 current mean dice: 0.0223
best mean dice: 0.0277 at epoch: 55


Validating: 0it [00:00, ?it/s]

current epoch: 58 current mean dice: 0.0299
best mean dice: 0.0299 at epoch: 58


Validating: 0it [00:00, ?it/s]

current epoch: 59 current mean dice: 0.0305
best mean dice: 0.0305 at epoch: 59


Validating: 0it [00:00, ?it/s]

current epoch: 60 current mean dice: 0.0226
best mean dice: 0.0305 at epoch: 59


Validating: 0it [00:00, ?it/s]

current epoch: 61 current mean dice: 0.0311
best mean dice: 0.0311 at epoch: 61


Validating: 0it [00:00, ?it/s]

current epoch: 62 current mean dice: 0.0322
best mean dice: 0.0322 at epoch: 62


Validating: 0it [00:00, ?it/s]

current epoch: 63 current mean dice: 0.0293
best mean dice: 0.0322 at epoch: 62


Validating: 0it [00:00, ?it/s]

current epoch: 64 current mean dice: 0.0205
best mean dice: 0.0322 at epoch: 62


Validating: 0it [00:00, ?it/s]

current epoch: 65 current mean dice: 0.0323
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 66 current mean dice: 0.0316
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 67 current mean dice: 0.0322
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 68 current mean dice: 0.0239
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 69 current mean dice: 0.0306
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 70 current mean dice: 0.0303
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 71 current mean dice: 0.0242
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 72 current mean dice: 0.0317
best mean dice: 0.0323 at epoch: 65


Validating: 0it [00:00, ?it/s]

current epoch: 73 current mean dice: 0.0338
best mean dice: 0.0338 at epoch: 73


Validating: 0it [00:00, ?it/s]

current epoch: 74 current mean dice: 0.0333
best mean dice: 0.0338 at epoch: 73


Validating: 0it [00:00, ?it/s]

current epoch: 75 current mean dice: 0.0346
best mean dice: 0.0346 at epoch: 75


Validating: 0it [00:00, ?it/s]

current epoch: 76 current mean dice: 0.0340
best mean dice: 0.0346 at epoch: 75


Validating: 0it [00:00, ?it/s]

current epoch: 77 current mean dice: 0.0305
best mean dice: 0.0346 at epoch: 75


Validating: 0it [00:00, ?it/s]

current epoch: 78 current mean dice: 0.0362
best mean dice: 0.0362 at epoch: 78


Validating: 0it [00:00, ?it/s]

current epoch: 79 current mean dice: 0.0367
best mean dice: 0.0367 at epoch: 79


Validating: 0it [00:00, ?it/s]

current epoch: 80 current mean dice: 0.0368
best mean dice: 0.0368 at epoch: 80


Validating: 0it [00:00, ?it/s]

current epoch: 81 current mean dice: 0.0357
best mean dice: 0.0368 at epoch: 80


Validating: 0it [00:00, ?it/s]

current epoch: 82 current mean dice: 0.0374
best mean dice: 0.0374 at epoch: 82


Validating: 0it [00:00, ?it/s]

current epoch: 83 current mean dice: 0.0383
best mean dice: 0.0383 at epoch: 83


Validating: 0it [00:00, ?it/s]

current epoch: 84 current mean dice: 0.0385
best mean dice: 0.0385 at epoch: 84


Validating: 0it [00:00, ?it/s]

current epoch: 85 current mean dice: 0.0370
best mean dice: 0.0385 at epoch: 84


Validating: 0it [00:00, ?it/s]

current epoch: 86 current mean dice: 0.0378
best mean dice: 0.0385 at epoch: 84


Validating: 0it [00:00, ?it/s]

current epoch: 87 current mean dice: 0.0411
best mean dice: 0.0411 at epoch: 87


Validating: 0it [00:00, ?it/s]

current epoch: 88 current mean dice: 0.0405
best mean dice: 0.0411 at epoch: 87


Validating: 0it [00:00, ?it/s]

current epoch: 89 current mean dice: 0.0384
best mean dice: 0.0411 at epoch: 87


Validating: 0it [00:00, ?it/s]

current epoch: 90 current mean dice: 0.0413
best mean dice: 0.0413 at epoch: 90


Validating: 0it [00:00, ?it/s]

current epoch: 91 current mean dice: 0.0395
best mean dice: 0.0413 at epoch: 90


Validating: 0it [00:00, ?it/s]

current epoch: 92 current mean dice: 0.0405
best mean dice: 0.0413 at epoch: 90


Validating: 0it [00:00, ?it/s]

current epoch: 93 current mean dice: 0.0404
best mean dice: 0.0413 at epoch: 90


Validating: 0it [00:00, ?it/s]

current epoch: 94 current mean dice: 0.0426
best mean dice: 0.0426 at epoch: 94


Validating: 0it [00:00, ?it/s]

current epoch: 95 current mean dice: 0.0397
best mean dice: 0.0426 at epoch: 94


Validating: 0it [00:00, ?it/s]

current epoch: 96 current mean dice: 0.0437
best mean dice: 0.0437 at epoch: 96


Validating: 0it [00:00, ?it/s]

current epoch: 97 current mean dice: 0.0435
best mean dice: 0.0437 at epoch: 96


Validating: 0it [00:00, ?it/s]

current epoch: 98 current mean dice: 0.0443
best mean dice: 0.0443 at epoch: 98


Validating: 0it [00:00, ?it/s]

current epoch: 99 current mean dice: 0.0389
best mean dice: 0.0443 at epoch: 98


Validating: 0it [00:00, ?it/s]

current epoch: 100 current mean dice: 0.0458
best mean dice: 0.0458 at epoch: 100


Validating: 0it [00:00, ?it/s]

current epoch: 101 current mean dice: 0.0430
best mean dice: 0.0458 at epoch: 100


Validating: 0it [00:00, ?it/s]

current epoch: 102 current mean dice: 0.0393
best mean dice: 0.0458 at epoch: 100


Validating: 0it [00:00, ?it/s]

current epoch: 103 current mean dice: 0.0400
best mean dice: 0.0458 at epoch: 100


Validating: 0it [00:00, ?it/s]

current epoch: 104 current mean dice: 0.0421
best mean dice: 0.0458 at epoch: 100


Validating: 0it [00:00, ?it/s]

current epoch: 105 current mean dice: 0.0445
best mean dice: 0.0458 at epoch: 100


Validating: 0it [00:00, ?it/s]

current epoch: 106 current mean dice: 0.0439
best mean dice: 0.0458 at epoch: 100


Validating: 0it [00:00, ?it/s]

current epoch: 107 current mean dice: 0.0483
best mean dice: 0.0483 at epoch: 107


Validating: 0it [00:00, ?it/s]

In [15]:
torch.save(model.state_dict(), '/content/drive/MyDrive/spleen_state200.pt')

In [None]:
#model.load_state_dict(torch.load('/content/drive/MyDrive/spleen_state600.pt'))


# inference

In [None]:
post_transforms = T.Compose([
        T.ToTensord(keys="pred"), 
        T.Activationsd(keys="pred", sigmoid=True),
        T.AsDiscreted(keys="pred", threshold_values=True),
        T.Invertd(
            keys=["pred",'label'],  # invert the `pred` data field, also support multiple fields
            transform=train_transforms,
            orig_keys="image",  # get the previously applied pre_transforms information on the `img` data field,
                              # then invert `pred` based on this information. we can use same info
                              # for multiple fields, also support different orig_keys for different fields
            meta_keys=["pred_meta_dict","label_meta_dict"],  # key field to save inverted meta data, every item maps to `keys`
            orig_meta_keys="image_meta_dict",  # get the meta data from `img_meta_dict` field when inverting,
                                             # for example, may need the `affine` to invert `Spacingd` transform,
                                             # multiple fields can use the same meta data to invert
            meta_key_postfix="meta_dict",  # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
                                           # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
                                           # otherwise, no need this arg during inverting
            nearest_interp=True,  # change to use "nearest" mode in interpolation when inverting
            to_tensor=True,  # convert to PyTorch Tensor after inverting
        ),
        #SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
    ])

In [None]:
val_dataloader=model.val_dataloader()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
model.eval()
device = torch.device("cuda:0")
model.to(device)
ch=32
with torch.no_grad():
    for i, val_data in enumerate(val_dataloader):
        roi_size = (32, 32, 32)
        sw_batch_size = 2
        output = sliding_window_inference(
            val_data["image"].to(device), roi_size, sw_batch_size, model
        )
        # plot the slice [:, :, 80]
        
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(val_data["image"][0, 0, :, :, ch], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(val_data["label"][0, 0, :, :, ch])
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(torch.argmax(
            output, dim=1).detach().cpu()[0, :, :, ch])
        plt.show()
        break
        

In [None]:
from monai.data import decollate_batch
model.eval()
post_data=[]
with torch.no_grad():
  for d in val_dataloader:
    images = d["image"].to(device)
    # define sliding window size and batch size for windows inference
    infer_outputs = sliding_window_inference(inputs=images, roi_size=(32 , 32, 32), sw_batch_size=2, predictor=model)
    infer_outputs = decollate_batch(infer_outputs)
    for (infer_output, infer_output_data) in zip(infer_outputs, decollate_batch(d)):
      infer_output_data["pred"] = infer_output
      post_data.append(post_transforms(infer_output_data))
      break
    break
    

In [None]:
out=post_data[0]
out['image'].shape,out['label'].shape,out['pred'].shape

In [None]:
raw_val=raw_val_data[0]
raw_val[0].shape,raw_val[1].shape#raw label and segment

In [None]:
fig,ax=plt.subplots(33,3,figsize=(10,60))

for i in range(33):
  ax[i,0].imshow(raw_val[1][:,:,i],cmap='gray')#raw image
  ax[i,1].imshow(out['label'][0,:,:,i],cmap='gray')#inverted image
  ax[i,2].imshow(torch.argmax(out['pred'],0)[:,:,i],cmap='gray')#predicted image
plt.subplots_adjust(wspace=0, hspace=-.5)


# test

In [None]:
post_transforms = T.Compose([
        T.ToTensord(keys="pred"), 
        T.Activationsd(keys="pred", sigmoid=True),
        T.AsDiscreted(keys="pred", threshold_values=True),
        T.Invertd(
            keys=["pred",'label'],  # invert the `pred` data field, also support multiple fields
            transform=val_transforms,
            orig_keys="image",  # get the previously applied pre_transforms information on the `img` data field,
                              # then invert `pred` based on this information. we can use same info
                              # for multiple fields, also support different orig_keys for different fields
            meta_keys=["pred_meta_dict","label_meta_dict"],  # key field to save inverted meta data, every item maps to `keys`
            orig_meta_keys="image_meta_dict",  # get the meta data from `img_meta_dict` field when inverting,
                                             # for example, may need the `affine` to invert `Spacingd` transform,
                                             # multiple fields can use the same meta data to invert
            meta_key_postfix="meta_dict",  # if `meta_keys=None`, use "{keys}_{meta_key_postfix}" as the meta key,
                                           # if `orig_meta_keys=None`, use "{orig_keys}_{meta_key_postfix}",
                                           # otherwise, no need this arg during inverting
            nearest_interp=True,  # change to use "nearest" mode in interpolation when inverting
            to_tensor=True,  # convert to PyTorch Tensor after inverting
        ),
        #SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
    ])

In [None]:
ds = CacheDataset(data=val_files, transform=val_transforms,cache_rate=1.0, num_workers=2)
test_loader=DataLoader(ds, batch_size=1, shuffle=False,num_workers=2, collate_fn=list_data_collate)
sample=next(iter(test_loader))

In [None]:
sample['image'].shape,sample['label'].shape

In [None]:
post_data=[]
for d in test_loader:
    images = d["image"].to(device)
    # define sliding window size and batch size for windows inference
    infer_outputs = sliding_window_inference(inputs=images, roi_size=(32 , 32, 32), sw_batch_size=2, predictor=model)
    infer_outputs = decollate_batch(infer_outputs)
    for (infer_output, infer_output_data) in zip(infer_outputs, decollate_batch(d)):
      infer_output_data["pred"] = infer_output
      post_data.append(post_transforms(infer_output_data))
      break
    break

In [None]:
out=post_data[0]
out['image'].shape,out['label'].shape,out['pred'].shape

In [None]:
raw_val=raw_val_data[0]
raw_val[0].shape,raw_val[1].shape#raw label and segment

In [None]:
fig,ax=plt.subplots(33,3,figsize=(10,60))

for i in range(33):
  ax[i,0].imshow(raw_val[1][:,:,i],cmap='gray')#raw image
  ax[i,1].imshow(out['label'][0,:,:,i],cmap='gray')#inverted image
  ax[i,2].imshow(torch.argmax(out['pred'],0)[:,:,i],cmap='gray')#predicted image
plt.subplots_adjust(wspace=0, hspace=-.5)
