In [3]:
#!pip install pytorch_lightning
#import google.colab
import numpy as np
import os
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn import functional as F

from torch.utils.data import random_split
#google.colab.drive.mount("/content/drive")

import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer

In [None]:
#!gdown --id 1jn98eiQqdYO2tgp_XsS8dO124WtIh29G #1HsvAl5WqvqXFJOnjwabZTmkSY3JrVzBa
# # !unzip dataset128.zip
#!unzip /content/drive/MyDrive/dataset2.zip

In [4]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):

        super(DoubleConv, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv(x)

class DoubleConv3d(nn.Module):

    def __init__(self, in_channels, out_channels):

        super(DoubleConv3d, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.InstanceNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.InstanceNorm3d(out_channels),
            nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=31, out_channels=32,
    ):
        super(UNET, self).__init__()

        self.down_1 = DoubleConv(31, 64)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_2 = DoubleConv(64, 128)        
        self.down_3 = DoubleConv(128, 256)
        
        self.bottleneck = nn.Conv2d(256, 512, 3, 1, 1, bias=False)

        self.up_1 = nn.ConvTranspose3d(256, 256, kernel_size = 2, stride=2, groups =2)
        self.up_1_h = DoubleConv3d(320, 80)
        self.up_2 = nn.ConvTranspose3d(80, 40, kernel_size = 2, stride=2, groups =2)
        self.up_2_h = DoubleConv3d(56, 20)
        self.up_3 = nn.ConvTranspose3d(20, 6, kernel_size = 2, stride=2, groups =2)
        self.up_3_h = DoubleConv3d(10, 6)
        self.up_4_h = DoubleConv3d(6, 2)

    def forward(self, x):

        skip_connections = []

        x = self.down_1(x)
        skip_connections.append(torch.flatten(x).reshape([16, 4, 16, 128, 128]))
        x = self.maxpool(x)

        x = self.down_2(x) 
        skip_connections.append(torch.flatten(x).reshape([16, 16, 8, 64, 64]))
        x = self.maxpool(x)

        #print(x.shape)
        x = self.down_3(x) 
        # print(x.shape)
        skip_connections.append(torch.flatten(x).reshape([16, 64, 4, 32, 32]))
        x = self.maxpool(x)

        x = self.bottleneck(x)
        #print(x.shape)
        x = torch.flatten(x).reshape([16, 256, 2, 16, 16])
        #print(x.shape)
        x = self.up_1(x)
        #print(x.shape)
        x = torch.cat([skip_connections[-1],x], dim = 1)
        #print(x.shape)
        x = self.up_1_h(x)

        x = self.up_2(x)
        x = torch.cat([skip_connections[-2],x], dim = 1)
        x = self.up_2_h(x)

        x = self.up_3(x)
        x = torch.cat([skip_connections[-3],x], dim = 1)
        x = self.up_3_h(x)
        x = self.up_4_h(x)

        return x

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
from torchvision import transforms as T
from random import random, choice

# For .npz data    
class CustomData(Dataset):

  def __init__(self, dir, mode):
    self.dir = dir
    self.mode = mode
    self.scans = [x for x in os.listdir(self.dir + 'scans') if x[-3:]=='npz']
    self.scans_indexes = [int(file_name[-9:-4]) for file_name in self.scans]

    if mode == "2d":
      self.struc = [x for x in os.listdir(self.dir + 'structures') if x[-3:]=='npz']
      self.struc_indexes = [int(file_name[-9:-4]) for file_name in self.struc]
    else:
      self.struc_3d = [x for x in os.listdir(self.dir + '3d_structures') if x[-3:]=='npz']
      self.struc_3d_indexes = [int(file_name[-9:-4]) for file_name in self.struc_3d]

    self.len = len(self.scans)
    
    self.count = 0


  @staticmethod
  def transform(scan, structure):
  
    # Horizontal flip
    if random()>0.5:
      scan = T.functional.hflip(scan)
      structure = T.functional.hflip(structure)
    
    #Vertical flip
    if random()>0.5:
      scan = T.functional.vflip(scan)
      structure = T.functional.vflip(structure)

    #Rotation
    if random()>0.5:
      rand_deg = choice([90, -90])
      scan = T.functional.rotate(scan, angle = rand_deg)
      structure = T.functional.rotate(structure, angle = rand_deg)
    
    return scan, structure


  def __len__(self):
    return self.len

  def __getitem__(self, idx):
    idx += 4501
    
    self.count = self.count+1

    scan_idx = self.scans_indexes.index(idx)

    scan = torch.from_numpy(np.load(self.dir + 'scans/' + self.scans[scan_idx])['arr_0'])
    scan = torch.permute(scan, (2, 0, 1)).float()
    if self.mode == "2d":
      structure_idx = self.struc_indexes.index(idx)
      structure = torch.from_numpy(np.array([a.T for a in np.load(self.dir + 'structures/' + self.struc[structure_idx])['arr_0']])).float()
    else:
      structure_3d_idx = self.struc_3d_indexes.index(idx)
      structure = torch.from_numpy(np.array([a.T for a in np.load(self.dir + '3d_structures/' + self.struc_3d[structure_3d_idx])['arr_0']])).float()  
    
    scan, structure = self.transform(scan, structure)
    
    #print("\nscan ", scan.shape, scan.dtype)
    #print("structure ", structure.shape, structure.dtype)
    #'''
    a = structure == 1
    a = a.float()
    #print("\na", a.shape, a.dtype)

    b = structure == 2
    b = b.float()
    #print("b", b.shape, b.dtype)

    structure = torch.stack([a, b], dim = 0)
    assert scan.shape == torch.Size([31, 128, 128]) and structure.shape == torch.Size([2, 16, 128, 128])
    #'''
    #assert scan.shape == torch.Size([31, 128, 128]) and structure.shape == torch.Size([16, 128, 128])
    
    return scan, structure

In [4]:
# Initialization (here PATH_TO_DATA is path to directory where scans and structures are located, e.g. '/content/drive/MyDrive/dataset0/', 'MODE' = 'test' or 'train'):
check = False

dataset= CustomData('dataset2/', '3d')

loader = DataLoader(dataset, num_workers=0, batch_size=1, shuffle=True)

for x, y in loader:
    print("\ncycle")
    x = x.detach().numpy()
    y = y.detach().numpy()
    print("x", x.shape, x.dtype)
    print("y", y.shape, y.dtype)   
    break
print(dataset.count)

if check:
    fig, axs = plt.subplots(1,5, figsize=(14,8))

    axs[0].imshow(np.mean(y[:,0:3,:,:], axis=1)[0])
    axs[1].imshow(np.mean(y[:,3:,:,:], axis=1)[0])

    axs[2].imshow(x[0,0])
    axs[3].imshow(x[0,10])
    axs[4].imshow(x[0,20])

    axs[0].set_title('struc gold')
    axs[1].set_title('struc alumin')

    axs[2].set_title('scan 1')
    axs[3].set_title('scan 10')
    axs[4].set_title('scan 20')

    # axs[5].set_title('3d struc mean 1')
    # axs[6].set_title('3d struc mean 2')

    plt.show()
    # T.functional.vflip(x)
    
struct_dataset = CustomData('dataset2/', '3d')
train_set, val_set = torch.utils.data.random_split(struct_dataset, [len(struct_dataset)-900, 900])
print("\nsplit check")
print(len(train_set))
print(len(val_set))


cycle
x (1, 31, 128, 128) float32
y (1, 2, 16, 128, 128) float32
1

split check
9500
900


In [5]:
# pred_3d and orig_3d are of size (batch_size, 16, 128, 128) if not - do unsqueeze
def calc_val_data(pred_3d, orig_3d, n_cls=3):
    from torch import logical_and as LAND
    from torch import logical_or as LOR
    
    bat_size = pred_3d.shape[0]

    intersection = torch.tensor([[torch.sum(LAND((pred_3d[b]==i),(orig_3d[b]==i))) for i in range(n_cls)] for b in range(bat_size)])
    union =        torch.tensor([[torch.sum(LOR( (pred_3d[b]==i),(orig_3d[b]==i))) for i in range(n_cls)] for b in range(bat_size)])
    target =       torch.tensor([[torch.sum(orig_3d[b]==i) for i in range(n_cls)] for b in range(bat_size)])
    
    # Output shapes: batch_size x num_classes
    return intersection, union, target

def calc_val_loss(intersection, union, target, eps = 1e-7):

    mean_iou = torch.mean((intersection+eps)/(union+eps))
    mean_class_rec = torch.mean((intersection+eps)/(target+eps))
    mean_acc = torch.nansum(intersection)/torch.nansum(target)

    return mean_iou, mean_class_rec, mean_acc

def calc_miou(pred_3d, orig_3d, n_cls=3):
    intersection, union, target = calc_val_data(pred_3d, orig_3d, n_cls=3)
    mean_iou, mean_class_rec, mean_acc = calc_val_loss(intersection, union, target, eps = 1e-7)
    return mean_iou, mean_class_rec, mean_acc

In [17]:
test = torch.randn(16, 2, 16, 128, 128)
print(test.shape)

zero_mask0 = test[:,0,:,:,:]<0.1
zero_mask1 = test[:,1,:,:,:]<0.1
zero_mask = np.logical_and(zero_mask0, zero_mask1)
print(zero_mask.shape)

test_ = torch.argmax(test, dim=1)+1
print(test_.shape)
print(torch.unique(test_))

test_ = test_*zero_mask
print(test_.shape)
print(torch.unique(test_))

torch.Size([16, 2, 16, 128, 128])
torch.Size([16, 16, 128, 128])
torch.Size([16, 16, 128, 128])
tensor([1, 2])
torch.Size([16, 16, 128, 128])
tensor([0, 1, 2])


In [21]:
BATCH_SIZE = 16
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda:0')

class SomeModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = UNET().to(device)

    def forward(self, x):
        x.to(device)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.to(device)
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        
        # prepare y_val
        zero_mask0 = y[:,0,:,:,:]==0
        zero_mask1 = y[:,1,:,:,:]==0
        zero_mask = torch.logical_and(zero_mask0, zero_mask1)

        y_val = zero_mask*(torch.argmax(y, dim=1)+1)
        
        x = x.to(device)
        y = y.to(device)
        y_pred = self(x)
        loss = F.mse_loss(y_pred, y)
        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("val_loss", loss, prog_bar=True)
        
        
        # prepare y_pred_val
        zero_mask0 = y_pred[:,0,:,:,:]<0.5
        zero_mask1 = y_pred[:,1,:,:,:]<0.5
        zero_mask = torch.logical_and(zero_mask0, zero_mask1)
        
        y_pred_val = zero_mask*(torch.argmax(y_pred, dim=1)+1)

        
        
        mean_iou, mean_class_rec, mean_acc = calc_miou(y_pred_val.int(), y_val.int(), n_cls=3)
        self.log("mean_iou", mean_iou, prog_bar=True)
        self.log("mean_class_rec", mean_class_rec, prog_bar=True)
        self.log("mean_acc", mean_acc, prog_bar=True)

        return loss
    
    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)
    
    
    def setup(self, stage=None):
             
        self.train_set, self.test_set, self.val_set = random_split(struct_dataset, [len(struct_dataset)-800, 400, 400])
        
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            pass
        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            pass
    
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=BATCH_SIZE, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=BATCH_SIZE, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=BATCH_SIZE, num_workers=4)
    

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=1e-4)

