In [1]:
import numpy as np
from tqdm import tqdm
from echo_lv.data import LV_CAMUS_Dataset, LV_EKB_Dataset
from echo_lv.metrics import dice as dice_np
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import sigmoid
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from torchsummary import summary
from echo_lv.utils import AverageMeter
import segmentation_models_pytorch as smp
from com_unet import UNet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

random_state = 17
torch.manual_seed(random_state)
torch.cuda.manual_seed(random_state)
torch.backends.cudnn.deterministic = True

cuda:0


In [2]:
model = UNet(n_channels = 1, n_classes = 1, bilinear=False, batchnorm=False).to(device)

In [3]:
summary(model, (1, 572, 572), device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 570, 570]             640
              ReLU-2         [-1, 64, 570, 570]               0
            Conv2d-3         [-1, 64, 568, 568]          36,928
              ReLU-4         [-1, 64, 568, 568]               0
        DoubleConv-5         [-1, 64, 568, 568]               0
         MaxPool2d-6         [-1, 64, 284, 284]               0
            Conv2d-7        [-1, 128, 282, 282]          73,856
              ReLU-8        [-1, 128, 282, 282]               0
            Conv2d-9        [-1, 128, 280, 280]         147,584
             ReLU-10        [-1, 128, 280, 280]               0
       DoubleConv-11        [-1, 128, 280, 280]               0
             Down-12        [-1, 128, 280, 280]               0
        MaxPool2d-13        [-1, 128, 140, 140]               0
           Conv2d-14        [-1, 256, 1

In [4]:
folds = 10
lv_camus = LV_CAMUS_Dataset(img_size = (388,388), classes = {0, 1}, folds=folds)
# lv_camus_valid = LV_CAMUS_Dataset(img_size = (572,572), classes = {0, 1}, folds=10, subset='valid')
lv_camus.set_state('train', 0)
train_loader = DataLoader(lv_camus, batch_size=1, shuffle=True, num_workers=4)

In [7]:
lv_camus.set_state(subset='valid', fold=9)
len(lv_camus)

144

In [8]:
lv_camus.set_state(subset='train', fold=9)
len(lv_camus)

1656

In [5]:

epochs = 50

# criterion = torch.nn.functional.binary_cross_entropy_with_logits #DiceLoss_1()
# criterion = smp.utils.losses.DiceLoss(activation='sigmoid')# + smp.utils.losses.BCEWithLogitsLoss()
weight = 10 * torch.ones((1,1,388,388), device=device)
criterion = smp.utils.losses.BCEWithLogitsLoss(pos_weight=weight)
# criterion = smp.utils.losses.DiceLoss(activation='sigmoid')
dice = smp.utils.metrics.Fscore(activation='sigmoid', threshold=0.5)#Dice()
iou = smp.utils.metrics.IoU(activation='sigmoid', threshold=0.5)
result = []

for fold in range(1):
    
    model = UNet(n_channels = 1, n_classes = 1, bilinear=False).to(device)
    
    optimizer = torch.optim.SGD([
            {'params': model.parameters(), 'lr': 1e-3},   
        ])
    
     # создаем даталодер
    
    
    
    val_dice = 0
    
    t = tqdm(total=epochs, 
            bar_format='{desc} | {postfix[0]}/'+ str(epochs) +' | ' +
            '{postfix[1]} : {postfix[2]:>2.4f} | {postfix[3]} : {postfix[4]:>2.4f} | ' +
            '{postfix[5]} : {postfix[6]:>2.4f} | {postfix[7]} : {postfix[8]:>2.4f}',
            postfix=[0, 'loss', 0, 'dice_lv', 0,  
                     'val_loss', 0, 'val_dice_lv', val_dice, 0], 
            desc = 'Train ' + 'unet' +   ' on fold ' + str(fold),
            position=0, leave=True
         )
    
#         t.reset()
            
        
        
    for epoch in range(0, epochs):
        average_total_loss = AverageMeter()
        average_dice = AverageMeter()
        torch.cuda.empty_cache()
        model.train()

        t.postfix[0] = epoch + 1
        
        lv_camus.set_state('train', fold)
        train_loader = DataLoader(lv_camus, batch_size=2, shuffle=True, num_workers=4)
        for data in train_loader:

            torch.cuda.empty_cache()
            inputs, masks, *_ = data
            shape = inputs.shape
            inputs = torch.cat([torch.zeros((shape[0], shape[1], shape[2], 92), dtype=float), inputs, torch.zeros((shape[0], shape[1], shape[2], 92), dtype=float)], axis=3)
            shape = inputs.shape
            inputs = torch.cat([torch.zeros((shape[0], shape[1], 92, shape[3]), dtype=float), inputs, torch.zeros((shape[0], shape[1], 92, shape[3]), dtype=float)], axis=2)
            inputs=inputs.to(device).float()
#             masks=masks[:,:,92:480,92:480].to(device).float()
            masks=masks.to(device).float()

            optimizer.zero_grad()
            outputs = model(inputs)

            loss = criterion(outputs, masks)

            average_total_loss.update(loss.data.item())

            average_dice.update(dice(outputs, masks).item())

            loss.backward()
            optimizer.step()

            t.postfix[2] = average_total_loss.average()
            t.postfix[4] = average_dice.average()
            t.update(n=1)

         # validation
        average_val_total_loss = AverageMeter()
        average_val_dice = AverageMeter()
        model.eval()
        
        lv_camus.set_state('valid', fold)
        valid_loader = DataLoader(lv_camus, batch_size=2, shuffle=True, num_workers=2)
        for data in valid_loader:
            inputs, masks, *_ = data

            shape = inputs.shape
            inputs = torch.cat([torch.zeros((shape[0], shape[1], shape[2], 92), dtype=float), inputs, torch.zeros((shape[0], shape[1], shape[2], 92), dtype=float)], axis=3)
            shape = inputs.shape
            inputs = torch.cat([torch.zeros((shape[0], shape[1], 92, shape[3]), dtype=float), inputs, torch.zeros((shape[0], shape[1], 92, shape[3]), dtype=float)], axis=2)
            inputs=inputs.to(device).float()
#             masks=masks[:,:,92:480,92:480].to(device).float()
            masks=masks.to(device).float()

            outputs = model(inputs)

            loss = criterion(outputs, masks)

            average_val_total_loss.update(loss.data.item())

            average_val_dice.update(dice(outputs, masks).item())

            t.postfix[6] = average_val_total_loss.average()
            t.postfix[8] = average_val_dice.average()
            t.update(n=0)

        val_loss = average_val_total_loss.average()
        val_dice = average_val_dice.average()

            
#             if val_dice_0 < 2e-2 or val_dice_1 < 2e-2 or val_dice_2 < 2e-2 or val_dice_3 < 2e-2:
#                 break
                
#             if epoch > 5 and val_loss > 0.15:
#                 val_dice_0 = 0
#                 val_dice_1 = 0
#                 val_dice_2 = 0
#                 val_dice_3 = 0
#                 break
                    
    t.close()
#     result.append([average_total_loss.average(), average_dice[0].average(), average_dice[1].average(), average_dice[2].average(), average_dice[3].average(),
#                    average_val_total_loss.average(), average_val_dice[0].average(), average_val_dice[1].average(), average_val_dice[2].average(), average_val_dice[3].average()
#                   ])

# df = pd.DataFrame(np.array(result), columns=['loss', 'dice_bg', 'dice_lv', 'dice_ep', 'dice_la', 'val_loss', 'val_dice_bg', 'val_dice_lv', 'val_dice_ep', 'val_dice_la'])
# df.to_csv('result.csv')

Train unet on fold 0 | 50/50 | loss : 0.0577 | dice_lv : 0.9237 | val_loss : 0.1575 | val_dice_lv : 0.8810


In [10]:
torch.save(model.state_dict(), 'common_unet_wo_bn.pth')