In [1]:
from unet import unet_constructor as GUnet
import dataloader as dataloader
from loss import dice_loss, cross_entropy_loss, random_cross_entropy
import transforms as t
import torch.nn.functional as F
import torch
import torch.nn as nn
import torchvision.transforms as tt
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import pickle
import skimage.io as io
import os

writer = SummaryWriter()

In [2]:
data = dataloader.stack(path='./Data/train',
                        joint_transforms=[t.to_float(),
                                          t.reshape(),
                                          t.nul_crop(),
                                          t.random_crop([300, 300, 19]),
                                          t.random_rotate(),
                                          t.random_affine()
                                          ],
                        image_transforms=[
                                          t.drop_channel(1), 
                                          t.random_gamma((.8,1.2)),
                                          t.spekle(),
                                          t.normalize([0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5])
                                          ]
                        )

val_data = dataloader.stack(path='./Data/train',
                        joint_transforms=[t.to_float(),
                                          t.reshape(),
                                          t.random_crop([512, 512, 19]),
                                          t.random_rotate(90),
                                          #t.random_affine
                                          ],
                        image_transforms=[
                                          #t.random_gamma((.8,1.2)),
                                          #t.spekle(),
                                          t.normalize([0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5])
                                          ]
                        )



In [3]:
device = 'cuda:0'

test = GUnet(image_dimensions=3,
             in_channels=4,
             out_channels=1,
             feature_sizes=[16,32,64,128],
             kernel={'conv1': (3, 3, 2), 'conv2': (3, 3, 1)},
             upsample_kernel=(8, 8, 2),
             max_pool_kernel=(2, 2, 1),
             upsample_stride=(2, 2, 1),
             dilation=1,
             groups=2).to(device)


test = test.type(torch.float)

image, mask, pwl = data[0]

out = test.forward(image.float().to(device))
out_loss = dice_loss(out, mask.to(device))#, pwl.float().to('cuda'))



In [4]:
test.load('Apr25_chris-MS-7C37_1.unet')
test.cuda()
test.train()

unet_constructor(
  (out_conv): Conv3d(16, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (down_steps): ModuleList(
    (0): Down(
      (conv1): Conv3d(4, 16, kernel_size=(3, 3, 2), stride=(1, 1, 1), groups=2)
      (conv2): Conv3d(16, 16, kernel_size=(3, 3, 1), stride=(1, 1, 1), groups=2)
      (batch1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (batch2): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Down(
      (conv1): Conv3d(16, 32, kernel_size=(3, 3, 2), stride=(1, 1, 1), groups=2)
      (conv2): Conv3d(32, 32, kernel_size=(3, 3, 1), stride=(1, 1, 1), groups=2)
      (batch1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (batch2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (2): Down(
      (conv1): Conv3d(32, 64, kernel_size=(3, 3, 2), stride=(1, 1, 1), groups=2

In [5]:
epoch = 0 
k = 0
running_loss = 0
device = 'cuda:0'

losses_dice=[]
avg_losses_dice = []
losses_BCE=[]
avg_losses_BCE = []



In [6]:
lr = .01
gamma = .75

optimizer = torch.optim.Adam(test.parameters(), lr = lr)
#optimizer = torch.optim.SGD(test.parameters(), lr=lr, momentum=.99, nesterov=True)

scheduler_1 = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=gamma)
scheduler_2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=100)


In [None]:
while True:  # loop over the dataset multiple times
    epoch_loss_dice=[]
    epoch_loss_BCE=[]
    epoch_loss_joint=[]
    for i in range(len(data)):
        image, mask, pwl  = data[i]

        optimizer.zero_grad()
        out = test(image.float().to('cuda'))

        #if epoch % 2 ==0:
        #    out_loss = dice_loss(out, mask.to('cuda'))
        #    losses_dice.append(out_loss.item())
        #    epoch_loss_dice.append(out_loss.item())
            
        #else: 
        #    out_loss = random_cross_entropy(out, mask.to('cuda'), pwl.to('cuda'), 5, 500)
        #    losses_BCE.append(out_loss.item())
        #    epoch_loss_BCE.append(out_loss.item())

        out_loss = random_cross_entropy(out, mask.to('cuda'), pwl.to('cuda'), 5, 5000) +\
                   dice_loss(out, mask.to('cuda'))
        epoch_loss_joint.append(out_loss.item())
        print(f'\r\033[32mEpoch: \033[0m{epoch} |',end='')
        #if epoch % 2 ==0:
        #    print(f' DICE |',end='')
        #else 
        #    print(' BCE |',end='')
        
        print(f' JOINT |', end='')
        
        
        print(f' {i} |',end='')
        print(f' \033[35mPOL: \033[0m{str(running_loss)[0:8]} -->', end='')
        print(f' \033[35mOL: \033[0m{str(out_loss.item())[0:8]}', end='')

        #if epoch > 30 and epoch % 2 ==0:
        #    avg_losses_dice.append(np.array(losses_dice[-10::1]).sum()/len(losses_dice[-30::1]))
        #    print(f' | \033[31mAL: \033[0m{str(avg_losses_dice[-1])[0:8]}',end='')
        #elif epoch > 30:
        #    avg_losses_BCE.append(np.array(losses_BCE[-10::1]).sum()/len(losses_BCE[-30::1]))
        #    print(f' | \033[31mAL: \033[0m{str(avg_losses_BCE[-1])[0:8]}',end='')
        
        #if epoch > 10:
        #    print(f' | \033[31mAL DICE: \033[0m{str(avg_losses_dice[-1])[0:8]}',end='')
        #    print(f'  \033[31mAL BCE: \033[0m{str(avg_losses_BCE[-1])[0:8]}',end='')


        out_loss.backward()
        optimizer.step()

        running_loss = out_loss.item()

        if epoch+1 % 150 == 0:
            scheduler_1.step() 
        #    writer.add_scalar('Learning Rate', lr, epoch)o
        #    lr *= gamma


    #After Epoch has run    
    #if epoch % 2 ==0:
    #    to_tb = (np.array(epoch_loss_dice).sum()/len(epoch_loss_dice))
    #    writer.add_scalar('Dice Loss Train',to_tb, epoch)
    #    epoch += 1
    #else:
    #    to_tb = (np.array(epoch_loss_BCE).sum()/len(epoch_loss_BCE))
    #    writer.add_scalar('BCE Loss Train',to_tb, epoch)
    #    epoch += 1
    
    to_tb = np.array(epoch_loss_joint).sum()/len(epoch_loss_joint)
    writer.add_scalar('Joint Loss Train', to_tb, epoch)
    epoch += 1




[32mEpoch: [0m206 | JOINT | 0 | [35mPOL: [0m0.642850 --> [35mOL: [0m0.582397

In [None]:
epoch_loss_dice

In [None]:
test.save('Apr25_chris-MS-7C37_1.unet')

In [None]:
plt.semilogy(avg_losses_dice[10::])

In [None]:
image, mask, pwl  = val_data[8]
#test.eval()
#test.cpu()
#torch.no_grad() 
out = test(image.float().cuda())


In [None]:
pred = F.sigmoid(out) > .5
for i in range(14):
    #plt.imshow(image[0,2,:,:,i],cmap='Reds_r')
    plt.imshow(pred[0,0,:,:,i].cpu().detach().numpy(), cmap='Greys_r')
    plt.show()

In [None]:
pred.min()

In [None]:
io.imsave('best_yet.tif', pred.cpu().squeeze(0).transpose(0,3).float().numpy())

In [None]:
n.shape