In [22]:
# Init our model
from pytorch_lightning.callbacks import ModelCheckpoint
# saves model every 2d epoch
checkpoint_callback = ModelCheckpoint(every_n_epochs=5,
                                      save_top_k = -1,
                                      filename="3d_unet-{epoch:02d}",
                                      dirpath ='checkpoints/'
                                      )


# Init our model
some_model = SomeModel().to(device)

# Initialize a trainer
trainer = Trainer(
    gpus=1,
    val_check_interval=0.2,
    max_epochs=50,
    progress_bar_refresh_rate=1,
    callbacks = checkpoint_callback,
)


# Train the model ⚡
trainer.fit(some_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | UNET | 3.5 M 
-------------------------------
3.5 M     Trainable params
0         Non-trainable params
3.5 M     Total params
14.103    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/trinity/shared/opt/python/3.8.5/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>Traceback (most recent call last):

  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/trinity/shared/opt/python/3.8.5/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/trinity/home/a.razorenova/.local/lib/python3.8/si

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/trinity/shared/opt/python/3.8.5/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Exception ignored in:     self._shutdown_workers()
Exception ignored in:   File "/trini

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Exception ignored in:     self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>

Exception ignored in: Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
<function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__

    Traceback (most recent call last):
    self._shutdown_workers()  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    if w.is_alive():
  File "/trinity/shared/opt/pytho

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()    
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
self._shutdown_workers()
      File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
if w.is_alive():    
  File "/trinity/shared/opt/python/3.8.5/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
if w.is_alive():
Exception ignored in:   

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
Exception ignored in:     if w.is_alive():
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>

  File "/trinity/shared/opt/python/3.8.5/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
Traceback (most recent call last):
      File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    assert self._parent_pid == os.getpid(), 'can only test a child process'self.

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
Exception ignored in:   File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
Traceback (most recent call last):
    self._shutdown_workers()  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>    

self._shutdown_workers()Traceback (most recent call last):

  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
  File "/trinity/home/a.razorenova/.local/lib/p

  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
can only test a child process
        self._shutdown_workers()self._shutdown_workers()
Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
Traceback (most recent call last):
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
          File "/trinity/shared/opt/python/3.8.5/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
Exception ignored in:     self._shutdown_workers()
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils

Validation: 0it [00:00, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
<function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>Traceback (most recent call last):
Exception ignored in: 
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x1554c89ab550>
    Traceback (most recent call last):
self._shutdown_workers()  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__

  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    Traceback (most recent call last):
self._shutdown_workers()
  File "/trinity/home/a.razorenova/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    Exception ignored in:   File "/trinity/home/a.razoren

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
torch.save(some_model.model.state_dict(), '3d_mse_20epoch.pkl')

In [8]:
import os
from collections import OrderedDict

device = torch.device('cuda')

def load_weights(model, checkpoint_path):
    
    #state_d = model.state_dict()
    #print(list(state_d.keys())[:5])

    checkpoint = torch.load(checkpoint_path)['state_dict']
    kezs = list(checkpoint.keys())
    state_d = OrderedDict()
    
    for k in kezs:
        new_k = k[6:]
        state_d[new_k] = checkpoint[k]
        
    #print(type(checkpoint))
    #print(list(checkpoint.keys())[:5])
    #print(type(state_d))
    #print(list(state_d.keys())[:5])   
    
    model.load_state_dict(state_d)
    print('\Loaded:\n', checkpoint_path)
    
    model.eval()
    
    return model
    
    
def read_numpy(path):
    a = np.load(path)
    if path[-1]=="z":
        a = a["arr_0"]
    return a
print(device)

    
def read_numpy(path):
    a = np.load(path)
    if path[-1]=="z":
        a = a["arr_0"]
    return a

print(device)


# trained model
rec_model = UNET()
load_weights(rec_model, 'checkpoints/3d_unet-epoch=49.ckpt')   
rec_model = rec_model.to(device)

cuda
cuda
\Loaded:
 checkpoints/3d_unet-epoch=49.ckpt


In [9]:


model = "new_half_3d"
ckpt = "_ep{}".format(50)

os.makedirs("rec3d_struct" , exist_ok = True)
os.makedirs("rec3d_struct/"+model+ckpt, exist_ok = True )

save_path = "rec3d_struct/"+model+ckpt
save_file = "/rec3d_{}"

# data to rec
filepath = "compare_data/test4k/scans/"
file_list = os.listdir(filepath)
#[print(file) for file in file_list]




for filename in file_list:
    print(filename)
    name = filename.split("_")[-1]
    
    scan = read_numpy(filepath+filename)
    #print(scan.shape, scan.dtype)
    
    scan = torch.from_numpy(scan).float()
    scan = torch.permute(scan, (2, 0, 1))
    scan = torch.unsqueeze(scan, dim = 0)
    
    dummy = torch.zeros(15, 31, 128,128)
    
    scan = torch.cat((scan,dummy), 0)  
    #print(scan.shape, scan.dtype)
    
    scan = scan.to(device)
    #print(scan.shape, scan.dtype, scan.device)  
    
    rec_ = rec_model(scan)
    
    rec = rec_.cpu().detach().numpy()
    check = np.sum(rec[10, :, :,:]!=rec[-10, :, :,:])
    #print("\ncheck =", check)
    
    rec_ = rec[0, :, :,:]
    check = np.sum(rec_!=rec[-1, :, :,:])
    #print("\t", check)
    #print(rec_.shape, rec.dtype, scan.device)
 

    zero_mask0 = rec_[0,:,:]<0.5
    zero_mask1 = rec_[1,:,:]<0.5
    zero_mask = np.logical_and(zero_mask0, zero_mask1)

    rec_integ  = np.argmax(rec_, axis = 0)
    #print(np.unique(rec_integ))
    rec_integ[rec_integ==1]=2
    #print(np.unique(rec_integ))
    rec_integ[rec_integ==0]=1
    #print(np.unique(rec_integ))
    rec_integ[zero_mask]=0
    #print(rec_integ.shape)
    #print(np.unique(rec_integ))
    
    np.save(save_path+save_file.format(name), rec_integ.astype(np.uint8))

scan_00222.npy
scan_01001.npy
scan_00999.npy
scan_00444.npy
scan_00888.npy
scan_00111.npy
scan_01111.npy
scan_01361.npy
scan_00555.npy
scan_01497.npy
scan_151193.npy
scan_00023.npy
scan_00665.npy
scan_00777.npy
scan_03849.npy
scan_00581.npy
scan_03333.npy
scan_00333.npy
scan_02459.npy
scan_03533.npy
scan_02613.npy
scan_04496.npy
scan_04444.npy
scan_00007.npy
scan_02136.npy
scan_02222.npy
