<a href="https://colab.research.google.com/github/thotakuria/surgical-tool-segmentation-using-deep-learning-techniques/blob/main/minmax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
import torchvision.transforms as transforms
data_transforms = transforms.Compose([
                    transforms.CenterCrop(224),
                    transforms.ToTensor()])
image_datasets = datasets.ImageFolder(root= "/content/gdrive/MyDrive/dataset", transform=data_transforms)
dataloaders = torch.utils.data.DataLoader(image_datasets, batch_size=32, shuffle=True, num_workers=2)

In [1]:
import os
import logging
import sys


def create_exp_dir(path, desc='Experiment dir: {}'):
    if not os.path.exists(path):
        os.makedirs(path)
    print(desc.format(path))


def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def get_logger(log_dir):
    create_exp_dir(log_dir)
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(log_dir, 'run.log'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Nas Seg')
    logger.addHandler(fh)
    return logger


In [2]:
import functools

import torch
import torch.nn as nn

class CONV_Block(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.LeakyReLU()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class conv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace = True)
        self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)


    def forward(self, x):
        y_1 = self.conv(x)
        y_1 = self.bn(y_1)
        y_1 = self.relu(y_1)
        

        return y_1


class projectors(nn.Module):
    def __init__(self, input_nc=1, ndf=8, norm_layer=nn.BatchNorm2d):
        super(projectors, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.pool = nn.MaxPool2d(2, 2)
        self.conv_1 = conv(input_nc, ndf)
        self.conv_2 = conv(ndf, ndf*2)
        self.final = nn.Conv2d(ndf*2, ndf*2, kernel_size=1)
    def forward(self, input):
        x_0 = self.conv_1(input)
        x_0 = self.pool(x_0)
        x_out = self.conv_2(x_0)
        x_out = self.pool(x_out)
        x_out = self.final(x_out)
        return x_out    
    
class classifier(nn.Module):
    def __init__(self, inp_dim = 1,ndf=8, norm_layer=nn.BatchNorm2d):
        super(classifier, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        self.pool = nn.MaxPool2d(2, 2)
        self.conv_1 = conv(inp_dim, ndf)
        self.conv_2 = conv(ndf, ndf*2)
        self.conv_3 = conv(ndf*2, ndf*4)
        self.final = nn.Conv2d(ndf*4, ndf*4, kernel_size=1)
    def forward(self,input):
        x_0 = self.conv_1(input)
        x_0 = self.pool(x_0)
        x_1 = self.conv_2(x_0)
        x_1 = self.pool(x_1)
        x_2 = self.conv_3(x_1)
        x_2 = self.pool(x_2)
        # x_out = self.linear(x_2)
        x_out = self.final(x_2)
        return x_out
   
      
            

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchvision.models as models
import os

class CONV_Block(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.LeakyReLU()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out


class preUnet(nn.Module):
    def __init__(self, num_classes=1, input_channels=3, **kwargs):
        super().__init__()
        self.resnet = res2net101_v1b_26w_4s(pretrained=True)

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)


        self.conv_up_1 = CONV_Block(1024, 1024, 512)
        self.conv_up_2 = CONV_Block(1024, 512, 512)
        self.conv_up_3 = CONV_Block(512, 512, 256)
        self.conv_up_4 = CONV_Block(512, 256, 256)
        self.conv_up_5 = CONV_Block(256, 256, 64)
        self.conv_up_6 = CONV_Block(128, 64, 64)
        self.final = nn.Conv2d(64, num_classes, kernel_size=1)



    def forward(self, x):
        
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x_k = self.resnet.maxpool(x)
        x1 = self.resnet.layer1(x_k)
        x2 = self.resnet.layer2(x1)
        
        x3 = self.resnet.layer3(x2)
        
        x_up_1 = self.conv_up_1(self.up(x3))
        x_up_1 = self.conv_up_2(torch.cat([x2, x_up_1], 1))
        
        x_up_2 = self.conv_up_3(self.up(x_up_1))
        x_up_2 = self.conv_up_4(torch.cat([x1, x_up_2], 1))

        x_up_3 = self.conv_up_5(self.up(x_up_2))
        x_up_3 = self.conv_up_6(torch.cat([x, x_up_3], 1))
        
        x_up_4 = self.up(x_up_3)
        output = self.final(x_up_4)
        return output

In [None]:
import argparse
import os
from datetime import datetime
from distutils.dir_util import copy_tree
import torch
import yaml

from torch.autograd import Variable

import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=100, help='epoch number')
parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
parser.add_argument('--trainsize', type=int, default=(512,288), help='training dataset size')
parser.add_argument('--dataset', type=str, default='kvasir', help='dataset name')
parser.add_argument('--split', type=float, default=1, help='training data ratio')
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--ratio', type=float, default=0.5, help='labeled data ratio')
opt = parser.parse_args()
pixel_wise_contrastive_loss_criter = ConLoss()
contrastive_loss_sup_criter = contrastive_loss_sup()


def adjust_lr(optimizer, init_lr, epoch, max_epoch):
    lr_ = init_lr * (1.0 - epoch / max_epoch) ** 0.9
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_


class Network(object):
    def __init__(self):
        self.patience = 0
        self.best_dice_coeff_1 = False
        self.best_dice_coeff_2 = False
        self.model_1 = preUnet()
        self.model_2 = preUnet()
        self.projector_1 = projectors()
        self.projector_2 = projectors()
        self.classifier_1 = classifier()
        self.classifier_2 = classifier()
        self.best_mIoU, self.best_dice_coeff = 0, 0
        self._init_configure()
        self._init_logger()

    def _init_configure(self):
        with open('configs/config.yml') as fp:
            self.cfg = yaml.safe_load(fp)

    def _init_logger(self):

        log_dir = 'logs/' + opt.dataset + '/train/'

        self.logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))

        self.save_path = log_dir
        self.image_save_path_1 = log_dir + "/saved_images_1"
        self.image_save_path_2 = log_dir + "/saved_images_2"

        create_dir(self.image_save_path_1)
        create_dir(self.image_save_path_2)

        self.save_tbx_log = self.save_path + '/tbx_log'
        self.writer = SummaryWriter(self.save_tbx_log)

    def run(self):
        print('Learning Rate:'.format(opt.lr))
        self.model_1.cuda()
        self.model_2.cuda()
        
        params = list(self.model_1.parameters()) + list(self.model_2.parameters())

        optimizer = torch.optim.Adam(params,lr=opt.lr)


        image_root = './data/'+ opt.dataset +'/train/image/'
        gt_root = './data/'+ opt.dataset +'/train/mask/'
        val_img_root = './data/'+ opt.dataset +'/test/image/'
        val_gt_root = './data/'+ opt.dataset +'/test/mask/'


        self.logger.info("Split Percentage : {} Labeled Data Ratio : {}".format(opt.split, opt.ratio))
        train_loader_1, train_loader_2, unlabeled_train_loader, val_loader = image_loader(image_root, gt_root,val_img_root,val_gt_root,
                                                                                          opt.batchsize, opt.trainsize,
                                                                                          opt.split, opt.ratio)
        self.logger.info(
            "train_loader_1 {} train_loader_2 {} unlabeled_train_loader {} val_loader {}".format(len(train_loader_1),
                                                                                                 len(train_loader_2),
                                                                                                 len(unlabeled_train_loader),
                                                                                                 len(val_loader)))
        print("Let's go!")
        for epoch in range(1, opt.epoch):

            running_loss = 0.0
            running_dice_val_1 = 0.0
            running_dice_val_2 = 0.0
            

            for i, data in enumerate(zip(train_loader_1, train_loader_2, unlabeled_train_loader)):

                inputs_S1, labels_S1 = data[0][0], data[0][1]
                inputs_S2, labels_S2 = data[1][0], data[1][1]
                inputs_U, labels_U = data[2][0], data[2][1]

                inputs_S1, labels_S1 = Variable(inputs_S1), Variable(labels_S1)
                inputs_S1, labels_S1 = inputs_S1.cuda(), labels_S1.cuda()
                inputs_S2, labels_S2 = Variable(inputs_S2), Variable(labels_S2)
                inputs_S2, labels_S2 = inputs_S2.cuda(), labels_S2.cuda()
                inputs_U = Variable(inputs_U)
                inputs_U = inputs_U.cuda()

                optimizer.zero_grad()
                prediction_1 = self.model_1(inputs_S1)
                prediction_1_1 = torch.sigmoid(prediction_1)

                feat_1 = self.model_1(inputs_U)
                u_prediction_1 = torch.sigmoid(feat_1)
                 self.projector_1.cuda()
                self.projector_2.cuda()
                self.classifier_1.cuda()
                self.classifier_2.cuda()
                feat_q = self.projector_1(feat_1)
                feat_k = self.projector_2(feat_2)
                feat_l_q = self.classifier_1(prediction_1)
                feat_l_k = self.classifier_2(prediction_2)
                Loss_sup = loss_sup(prediction_1_1, prediction_2_2, labels_S1, labels_S2)
                Loss_diff = loss_diff(u_prediction_1, u_prediction_2, opt.batchsize)
                Loss_contrast = pixel_wise_contrastive_loss_criter(feat_q,feat_k)
                Loss_contrast_2 = contrastive_loss_sup_criter(feat_l_q,feat_l_k)
                

                seg_loss = 0.25*Loss_sup +0.25*Loss_diff +0.25*Loss_contrast+0.25*Loss_contrast_2
                
                seg_loss.backward()
                running_loss += seg_loss.item()
                optimizer.step()
                
                adjust_lr(optimizer, opt.lr, epoch, opt.epoch)
                
                    


            epoch_loss = running_loss / (len(train_loader_1) + len(train_loader_2))
            self.logger.info('{} Epoch [{:03d}/{:03d}], total_loss : {:.4f}'.
                             format(datetime.now(), epoch, opt.epoch, epoch_loss))

            self.logger.info('Train loss: {}'.format(epoch_loss))
            self.writer.add_scalar('Train/Loss', epoch_loss, epoch)

            for i, pack in enumerate(val_loader, start=1):
                with torch.no_grad():
                    images, gts = pack
                    images = Variable(images)
                    gts = Variable(gts)
                    images = images.cuda()
                    gts = gts.cuda()

                    prediction_1 = self.model_1(images)
                    prediction_1 = torch.sigmoid(prediction_1)

                    prediction_2 = self.model_2(images)
                    prediction_2 = torch.sigmoid(prediction_2)

                dice_coe_1 = dice_coef(prediction_1, gts)
                running_dice_val_1 += dice_coe_1
                dice_coe_2 = dice_coef(prediction_2, gts)
                running_dice_val_2 += dice_coe_2

            epoch_dice_val_1 = running_dice_val_1 / len(val_loader)

            self.logger.info('Validation dice coeff model 1: {}'.format(epoch_dice_val_1))
            self.writer.add_scalar('Validation_1/DSC', epoch_dice_val_1, epoch)

            epoch_dice_val_2 = running_dice_val_2 / len(val_loader)

            self.logger.info('Validation dice coeff model 1: {}'.format(epoch_dice_val_2))
            self.writer.add_scalar('Validation_1/DSC', epoch_dice_val_2, epoch)

            mdice_coeff_1 = epoch_dice_val_1
            mdice_coeff_2 = epoch_dice_val_2

            if self.best_dice_coeff_1 < mdice_coeff_1:
                self.best_dice_coeff_1 = mdice_coeff_1
                self.save_best_model_1 = True

                if not os.path.exists(self.image_save_path_1):
                    os.makedirs(self.image_save_path_1)

                copy_tree(self.image_save_path_1, self.save_path + '/best_model_predictions_1')
                self.patience = 0
            else:
                self.save_best_model_1 = False
                self.patience += 1

            if self.best_dice_coeff_2 < mdice_coeff_2:
                self.best_dice_coeff_2 = mdice_coeff_2
                self.save_best_model_2 = True

                if not os.path.exists(self.image_save_path_2):
                    os.makedirs(self.image_save_path_2)

                copy_tree(self.image_save_path_2, self.save_path + '/best_model_predictions_2')
                self.patience = 0
            else:
                self.save_best_model_2 = False
                self.patience += 1

            Checkpoints_Path = self.save_path + '/Checkpoints'

            if not os.path.exists(Checkpoints_Path):
                os.makedirs(Checkpoints_Path)

            if self.save_best_model_1:
                torch.save(self.model_1.state_dict(), Checkpoints_Path + '/Model_1.pth')
            if self.save_best_model_2:
                torch.save(self.model_2.state_dict(), Checkpoints_Path + '/Model_2.pth')

            self.logger.info(
                'current best dice coef model 1 {}, model 2 {}'.format(self.best_dice_coeff_1, self.best_dice_coeff_2))
            self.logger.info('current patience :{}'.format(self.patience))


