Reference: https://github.com/xuebinqin/U-2-Net

**Pre-requisites**


1.   !git clone https://github.com/xuebinqin/U-2-Net.git
2.   !mv ./U-2-Net ./drive/MyDrive/



In [1]:
!ln -s /content/drive/MyDrive/OkraInsight/training_data ./
!ln -s /content/drive/MyDrive/Colab\ Notebooks/Mobile-Unet/utils ./
!ln -s ./drive/MyDrive/U-2-Net ./U2Net
!ln -s /content/drive/MyDrive/OkraInsight/.ipynb_checkpoints ./

ln: failed to create symbolic link './training_data': File exists
ln: failed to create symbolic link './utils': File exists


In [23]:
import os
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

from utils.data_loading import BasicDataset, OkraDataset, OkraAugmentedDataset
from utils.dice_score import dice_loss

import numpy as np
import glob
import os

In [24]:
# ------- 1. define loss function --------

bce_loss = nn.BCELoss(size_average=True)

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):

	loss0 = bce_loss(d0,labels_v)
	loss1 = bce_loss(d1,labels_v)
	loss2 = bce_loss(d2,labels_v)
	loss3 = bce_loss(d3,labels_v)
	loss4 = bce_loss(d4,labels_v)
	loss5 = bce_loss(d5,labels_v)
	loss6 = bce_loss(d6,labels_v)

	loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
	print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.data.item(),loss1.data.item(),loss2.data.item(),loss3.data.item(),loss4.data.item(),loss5.data.item(),loss6.data.item()))

	return loss0, loss

In [25]:
# ------- 2. set the directory of training dataset --------

model_name = 'u2net' #'u2netp'



# data_dir = os.path.join(os.getcwd(), 'training_data' + os.sep)

# tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
# tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)

img_folder = './training_data/okra_images/class_okra_surface'
mask_folder ='./training_data/okra_segmentation_target_masks'

image_ext = '.jpeg'
label_ext = '.segmask'

model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)

epoch_num = 100
batch_size_train = 3
batch_size_val = 1
train_num = 0
val_num = 0


In [None]:
normalize = torchvision.transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
img_scale = 0.1

dataset_l2rUp = OkraDataset(img_folder, mask_folder, normalize, img_scale, 0)
dataset_r2lUp = OkraDataset(img_folder, mask_folder, normalize, img_scale, 1)
dataset_l2rDn = OkraDataset(img_folder, mask_folder, normalize, img_scale, 2)
dataset_r2lDn = OkraDataset(img_folder, mask_folder, normalize, img_scale, 3)

dataset = OkraAugmentedDataset([dataset_l2rUp, dataset_r2lUp, dataset_l2rDn, dataset_r2lDn])

# dataset = OkraDataset(img_folder, mask_folder, normalize, img_scale, 0)

In [27]:
    # 2. Split into train / validation partitions
from torch.utils.data import DataLoader, random_split

val_percent = 0.1
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

n_train = len(train_set)
n_val = len(val_set)

# train_sampler = torch.utils.data.RandomSampler(dataset)

if torch.cuda.is_available():
  device = torch.device('cuda')
  nworkers = 1
else:
  device = torch.device('cpu')
  nworkers = os.cpu_count()


batch_size = batch_size_train
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, drop_last=True, **loader_args)


#data_loader = torch.utils.data.DataLoader(
#  train_set,
#  batch_size=batch_size_train,
#  sampler=train_sampler,
#  num_workers=nworkers,
#  drop_last=True,
#)



In [None]:
# ------- 3. define model --------
from U2Net.model.u2net import U2NET, U2NETP

# define the net
if(model_name=='u2net'):
    net = U2NET(3, 1)
elif(model_name=='u2netp'):
    net = U2NETP(3,1)

if torch.cuda.is_available():
    net.cuda()

# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

In [None]:
        # forward + backward + optimize

# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000 # save the model every 2000 iterations


i = 0
for epoch in range (0, epoch_num):
  net.train()
  ite_num = ite_num + 1
  ite_num4val = ite_num4val + 1

  for image_target in train_loader:

    inputs, labels = image_target['image'].to(device), image_target['mask'].to(device)

    inputs = inputs.type(torch.FloatTensor)
    labels = labels.type(torch.FloatTensor)

    # wrap them in Variable
    if torch.cuda.is_available():
      inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),                                                                                    requires_grad=False)
    else:
      inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)

    # y zero the parameter gradients
    optimizer.zero_grad()

    d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
    d0 = d0.squeeze(1)
    d1 = d0.squeeze(1)
    d2 = d0.squeeze(1)
    d3 = d0.squeeze(1)
    d4 = d0.squeeze(1)
    d5 = d0.squeeze(1)
    d6 = d6.squeeze(1)

    loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

    loss.backward()
    optimizer.step()

        # # print statistics
    running_loss += loss.data.item()
    running_tar_loss += loss2.data.item()

        # del temporary outputs and loss
    del d0, d1, d2, d3, d4, d5, d6, loss2, loss

    print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
    epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))

    if ite_num % save_frq == 0:
      torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
      running_loss = 0.0
      running_tar_loss = 0.0
      net.train()  # resume train
      ite_num4val = 0
    i += 1

    if ((epoch %100) == 0 and epoch > 0):
      EPOCH = 5
      PATH = "okra_u2net_{epoch}.pth"
      LOSS = running_loss

      torch.save({
        'epoch': EPOCH,
        'model_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': LOSS,
      }, PATH)

net.to(torch.device('cpu'))
traceable_net = net.eval()

scriptedm = torch.jit.script(traceable_net)

scriptedm.save(".ipynb_checkpoints/okra_u2net.pt")