<a href="https://colab.research.google.com/github/thotakuria/surgical-tool-segmentation-using-deep-learning-techniques/blob/main/duo_segnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import torchvision.transforms as transforms
data_transforms = transforms.Compose([
                    transforms.CenterCrop(224),
                    transforms.ToTensor()])
image_datasets = datasets.ImageFolder(root= "/content/gdrive/MyDrive/dataset", transform=data_transforms)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=32, shuffle=True, num_workers=2)

In [1]:
import os
from PIL import Image
import torch.utils.data as data
from torchvision.transforms import transforms


class ObjDataset(data.Dataset):
    def __init__(self, images, gts, trainsize):
        self.trainsize = trainsize
        self.images = images
        self.gts = gts
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.filter_files()
        self.size = len(self.images)
        self.img_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        self.gt_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor()])

    def __getitem__(self, index):
        image = self.rgb_loader(self.images[index])
        gt = self.binary_loader(self.gts[index])

        image = self.img_transform(image)
        gt = self.gt_transform(gt)

        return image, gt

    def filter_files(self):
        assert len(self.images) == len(self.gts)
        images = []
        gts = []
        for img_path, gt_path in zip(self.images, self.gts):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            if img.size == gt.size:
                images.append(img_path)
                gts.append(gt_path)
        self.images = images
        self.gts = gts

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            # return img.convert('1')
            return img.convert('L')

    def resize(self, img, gt):
        assert img.size == gt.size
        w, h = img.size
        if h < self.trainsize or w < self.trainsize:
            h = max(h, self.trainsize)
            w = max(w, self.trainsize)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
        else:
            return img, gt

    def __len__(self):
        return self.size


class ValObjDataset(data.Dataset):
    def __init__(self, images, gts, trainsize):
        self.trainsize = trainsize
        self.images = images
        self.gts = gts
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.filter_files()
        self.size = len(self.images)
        self.img_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor()])
        self.gt_transform = transforms.Compose([
            transforms.Resize((self.trainsize, self.trainsize)),
            transforms.ToTensor()])

    def __getitem__(self, index):
        image = self.rgb_loader(self.images[index])
        gt = self.binary_loader(self.gts[index])

        image = self.img_transform(image)
        gt = self.gt_transform(gt)

        return image, gt

    def filter_files(self):
        assert len(self.images) == len(self.gts)
        images = []
        gts = []
        for img_path, gt_path in zip(self.images, self.gts):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            if img.size == gt.size:
                images.append(img_path)
                gts.append(gt_path)
        self.images = images
        self.gts = gts

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            # return img.convert('1')
            return img.convert('L')

    def resize(self, img, gt):
        assert img.size == gt.size
        w, h = img.size
        if h < self.trainsize or w < self.trainsize:
            h = max(h, self.trainsize)
            w = max(w, self.trainsize)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
        else:
            return img, gt

    def __len__(self):
        return self.size


def image_loader(image_root, gt_root, batch_size, image_size, split=0.8, labeled_ratio=0.05):
    images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
    gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]

    train_images = images[0:int(len(images) * split)]
    val_images = images[int(len(images) * split):]
    train_gts = gts[0:int(len(images) * split)]
    val_gts = gts[int(len(images) * split):]

    labeled_train_images = train_images[0:int(len(train_images) * labeled_ratio)]
    labeled_train_images_1 = labeled_train_images[0:int(len(labeled_train_images) * 0.5)]
    labeled_train_images_2 = labeled_train_images[int(len(labeled_train_images) * 0.5):]
    unlabeled_train_images = train_images[int(len(train_images) * labeled_ratio):]
    labeled_train_gts = train_gts[0:int(len(train_gts) * labeled_ratio)]
    labeled_train_gts_1 = labeled_train_gts[0:int(len(labeled_train_gts) * 0.5)]
    labeled_train_gts_2 = labeled_train_gts[int(len(labeled_train_gts) * 0.5):]
    unlabeled_train_gts = train_gts[int(len(train_gts) * labeled_ratio):]

    labeled_train_dataset_1 = ObjDataset(labeled_train_images_1, labeled_train_gts_1, image_size)
    labeled_train_dataset_2 = ObjDataset(labeled_train_images_2, labeled_train_gts_2, image_size)
    unlabeled_train_dataset = ObjDataset(unlabeled_train_images, unlabeled_train_gts, image_size)
    val_dataset = ValObjDataset(val_images, val_gts, image_size)

    labeled_data_loader_1 = data.DataLoader(dataset=labeled_train_dataset_1,
                                  batch_size=batch_size,
                                  num_workers=1,
                                  pin_memory=True,
                                  shuffle=True)

    labeled_data_loader_2 = data.DataLoader(dataset=labeled_train_dataset_2,
                                            batch_size=batch_size,
                                            num_workers=1,
                                            pin_memory=True,
                                            shuffle=True)

    unlabeled_data_loader = data.DataLoader(dataset=unlabeled_train_dataset,
                                          batch_size=batch_size,
                                          num_workers=1,
                                          pin_memory=True,
                                          shuffle=True)

    val_loader = data.DataLoader(dataset=val_dataset,
                                 batch_size=batch_size,
                                 num_workers=1,
                                 pin_memory=True,
                                 shuffle=False)

    return labeled_data_loader_1, labeled_data_loader_2, unlabeled_data_loader, val_loader