if __name__ == '__main__':
    train_network = Network()
    train_network.run()

In [None]:
import argparse
import glob
import os

import imageio
import numpy as np
import torch
import yaml
from PIL import Image
from sklearn.metrics import f1_score, mean_absolute_error
from torch.autograd import Variable
from torchvision import transforms
from utils import get_logger, create_dir
from model.pretrained_unet import preUnet
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

parser = argparse.ArgumentParser()
parser.add_argument('--batchsize', type=int, default=1, help='training batch size')
parser.add_argument('--trainsize', type=int, default=(512,288), help='training dataset size')
parser.add_argument('--dataset', type=str, default='kvasir', help='dataset name')
parser.add_argument('--threshold', type=float, default=0.5, help='threshold')
opt = parser.parse_args()


class Test(object):
    def __init__(self):
        self._init_configure()
        self._init_logger()
        self.model_1 = preUnet()
        self.model_2 = preUnet()

    def _init_configure(self):
        with open('configs/config.yml') as fp:
            self.cfg = yaml.safe_load(fp)

    def _init_logger(self):

        log_dir = 'logs/' + opt.dataset + '/test'

        self.logger = get_logger(log_dir)
        print('RUNDIR: {}'.format(log_dir))

        self.save_path = log_dir
        self.image_save_path_1 = log_dir + "/saved_images_1"
        create_dir(self.image_save_path_1)
        self.image_save_path_2 = log_dir + "/saved_images_2"
        create_dir(self.image_save_path_2)

        self.model_1_load_path = 'logs/' + opt.dataset + '/content/gdrive/MyDrive/dataset'

    def visualize_val_input(self, var_map, i):
        count = i
        im = transforms.ToPILImage()(var_map.squeeze_(0).detach().cpu()).convert("RGB")
        name = '{:02d}_input.png'.format(count)
        imageio.imwrite(self.image_save_path_1 + "/val_" + name, im)

    def visualize_gt(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_gt.png'.format(count)
            imageio.imwrite(self.image_save_path_1 + "/val_" + name, pred_edge_kk)
            imageio.imwrite(self.image_save_path_2 + "/val_" + name, pred_edge_kk)

    def visualize_prediction1(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_pred_1.png'.format(count)
            imageio.imwrite(self.image_save_path_1 + "/val_" + name, pred_edge_kk)

    def visualize_prediction2(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_pred_2.png'.format(count)
            imageio.imwrite(self.image_save_path_2 + "/val_" + name, pred_edge_kk)

    def visualize_uncertainity(self, var_map, i):
        count = i
        for kk in range(var_map.shape[0]):
            pred_edge_kk = var_map[kk, :, :, :]
            pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
            pred_edge_kk *= 255.0
            pred_edge_kk = pred_edge_kk.astype(np.uint8)
            name = '{:02d}_pred.png'.format(count)
            imageio.imwrite(self.image_save_path_1 + "/uncertainity_" + name, pred_edge_kk)

    def evaluate_model_1(self, image_dir):

        target_list = np.array([])
        output_list = np.array([])
        output_pred_list = np.array([])
        test_dir = image_dir
        self.logger.info(test_dir)

        pred_files = glob.glob(test_dir + 'val_*_pred_1.png')
        gt_files = glob.glob(test_dir + 'val_*_gt.png')

        for file in pred_files:
            image = Image.open(file)
            output = np.asarray(image)
            output = output.flatten() / 255
            output_pred = (output > opt.threshold)
            output_list = np.concatenate((output_list, output), axis=None)
            output_pred_list = np.concatenate((output_pred_list, output_pred), axis=None)

        for file in gt_files:
            image = Image.open(file)
            target = np.asarray(image)
            target = target.flatten() / 255
            target = (target > opt.threshold)
            target_list = np.concatenate((target_list, target), axis=None)

        # F1 score
        F1_score = f1_score(target_list, output_pred_list)
        self.logger.info("Model 1 F1 score : {} ".format(F1_score))

        # Mean Absolute Error
        mae = mean_absolute_error(target_list, output_pred_list)
        self.logger.info("Model 1 MAE : {} ".format(mae))

    def evaluate_model_2(self, image_dir):

        target_list = np.array([])
        output_list = np.array([])
        output_pred_list = np.array([])
        test_dir = image_dir
        self.logger.info(test_dir)

        pred_files = glob.glob(test_dir + 'val_*_pred_2.png')
        gt_files = glob.glob(test_dir + 'val_*_gt.png')

        for file in pred_files:
            image = Image.open(file)
            output = np.asarray(image)
            output = output.flatten() / 255
            output_pred = (output > opt.threshold)
            output_list = np.concatenate((output_list, output), axis=None)
            output_pred_list = np.concatenate((output_pred_list, output_pred), axis=None)

        for file in gt_files:
            image = Image.open(file)
            target = np.asarray(image)
            target = target.flatten() / 255
            target = (target > opt.threshold)
            target_list = np.concatenate((target_list, target), axis=None)
        F1_score = f1_score(target_list, output_pred_list)
        self.logger.info("Model 2 F1 score : {} ".format(F1_score))
        mae = mean_absolute_error(target_list, output_pred_list)
        self.logger.info("Model 2 MAE : {} ".format(mae))

    def run(self):

        self.model_1.load_state_dict(torch.load(self.model_1_load_path))
        self.model_1.cuda()

        self.model_2.load_state_dict(torch.load(self.model_2_load_path))
        self.model_2.cuda()


        
        image_root = './data/'+ opt.dataset +'/train/image/'
        gt_root = './data/'+ opt.dataset +'/train/mask/'
        val_img_root = './data/'+ opt.dataset +'/test/image/'
        val_gt_root = './data/'+ opt.dataset +'/test/mask/'

        _, _, _, val_loader = image_loader(image_root, gt_root,val_img_root,val_gt_root, opt.batchsize, opt.trainsize)

        for i, pack in enumerate(val_loader, start=1):
            with torch.no_grad():
                images, gts = pack
                images = Variable(images)
                gts = Variable(gts)
                images = images.cuda()
                gts = gts.cuda()

                feat_map_1 = self.model_1(images)
                prediction1 = torch.sigmoid(feat_map_1)

                feat_map_2 = self.model_2(images)
                prediction2 = torch.sigmoid(feat_map_2)

            self.visualize_val_input(images, i)
            self.visualize_gt(gts, i)
            self.visualize_prediction1(prediction1, i)
            self.visualize_prediction2(prediction2, i)

        self.evaluate_model_1('logs/kvasir/test/saved_images_1/')
        self.evaluate_model_2('logs/kvasir/test/saved_images_2/')


if __name__ == '__main__':
    Test_network = Test()
    Test_network.run()

In [5]:
def dice_coef(output, target):
    smooth = 1e-5

    output = output.view(-1).data.cpu().numpy()
    target = target.view(-1).data.cpu().numpy()
    intersection = (output * target).sum()

    return (2. * intersection + smooth) / (output.sum() + target.sum() + smooth)


In [6]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
BCE = torch.nn.BCELoss()

def weighted_loss(pred, mask):
    
    weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask)
    wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3))
    inter = ((pred * mask)*weit).sum(dim=(2, 3))
    union = ((pred + mask)*weit).sum(dim=(2, 3))
    wiou = 1 - (inter + 1)/(union - inter+1)
    
    return (wbce + wiou).mean()



