In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import functools
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from model import ModelClass

In [None]:
def dice_loss(pred, target):
    """This definition generalize to real valued pred and target vector.
This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(tflat * iflat)
    B_sum = torch.sum(tflat * tflat)
    
    return 1- ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )

In [2]:
EPS = 1e-6
#slightly modified
def get_IoU(outputs, labels):
    outputs = outputs.int()
    labels = labels.int()
    # Taken from: https://www.kaggle.com/iezepov/fast-iou-scoring-metric-in-pytorch-and-numpy
    intersection = (outputs & labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (outputs | labels).float().sum((1, 2))  # Will be zero if both are 0

    iou = (intersection + EPS) / (union + EPS)  # We smooth our devision to avoid 0/0

    # thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    # return thresholded.mean()  # Or thresholded.mean() if you are interested in average across the batch
    return iou.mean()

In [None]:
## Training model

import torch
import torch.nn as nn
from torch.nn import init
import functools
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from model import ModelClass


X = np.load("/content/drive/MyDrive/Images.npy")
Y = np.load("/content/drive/MyDrive/Masks.npy")
X=X.astype('float32')
# print(X.dtype)
# print(np.mean(X).dtype)
mean = np.mean(X)  # mean for data centering
std = np.std(X)  # std for data normalization

X -= mean  
X /= std

Y = np.where(Y > 1, 1, 0)

X_train=X[:1900]
Y_train=Y[:1900]
X_test=X[1900:]
Y_test=Y[1900:]


model=ModelClass().cuda()

lr = 3.00000002e-03 # 0.1
criterion = nn.BCELoss() #nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.8, nesterov=True, weight_decay=0.0003)

is_train = True
is_pretrain = False

total_epoch = 25
if is_train is True:

  if is_pretrain == True:
    model.load_state_dict((torch.load('/content/drive/MyDrive/Unet_resnet4.pkl')))
  
    
  
  for epoch in range(total_epoch):
      model.train()
      lossavg=0
      tot=0
      for inp,msk in zip(X_train,Y_train):
          # inp.astype(np.float32)
          # msk.astype(np.int32)


          input = Variable(torch.from_numpy(inp).type(torch.float32).cuda())
          
          mask = Variable(torch.from_numpy(msk).type(torch.float32).cuda())

          # Forward + Backward + Optimize
          optimizer.zero_grad()
          outputs = model(input.permute(2,0,1).unsqueeze(0))
          # cost=get_IoU(outputs, mask.permute(2,0,1).unsqueeze(0))
          
          # loss = Variable(cost.data, requires_grad=True)
          loss = dice_loss(outputs, mask.permute(2,0,1).unsqueeze(0)) #
          # loss = criterion(outputs, mask.permute(2,0,1).unsqueeze(0))
          loss.backward()
          # print(loss)
          optimizer.step()
          
          lossavg +=loss.item()
          tot+=1
          
      aver=lossavg/tot
      print(" ")
      
      
      print("Epoch [%d/%d],  AVERAGELoss: %.4f" %(epoch+1,total_epoch,aver))
      print('evaluate test set:')
      
      # # Decaying Learning Rate
      if (epoch+1) / float(total_epoch) == 0.3 or (epoch+1) / float(total_epoch) == 0.6 or (epoch+1) / float(total_epoch) == 0.9:
          lr /= 10
          print('reset learning rate to:', lr)
          for param_group in optimizer.param_groups:
              param_group['lr'] = lr
              print(param_group['lr'])
          # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
          # optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=0.0001)
  # Save the Model
  torch.save(model.state_dict(), '/content/drive/My Drive/Unet_resnet4.pkl')