In [3]:
import numpy as np
import torch
from torch.autograd import Variable
CE = torch.nn.BCELoss()
pred_label = 0
gt_label = 1


def make_Dis_label(label, gts):
    D_label = np.ones(gts.shape) * label
    D_label = Variable(torch.FloatTensor(D_label)).cuda()

    return D_label


def calc_loss(pred, target, bce_weight=0.5):
    bce = CE(pred, target)
    dl = 1 - dice_coef(pred, target)
    loss = bce * bce_weight + dl * bce_weight

    return loss


def loss_sup(logit_S1, logit_S2, labels_S1, labels_S2):
    loss1 = calc_loss(logit_S1, labels_S1)
    loss2 = calc_loss(logit_S2, labels_S2)

    return loss1 + loss2


def loss_diff(u_prediction_1, u_prediction_2, batch_size):
    a = CE(u_prediction_1, Variable(u_prediction_2, requires_grad=False))
    a = a.item()

    b = CE(u_prediction_2, Variable(u_prediction_1, requires_grad=False))
    b = b.item()

    loss_diff_avg = (a + b)
    return loss_diff_avg / batch_size


def loss_adversarial_1(D_fake_1, D_fake_2, labels_S1, labels_S2):
    D_loss_fake_0_1 = CE(D_fake_1, make_Dis_label(pred_label, labels_S1))
    D_loss_fake_0_2 = CE(D_fake_2, make_Dis_label(pred_label, labels_S2))

    loss = D_loss_fake_0_1 + D_loss_fake_0_2
    return loss


def loss_adversarial_2(D_fake_1, D_real_1, D_fake_2, D_real_2, labels_S1, labels_S2):
    D_loss_fake_0_1 = CE(D_fake_1, make_Dis_label(pred_label, labels_S1))
    D_loss_fake_0_2 = CE(D_fake_2, make_Dis_label(pred_label, labels_S2))

    D_loss_real_1 = CE(D_real_1, make_Dis_label(gt_label, labels_S1))
    D_loss_real_2 = CE(D_real_2, make_Dis_label(gt_label, labels_S2))

    loss = D_loss_fake_0_1 + D_loss_fake_0_2 + D_loss_real_1 + D_loss_real_2
    return loss

In [4]:
def dice_coef(output, target):
    smooth = 1e-5

    output = output.view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / (output.sum() + target.sum() + smooth)


In [None]:
import argparse
import glob
import os

import imageio
import numpy as np
import torch
import yaml
from PIL import Image
from sklearn.metrics import f1_score, mean_absolute_error
from torch.autograd import Variable
from torchvision import transforms
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

parser = argparse.ArgumentParser()
parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
parser.add_argument('--trainsize', type=int, default=256, help='training dataset size')
parser.add_argument('--dataset', type=str, default='nuclei', help='dataset name')
parser.add_argument('--threshold', type=float, default=0.5, help='threshold')
opt = parser.parse_args()