def calc_loss(pred, target, bce_weight=0.5):
    bce = weighted_loss(pred, target)
    return bce


def loss_sup(logit_S1, logit_S2, labels_S1, labels_S2):
    loss1 = calc_loss(logit_S1, labels_S1)
    loss2 = calc_loss(logit_S2, labels_S2)

    return loss1 + loss2



def loss_diff(u_prediction_1, u_prediction_2, batch_size):
    a = weighted_loss(u_prediction_1, Variable(u_prediction_2, requires_grad=False))
    a = a.item()

    b = weighted_loss(u_prediction_2, Variable(u_prediction_1, requires_grad=False))
    b = b.item()

    loss_diff_avg = (a + b)
    return loss_diff_avg / batch_size




In [None]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import pdb
import math
import torchvision
 
 
class Grid(object):
    def __init__(self, d1, d2, rotate=1, ratio=0.5, mode=0, prob=1.):
        self.d1 = d1
        self.d2 = d2
        self.rotate = rotate
        self.ratio = ratio
        self.mode = mode
        self.st_prob = self.prob = prob
 
    def set_prob(self, epoch, max_epoch):
        self.prob = self.st_prob * min(1, epoch / max_epoch)
 
    def __call__(self, img):
        if np.random.rand() > self.prob:
            return img
        h = img.size(1)
        w = img.size(2)
        hh = math.ceil((math.sqrt(h * h + w * w)))
 
        d = np.random.randint(self.d1, self.d2)
        self.l = math.ceil(d * self.ratio)
 
        mask = np.ones((hh, hh), np.float32)
        st_h = np.random.randint(d)
        st_w = np.random.randint(d)
        for i in range(-1, hh // d + 1):
            s = d * i + st_h
            t = s + self.l
            s = max(min(s, hh), 0)
            t = max(min(t, hh), 0)
            mask[s:t, :] *= 0
 
        for i in range(-1, hh // d + 1):
            s = d * i + st_w
            t = s + self.l
            s = max(min(s, hh), 0)
            t = max(min(t, hh), 0)
            mask[:, s:t] *= 0
 
        r = np.random.randint(self.rotate)
        mask = Image.fromarray(np.uint8(mask))
        mask = mask.rotate(r)
        mask = np.asarray(mask)
        mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (hh - w) // 2:(hh - w) // 2 + w]
 
        mask = torch.from_numpy(mask).float()
        if self.mode == 1:
            mask = 1 - mask
 
        mask = mask.expand_as(img)
        img = img * mask
 
        return img
 
 
class GridMask(nn.Module):
    def __init__(self, d1=20, d2=80, rotate=90, ratio=0.4, mode=1, prob=0.8):
        super(GridMask, self).__init__()
        self.rotate = rotate
        self.ratio = ratio
        self.mode = mode
        self.st_prob = prob
        self.grid = Grid(d1, d2, rotate, ratio, mode, prob)
 
    def set_prob(self, epoch, max_epoch):
        self.grid.set_prob(epoch, max_epoch)
 
    
    def forward(self, x):
        if not self.training:
            return x
        
        return self.grid(x)
    
if __name__ == '__main__':
    import cv2
    from torchvision import transforms
     
    img = cv2.imread('./data/kvasir/train/image/ckcu8xad600033b5yc78xfyjx.jpg')
    img = torchvision.transforms.ToTensor()(img)
    grid_mask = GridMask()
    img = grid_mask(img)
     
    img = img.mul(255).byte()
    img = img.numpy().transpose((1, 2, 0))
    cv2.imwrite('gridmask.jpg', img)

In [8]:
import os
from PIL import Image
import torch.utils.data as data
from torchvision.transforms import transforms
import numpy as np
import torch
import random
class ObjDataset(data.Dataset):
    def __init__(self, images, gts, trainsize, mode):
        self.trainsize = trainsize
        self.images = images
        self.mode = mode
        self.gts = gts
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.filter_files()
        self.size = len(self.images)
        self.gridmask = GridMask()
        self.img_transform_w = transforms.Compose([
            transforms.Resize((self.trainsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        
        self.gt_transform_w = transforms.Compose([
            transforms.Resize((self.trainsize)),
            transforms.ToTensor()])
        
        
        
        self.img_transform_s = transforms.Compose([
            transforms.RandomRotation(90),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees = 90,translate=(0.5,0.5),shear=30),
            transforms.ColorJitter(hue = 0.5),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(3),
            transforms.Resize((self.trainsize)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        
        self.gt_transform_s = transforms.Compose([
            transforms.RandomRotation(90),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomAffine(degrees = 90,translate=(0.5,0.5),shear=30),
            transforms.Resize((self.trainsize)),
            transforms.ToTensor()])

    def __getitem__(self, index):
        image = self.rgb_loader(self.images[index])
        gt = self.binary_loader(self.gts[index])
        
        seed =  np.random.randint(2147483647)
        

        
        if self.mode == 'weak':
            
            image = self.img_transform_w(image)
            gt = self.gt_transform_w(gt)

        if self.mode =='strong':
            torch.manual_seed(seed)
            image = self.img_transform_s(image)
            torch.manual_seed(seed)
            gt = self.gt_transform_s(gt)
            image = self.gridmask(image)
            

        return image, gt

    def filter_files(self):
        assert len(self.images) == len(self.gts)
        images = []
        gts = []
        for img_path, gt_path in zip(self.images, self.gts):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            if img.size == gt.size:
                images.append(img_path)
                gts.append(gt_path)
        self.images = images
        self.gts = gts

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('L')

    def resize(self, img, gt):
        assert img.size == gt.size
        w, h = img.size
        if h < self.trainsize or w < self.trainsize:
            h = max(h, self.trainsize)
            w = max(w, self.trainsize)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
        else:
            return img, gt

    def __len__(self):
        return self.size


class ValObjDataset(data.Dataset):
    def __init__(self, images, gts, trainsize):
        self.trainsize = trainsize
        self.images = images
        self.gts = gts
        self.images = sorted(self.images)
        self.gts = sorted(self.gts)
        self.filter_files()
        self.size = len(self.images)
        self.img_transform = transforms.Compose([
            transforms.Resize((self.trainsize)),
            transforms.ToTensor()])
        self.gt_transform = transforms.Compose([
            transforms.Resize((self.trainsize)),
            transforms.ToTensor()])

    def __getitem__(self, index):
        image = self.rgb_loader(self.images[index])
        gt = self.binary_loader(self.gts[index])

        image = self.img_transform(image)
        gt = self.gt_transform(gt)

        return image, gt

    def filter_files(self):
        assert len(self.images) == len(self.gts)
        images = []
        gts = []
        for img_path, gt_path in zip(self.images, self.gts):
            img = Image.open(img_path)
            gt = Image.open(gt_path)
            if img.size == gt.size:
                images.append(img_path)
                gts.append(gt_path)
        self.images = images
        self.gts = gts

    def rgb_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            return img.convert('RGB')

    def binary_loader(self, path):
        with open(path, 'rb') as f:
            img = Image.open(f)
            # return img.convert('1')
            return img.convert('L')

    def resize(self, img, gt):
        assert img.size == gt.size
        w, h = img.size
        if h < 256 or w < 256:
            h = max(h, 256)
            w = max(w, 256)
            return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
        else:
            return img, gt

    def __len__(self):
        return self.size


def image_loader(image_root, gt_root,val_img_root,val_gt_root, batch_size, image_size, split=1, labeled_ratio=0.05,mode='weak_1'):
    images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')]
    gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') or f.endswith('.png')]
    
    val_img = [val_img_root+ f for f in os.listdir(val_img_root) if f.endswith('.jpg') or f.endswith('.png')]
    val_label = [val_gt_root+ f for f in os.listdir(val_gt_root) if f.endswith('.jpg') or f.endswith('.png')]
    
    train_images = images[0:int(len(images) * split)]
    val_images = val_img[0:int(len(val_img) * split)] 
    val_gts = val_label[0:int(len(val_label) * split)]
    train_gts = gts[0:int(len(images) * split)]


    labeled_train_images = train_images[0:int(len(train_images) * labeled_ratio)] 
    labeled_train_images_1 = labeled_train_images[0:int(len(labeled_train_images) * 0.5)] 
    labeled_train_images_2 = labeled_train_images[int(len(labeled_train_images) * 0.5):]
    unlabeled_train_images = train_images[int(len(train_images) * labeled_ratio):] 
    labeled_train_gts = train_gts[0:int(len(train_gts) * labeled_ratio)]
    labeled_train_gts_1 = labeled_train_gts[0:int(len(labeled_train_gts) * 0.5)]
    labeled_train_gts_2 = labeled_train_gts[int(len(labeled_train_gts) * 0.5):]
    unlabeled_train_gts = train_gts[int(len(train_gts) * labeled_ratio):]

    labeled_train_dataset_1 = ObjDataset(labeled_train_images_1, labeled_train_gts_1, image_size,mode='weak')
    labeled_train_dataset_2 = ObjDataset(labeled_train_images_2, labeled_train_gts_2, image_size,mode='weak')
    unlabeled_train_dataset = ObjDataset(unlabeled_train_images, unlabeled_train_gts, image_size,mode='strong')
    val_dataset = ValObjDataset(val_images, val_gts, image_size)

    labeled_data_loader_1 = data.DataLoader(dataset=labeled_train_dataset_1,
                                  batch_size=batch_size,
                                  num_workers=1,
                                  pin_memory=True,
                                  shuffle=True)

    labeled_data_loader_2 = data.DataLoader(dataset=labeled_train_dataset_2,
                                            batch_size=batch_size,
                                            num_workers=1,
                                            pin_memory=True,
                                            shuffle=True)

    unlabeled_data_loader = data.DataLoader(dataset=unlabeled_train_dataset,
                                          batch_size=batch_size,
                                          num_workers=1,
                                          pin_memory=True,
                                          shuffle=True)

    val_loader = data.DataLoader(dataset=val_dataset,
                                 batch_size=batch_size,
                                 num_workers=1,
                                 pin_memory=True,
                                 shuffle=False)

    return labeled_data_loader_1, labeled_data_loader_2, unlabeled_data_loader, val_loader

In [9]:
import torch
from torch import nn
import torch.nn.functional as F
import warnings

warnings.filterwarnings("ignore")

class ConLoss(torch.nn.Module):
    def __init__(self, temperature=0.07, base_temperature=0.07):
        """
        Contrastive Learning for Unpaired Image-to-Image Translation
        models/patchnce.py
        """
        super(ConLoss, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.nce_includes_all_negatives_from_minibatch = False
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.mask_dtype = torch.bool

    def forward(self, feat_q, feat_k):
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        batch_size = feat_q.shape[0]
        dim = feat_q.shape[1]
        width = feat_q.shape[2]
        feat_q = feat_q.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_k = feat_k.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_q = F.normalize(feat_q, dim=-1, p=1)
        feat_k = F.normalize(feat_k, dim=-1, p=1)
        feat_k = feat_k.detach()
        l_pos = torch.bmm(feat_q.reshape(-1, 1, dim), feat_k.reshape(-1, dim, 1))
        l_pos = l_pos.view(-1, 1)
        if self.nce_includes_all_negatives_from_minibatch:
            batch_dim_for_bmm = 1
        else:
            batch_dim_for_bmm = batch_size
        feat_q = feat_q.reshape(batch_dim_for_bmm, -1, dim)
        feat_k = feat_k.reshape(batch_dim_for_bmm, -1, dim)
        npatches = feat_q.size(1)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))

        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]

        l_neg_curbatch.masked_fill_(diagonal, -10.0)
        l_neg = l_neg_curbatch.view(-1, npatches)

        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss
    
    
class contrastive_loss_sup(torch.nn.Module):
    def __init__(self, temperature=0.07, base_temperature=0.07):
      super(contrastive_loss_sup, self).__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature
        self.nce_includes_all_negatives_from_minibatch = False
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.mask_dtype = torch.bool

    def forward(self, feat_q, feat_k):
        assert feat_q.size() == feat_k.size(), (feat_q.size(), feat_k.size())
        batch_size = feat_q.shape[0]
        dim = feat_q.shape[1]
        width = feat_q.shape[2]
        feat_q = feat_q.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_k = feat_k.view(batch_size, dim, -1).permute(0, 2, 1)
        feat_q = F.normalize(feat_q, dim=-1, p=1)
        feat_k = F.normalize(feat_k, dim=-1, p=1)
        feat_k = feat_k.detach()
        l_pos = torch.zeros((batch_size*2304,1)).cuda()
        if self.nce_includes_all_negatives_from_minibatch:
        else:
            batch_dim_for_bmm = batch_size
        feat_q = feat_q.reshape(batch_dim_for_bmm, -1, dim)
        feat_k = feat_k.reshape(batch_dim_for_bmm, -1, dim)
        npatches = feat_q.size(1)
        l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))

        diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]

        l_neg_curbatch.masked_fill_(diagonal, -10.0)
        l_neg = l_neg_curbatch.view(-1, npatches)

        out = torch.cat((l_pos, l_neg), dim=1) / self.temperature

        loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
                                                        device=feat_q.device))

        return loss  
    
  
    

In [None]:

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F

__all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s']

model_urls = {
    'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth',
    'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth',
}


class Bottle2neck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'):
        
        super(Bottle2neck, self).__init__()

        width = int(math.floor(planes * (baseWidth / 64.0)))
        self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width * scale)

        if scale == 1:
            self.nums = 1
        else:
            self.nums = scale - 1
        if stype == 'stage':
            self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
        convs = []
        bns = []
        for i in range(self.nums):
            convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False))
            bns.append(nn.BatchNorm2d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)

        self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stype = stype
        self.scale = scale
        self.width = width

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0 or self.stype == 'stage':
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(self.bns[i](sp))
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        if self.scale != 1 and self.stype == 'normal':
            out = torch.cat((out, spx[self.nums]), 1)
        elif self.scale != 1 and self.stype == 'stage':
            out = torch.cat((out, self.pool(spx[self.nums])), 1)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Res2Net(nn.Module):

    def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000):
        self.inplanes = 64
        super(Res2Net, self).__init__()
        self.baseWidth = baseWidth
        self.scale = scale
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, 3, 1, 1, bias=False)
        )
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.AvgPool2d(kernel_size=stride, stride=stride,
                             ceil_mode=True, count_include_pad=False),
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample=downsample,
                            stype='stage', baseWidth=self.baseWidth, scale=self.scale))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def res2net50_v1b(pretrained=False, **kwargs):
    model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s']))
    return model


def res2net101_v1b(pretrained=False, **kwargs):
     model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
    return model


def res2net50_v1b_26w_4s(pretrained=False, **kwargs):
     model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model_state = torch.load('D:/HarDNet-MSEG-master/res2net50_v1b_26w_4s-3cf99910.pth')
        model.load_state_dict(model_state)
    return model


def res2net101_v1b_26w_4s(pretrained=False, **kwargs):
   model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s']))
    return model


def res2net152_v1b_26w_4s(pretrained=False, **kwargs):
     model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s']))
    return model


if __name__ == '__main__':
    images = torch.rand(1, 3, 224, 224).cuda(0)
    model = res2net50_v1b_26w_4s(pretrained=True)
    model = model.cuda(0)
    print(model(images).size())