class Test(object):
    def __init__(self):
        self._init_configure()
        self._init_logger()
        self.model_1 = Unet()
        self.model_2 = Unet()

    def _init_configure(self):
        with open('configs/config.yml') as fp:
            self.cfg = yaml.safe_load(fp)

    def _init_logger(self):

        log_dir = 'logs/' + opt.dataset + '/test'

        self.logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))

        self.save_path = log_dir
        self.image_save_path_1 = log_dir + "/saved_images_1"
        create_dir(self.image_save_path_1)
        self.image_save_path_2 = log_dir + "/saved_images_2"
        create_dir(self.image_save_path_2)

        self.model_1_load_path = 'logs/' + opt.dataset + '/train/Model_1.pth'
        self.model_2_load_path = 'logs/' + opt.dataset + '/train/Model_2.pth'

    def visualize_val_input(self, var_map, i):
        count = i
        im = transforms.ToPILImage()(var_map.squeeze_(0).detach().cpu()).convert("RGB")
        name = '{:02d}_input.png'.format(count)
        imageio.imwrite(self.image_save_path_1 + "/val_" + name, im)

    def visualize_gt(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_gt.png'.format(count)
            imageio.imwrite(self.image_save_path_1 + "/val_" + name, pred_edge_kk)
            imageio.imwrite(self.image_save_path_2 + "/val_" + name, pred_edge_kk)

    def visualize_prediction1(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_pred_1.png'.format(count)
            imageio.imwrite(self.image_save_path_1 + "/val_" + name, pred_edge_kk)

    def visualize_prediction2(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_pred_2.png'.format(count)
            imageio.imwrite(self.image_save_path_2 + "/val_" + name, pred_edge_kk)

    def visualize_uncertainity(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_pred.png'.format(count)
            imageio.imwrite(self.image_save_path_1 + "/uncertainity_" + name, pred_edge_kk)

    def evaluate_model_1(self, image_dir):

        target_list = np.array([])
        output_list = np.array([])
        output_pred_list = np.array([])
        test_dir = image_dir
        self.logger.info(test_dir)

        pred_files = glob.glob(test_dir + 'val_*_pred_1.png')
        gt_files = glob.glob(test_dir + 'val_*_gt.png')

        for file in pred_files:
            image = Image.open(file)
            output = np.asarray(image)
            output = output.flatten() / 255
            output_pred = (output > opt.threshold)
            output_list = np.concatenate((output_list, output), axis=None)
            output_pred_list = np.concatenate((output_pred_list, output_pred), axis=None)

        for file in gt_files:
            image = Image.open(file)
            target = np.asarray(image)
            target = target.flatten() / 255
            target = (target > opt.threshold)
            target_list = np.concatenate((target_list, target), axis=None)
        F1_score = f1_score(target_list, output_pred_list)
        self.logger.info("Model 1 F1 score : {} ".format(F1_score))
        mae = mean_absolute_error(target_list, output_pred_list)
        self.logger.info("Model 1 MAE : {} ".format(mae))

    def evaluate_model_2(self, image_dir):

        target_list = np.array([])
        output_list = np.array([])
        output_pred_list = np.array([])
        test_dir = image_dir
        self.logger.info(test_dir)

        pred_files = glob.glob(test_dir + 'val_*_pred_2.png')
        gt_files = glob.glob(test_dir + 'val_*_gt.png')

        for file in pred_files:
            image = Image.open(file)
            output = np.asarray(image)
            output = output.flatten() / 255
            output_pred = (output > opt.threshold)
            output_list = np.concatenate((output_list, output), axis=None)
            output_pred_list = np.concatenate((output_pred_list, output_pred), axis=None)

        for file in gt_files:
            image = Image.open(file)
            target = np.asarray(image)
            target = target.flatten() / 255
            target = (target > opt.threshold)
            target_list = np.concatenate((target_list, target), axis=None)
        F1_score = f1_score(target_list, output_pred_list)
        self.logger.info("Model 2 F1 score : {} ".format(F1_score))
        mae = mean_absolute_error(target_list, output_pred_list)
        self.logger.info("Model 2 MAE : {} ".format(mae))

    def run(self):

        self.model_1.load_state_dict(torch.load(self.model_1_load_path))
        self.model_1.cuda()

        self.model_2.load_state_dict(torch.load(self.model_2_load_path))
        self.model_2.cuda()

        image_root = self.cfg[opt.dataset]['image_dir']
        gt_root = self.cfg[opt.dataset]['mask_dir']

        _, _, _, val_loader = image_loader(image_root, gt_root, opt.batchsize, opt.trainsize)

        for i, pack in enumerate(val_loader, start=1):
            with torch.no_grad():
                images, gts = pack
                images = Variable(images)
                gts = Variable(gts)
                images = images.cuda()
                gts = gts.cuda()

                prediction1 = self.model_1(images)
                prediction1 = torch.sigmoid(prediction1)

                prediction2 = self.model_2(images)
                prediction2 = torch.sigmoid(prediction2)

            self.visualize_val_input(images, i)
            self.visualize_gt(gts, i)
            self.visualize_prediction1(prediction1, i)
            self.visualize_prediction2(prediction2, i)

        self.evaluate_model_1(self.cfg[opt.dataset]['predictions_dir_1'])
        self.evaluate_model_2(self.cfg[opt.dataset]['predictions_dir_2'])


if __name__ == '__main__':
    Test_network = Test()
    Test_network.run()

In [7]:
import os
import logging
import sys


def create_exp_dir(path, desc='Experiment dir: {}'):
    if not os.path.exists(path):
        os.makedirs(path)
    print(desc.format(path))


def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def get_logger(log_dir):
    create_exp_dir(log_dir)
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(log_dir, 'run.log'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Nas Seg')
    logger.addHandler(fh)
    return logger


In [None]:
import argparse
import os
from datetime import datetime
from distutils.dir_util import copy_tree

import numpy as np
import torch
import torch.nn.functional as F
import yaml
!pip install tensorboardX
from tensorboardX import SummaryWriter
from torch.autograd import Variable
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=200, help='epoch number')
parser.add_argument('--lr_gen', type=float, default=1e-2, help='learning rate')
parser.add_argument('--lr_dis', type=float, default=5e-5, help='learning rate')
parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
parser.add_argument('--trainsize', type=int, default=256, help='training dataset size')
parser.add_argument('--dataset', type=str, default='nuclei', help='dataset name')
parser.add_argument('--split', type=float, default=0.8, help='training data ratio')
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--decay', default=3e-5, type=float)
parser.add_argument('--ratio', type=float, default=0.05, help='labeled data ratio')

opt = parser.parse_args()
CE = torch.nn.BCELoss()


class Network(object):
    def __init__(self):
        self.patience = 0
        self.best_dice_coeff_1 = False
        self.best_dice_coeff_2 = False
        self.model_1 = Unet()
        self.model_2 = Unet()
        self.critic = PixelDiscriminator()
        self.best_mIoU, self.best_dice_coeff = 0, 0
        self._init_configure()
        self._init_logger()

    def _init_configure(self):
        with open('configs/config.yml') as fp:
            self.cfg = yaml.safe_load(fp)

    def _init_logger(self):

        log_dir = 'logs/' + opt.dataset + '/train/'

        self.logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))

        self.save_path = log_dir
        self.image_save_path_1 = log_dir + "/saved_images_1"
        self.image_save_path_2 = log_dir + "/saved_images_2"

        create_dir(self.image_save_path_1)
        create_dir(self.image_save_path_2)

        self.save_tbx_log = self.save_path + '/tbx_log'
        self.writer = SummaryWriter(self.save_tbx_log)

    def run(self):
        print('Learning Rate:'.format(opt.lr_gen,opt.lr_dis))

        self.model_1.cuda()
        self.model_2.cuda()

        params = list(self.model_1.parameters()) + list(self.model_2.parameters())
        dis_params = self.critic.parameters()
        optimizer = torch.optim.SGD(params, lr=opt.lr_gen, momentum=opt.momentum)
        dis_optimizer = torch.optim.RMSprop(dis_params, opt.lr_dis)

        image_root = self.cfg[opt.dataset]['image_dir']
        gt_root = self.cfg[opt.dataset]['mask_dir']

        self.logger.info("Split Percentage : {} Labeled Data Ratio : {}".format(opt.split, opt.ratio))
        train_loader_1, train_loader_2, unlabeled_train_loader, val_loader = image_loader(image_root, gt_root,
                                                                                          opt.batchsize, opt.trainsize,
                                                                                          opt.split, opt.ratio)
        self.logger.info(
            "train_loader_1 {} train_loader_2 {} unlabeled_train_loader {} val_loader {}".format(len(train_loader_1),
                                                                                                 len(train_loader_2),
                                                                                                 len(
                                                                                                     unlabeled_train_loader),
                                                                                                 len(val_loader)))
        print("Let's go!")

        for epoch in range(1, opt.epoch):

            running_loss = 0.0
            running_dice_val_1 = 0.0
            running_dice_val_2 = 0.0

            for i, data in enumerate(zip(train_loader_1, train_loader_2, unlabeled_train_loader)):
                inputs_S1, labels_S1 = data[0][0], data[0][1]
                inputs_S2, labels_S2 = data[1][0], data[1][1]
                inputs_U, labels_U = data[2][0], data[2][1]

                inputs_S1, labels_S1 = Variable(inputs_S1), Variable(labels_S1)
                inputs_S1, labels_S1 = inputs_S1.cuda(), labels_S1.cuda()
                inputs_S2, labels_S2 = Variable(inputs_S2), Variable(labels_S2)
                inputs_S2, labels_S2 = inputs_S2.cuda(), labels_S2.cuda()
                inputs_U = Variable(inputs_U)
                inputs_U = inputs_U.cuda()

                optimizer.zero_grad()
                prediction_1 = self.model_1(inputs_S1)
                prediction_1 = torch.sigmoid(prediction_1)

                u_prediction_1 = self.model_1(inputs_U)
                u_prediction_1 = torch.sigmoid(u_prediction_1)
                prediction_2 = self.model_2(inputs_S2)
                prediction_2 = torch.sigmoid(prediction_2)

                u_prediction_2 = self.model_2(inputs_U)
                u_prediction_2 = torch.sigmoid(u_prediction_2)

                self.critic.cuda()

                Loss_sup = loss_sup(prediction_1, prediction_2, labels_S1, labels_S2)
                Loss_diff = loss_diff(u_prediction_1, u_prediction_2, opt.batchsize)

                prediction_1 = prediction_1.detach()
                prediction_2 = prediction_2.detach()

                D_fake_1 = F.interpolate(torch.sigmoid(self.critic(prediction_1)),
                                         (prediction_1.shape[2], prediction_1.shape[3]),
                                         mode='bilinear', align_corners=False)
                D_fake_2 = F.interpolate(torch.sigmoid(self.critic(prediction_2)),
                                         (prediction_2.shape[2], prediction_2.shape[3]),
                                         mode='bilinear', align_corners=False)

                D_fake3 = F.interpolate(torch.sigmoid(self.critic(u_prediction_1)),
                                         (u_prediction_1.shape[2], u_prediction_1.shape[3]),
                                         mode='bilinear', align_corners=False)

                D_fake4 = F.interpolate(torch.sigmoid(self.critic(u_prediction_2)),
                                        (u_prediction_2.shape[2], u_prediction_2.shape[3]),
                                        mode='bilinear', align_corners=False)

                ignore_mask_remain_1 = np.zeros(D_fake3.shape).astype(np.bool)
                ignore_mask_remain_2 = np.zeros(D_fake4.shape).astype(np.bool)

                Loss_adv_labeled = loss_adversarial_1(D_fake_1, D_fake_2, labels_S1, labels_S2)
                Loss_adv_unlabeled = CE(D_fake3, make_Dis_label(gt_label, ignore_mask_remain_1)) + CE(D_fake4, make_Dis_label(gt_label, ignore_mask_remain_2))
                Loss_adv1 = Loss_adv_labeled + Loss_adv_unlabeled
                seg_loss = Loss_sup + 0.4 * Loss_diff + 0.2 * Loss_adv1

                seg_loss.backward()
                running_loss += seg_loss.item()
                optimizer.step()
                dis_optimizer.zero_grad()
                prediction_1 = prediction_1.detach()
                prediction_2 = prediction_2.detach()

                D_fake_1 = F.interpolate(torch.sigmoid(self.critic(prediction_1)),
                                         (prediction_1.shape[2], prediction_1.shape[3]),
                                         mode='bilinear', align_corners=False)
                D_fake_2 = F.interpolate(torch.sigmoid(self.critic(prediction_2)),
                                         (prediction_2.shape[2], prediction_2.shape[3]),
                                         mode='bilinear', align_corners=False)
                D_real_1 = F.interpolate(torch.sigmoid(self.critic(labels_S1)),
                                         (labels_S1.shape[2], labels_S1.shape[3]),
                                         mode='bilinear', align_corners=False)
                D_real_2 = F.interpolate(torch.sigmoid(self.critic(labels_S2)),
                                         (labels_S2.shape[2], labels_S2.shape[3]),
                                         mode='bilinear', align_corners=False)

                Loss_adv2 = loss_adversarial_2(D_fake_1, D_real_1, D_fake_2, D_real_2, labels_S1, labels_S2)
                Loss_adv2.backward()
                dis_optimizer.step()

            epoch_loss = running_loss / (len(train_loader_1) + len(train_loader_2))
            self.logger.info('{} Epoch [{:03d}/{:03d}], total_loss : {:.4f}'.
                             format(datetime.now(), epoch, opt.epoch, epoch_loss))

            self.logger.info('Train loss: {}'.format(epoch_loss))
            self.writer.add_scalar('Train/Loss', epoch_loss, epoch)

            for i, pack in enumerate(val_loader, start=1):
                with torch.no_grad():
                    images, gts = pack
                    images = Variable(images)
                    gts = Variable(gts)
                    images = images.cuda()
                    gts = gts.cuda()

                    prediction_1 = self.model_1(images)
                    prediction_1 = torch.sigmoid(prediction_1)

                    prediction_2 = self.model_2(images)
                    prediction_2 = torch.sigmoid(prediction_2)

                dice_coe_1 = dice_coef(prediction_1, gts)
                running_dice_val_1 += dice_coe_1
                dice_coe_2 = dice_coef(prediction_2, gts)
                running_dice_val_2 += dice_coe_2

            epoch_dice_val_1 = running_dice_val_1 / len(val_loader)

            self.logger.info('Validation dice coeff model 1: {}'.format(epoch_dice_val_1))
            self.writer.add_scalar('Validation_1/DSC', epoch_dice_val_1, epoch)

            epoch_dice_val_2 = running_dice_val_2 / len(val_loader)

            self.logger.info('Validation dice coeff model 1: {}'.format(epoch_dice_val_2))
            self.writer.add_scalar('Validation_1/DSC', epoch_dice_val_2, epoch)

            mdice_coeff_1 = epoch_dice_val_1
            mdice_coeff_2 = epoch_dice_val_2

            if self.best_dice_coeff_1 < mdice_coeff_1:
                self.best_dice_coeff_1 = mdice_coeff_1
                self.save_best_model_1 = True

                if not os.path.exists(self.image_save_path_1):
                    os.makedirs(self.image_save_path_1)

                copy_tree(self.image_save_path_1, self.save_path + '/best_model_predictions_1')
                self.patience = 0
            else:
                self.save_best_model_1 = False
                self.patience += 1

            if self.best_dice_coeff_2 < mdice_coeff_2:
                self.best_dice_coeff_2 = mdice_coeff_2
                self.save_best_model_2 = True

                if not os.path.exists(self.image_save_path_2):
                    os.makedirs(self.image_save_path_2)

                copy_tree(self.image_save_path_2, self.save_path + '/best_model_predictions_2')
                self.patience = 0
            else:
                self.save_best_model_2 = False
                self.patience += 1

            Checkpoints_Path = self.save_path + '/Checkpoints'

            if not os.path.exists(Checkpoints_Path):
                os.makedirs(Checkpoints_Path)

            if self.save_best_model_1:
                torch.save(self.model_1.state_dict(), Checkpoints_Path + '/Model_1.pth')
                torch.save(self.critic.state_dict(), Checkpoints_Path + '/Critic.pth')

            if self.save_best_model_2:
                torch.save(self.model_2.state_dict(), Checkpoints_Path + '/Model_2.pth')
                torch.save(self.critic.state_dict(), Checkpoints_Path + '/Critic.pth')

            self.logger.info(
                'current best dice coef model 1 {}, model 2 {}'.format(self.best_dice_coeff_1, self.best_dice_coeff_2))
            self.logger.info('current patience :{}'.format(self.patience))


if __name__ == '__main__':
    train_network = Network()
    train_network.run()

In [None]:
norm_image = (gray_image - np.min(gray_image)) / (np.max(gray_image) - np.min(gray_image))
plt.imshow(norm_image)