In [None]:
from collections import OrderedDict
from torch import Tensor
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from torch import Tensor
from typing import Type, Any, Callable, Union, List, Optional
import glob
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Subset,DataLoader
import torchvision.transforms as transforms
import torchvision
import random
from google.colab import files
from sklearn.metrics import  confusion_matrix
from sklearn.model_selection import ShuffleSplit
import cv2
from google.colab.patches import cv2_imshow
from scipy.ndimage import distance_transform_edt
from torch.autograd import Variable
import skimage.segmentation
import skimage.io
import skimage
from scipy.optimize import linear_sum_assignment
import skimage.segmentation
import matplotlib.pyplot as plt
import skimage.io
import skimage.segmentation
from skimage import feature
from skimage import filters
import copy
import torchvision
from collections import OrderedDict
import math
import imageio

In [None]:
def randomHueSaturationValue(
    image,
    hue_shift_limit=(-40, 40),
    sat_shift_limit=(-10, 10),
    val_shift_limit=(-20, 20),
    u=0.5,
):
    if np.random.random() < u:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(image)
        hue_shift = np.random.randint(
            hue_shift_limit[0], hue_shift_limit[1] + 1)
        hue_shift = np.uint8(hue_shift)
        h += hue_shift
        sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
        s = cv2.add(s, sat_shift)
        val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
        v = cv2.add(v, val_shift)
        image = cv2.merge((h, s, v))
        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)

    return image


def randomShiftScaleRotate(
    image,
    mask,
    shift_limit=(-0.1, 0.1),
    scale_limit=(-0.1, 0.1),
    aspect_limit=(-0.1, 0.1),
    rotate_limit=(-0, 0),
    borderMode=cv2.BORDER_CONSTANT,
    u=0.5,
):
    if np.random.random() < u:
        height, width, channel = image.shape

        angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
        scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
        aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
        sx = scale * aspect / (aspect ** 0.5)
        sy = scale / (aspect ** 0.5)
        dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
        dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)

        cc = np.math.cos(angle / 180 * np.math.pi) * sx
        ss = np.math.sin(angle / 180 * np.math.pi) * sy
        rotate_matrix = np.array([[cc, -ss], [ss, cc]])

        box0 = np.array([[0, 0], [width, 0], [width, height], [0, height]])
        box1 = box0 - np.array([width / 2, height / 2])
        box1 = np.dot(box1, rotate_matrix.T) + np.array(
            [width / 2 + dx, height / 2 + dy]
        )

        box0 = box0.astype(np.float32)
        box1 = box1.astype(np.float32)
        mat = cv2.getPerspectiveTransform(box0, box1)
        image = cv2.warpPerspective(
            image,
            mat,
            (width, height),
            flags=cv2.INTER_NEAREST,
            borderMode=borderMode,
            borderValue=(0, 0, 0),
        )
        mask = cv2.warpPerspective(
            mask,
            mat,
            (width, height),
            flags=cv2.INTER_NEAREST,
            borderMode=borderMode,
            borderValue=(0, 0, 0),
        )

    return image, mask


def randomHorizontalFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)

    return image, mask


def randomVerticleFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 0)
        mask = cv2.flip(mask, 0)

    return image, mask


def randomRotate90(image, mask, u=0.5):
    if np.random.random() < u:
        image = np.rot90(image)
        mask = np.rot90(mask)

    return image, mask

In [None]:
def extract_patches_2d(img,patch_shape,step=[1.0,1.0],batch_first=False):
    patch_H, patch_W = patch_shape[0], patch_shape[1]
    if(img.size(2)<patch_H):
        num_padded_H_Top = (patch_H - img.size(2))//2
        num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
        padding_H = nn.ConstantPad2d((0,0,num_padded_H_Top,num_padded_H_Bottom),0)
        img = padding_H(img)
    if(img.size(3)<patch_W):
        num_padded_W_Left = (patch_W - img.size(3))//2
        num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
        padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right,0,0),0)
        img = padding_W(img)
    step_int = [0,0]
    step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
    step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
    patches_fold_H = img.unfold(2, patch_H, step_int[0])
    if((img.size(2) - patch_H) % step_int[0] != 0):
        patches_fold_H = torch.cat((patches_fold_H,img[:,:,-patch_H:,].permute(0,1,3,2).unsqueeze(2)),dim=2)
    patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
    if((img.size(3) - patch_W) % step_int[1] != 0):
        patches_fold_HW = torch.cat((patches_fold_HW,patches_fold_H[:,:,:,-patch_W:,:].permute(0,1,2,4,3).unsqueeze(3)),dim=3)
    patches = patches_fold_HW.permute(2,3,0,1,4,5)
    patches = patches.reshape(-1,img.size(0),img.size(1),patch_H,patch_W)
    if(batch_first):
        patches = patches.permute(1,0,2,3,4)
    return patches

def reconstruct_from_patches_2d(patches,img_shape,step=[1.0,1.0],batch_first=False):
    if(batch_first):
        patches = patches.permute(1,0,2,3,4)
    patch_H, patch_W = patches.size(3), patches.size(4)
    img_size = (patches.size(1), patches.size(2),max(img_shape[0], patch_H), max(img_shape[1], patch_W))
    step_int = [0,0]
    step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
    step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
    nrow, ncol = 1 + (img_size[-2] - patch_H)//step_int[0], 1 + (img_size[-1] - patch_W)//step_int[1]
    r_nrow = nrow + 1 if((img_size[2] - patch_H) % step_int[0] != 0) else nrow
    r_ncol = ncol + 1 if((img_size[3] - patch_W) % step_int[1] != 0) else ncol
    patches = patches.reshape(r_nrow,r_ncol,img_size[0],img_size[1],patch_H,patch_W)
    img = torch.zeros(img_size, device = patches.device)
    overlap_counter = torch.zeros(img_size, device = patches.device)
    for i in range(nrow):
        for j in range(ncol):
            img[:,:,i*step_int[0]:i*step_int[0]+patch_H,j*step_int[1]:j*step_int[1]+patch_W] += patches[i,j,]
            overlap_counter[:,:,i*step_int[0]:i*step_int[0]+patch_H,j*step_int[1]:j*step_int[1]+patch_W] += 1
    if((img_size[2] - patch_H) % step_int[0] != 0):
        for j in range(ncol):
            img[:,:,-patch_H:,j*step_int[1]:j*step_int[1]+patch_W] += patches[-1,j,]
            overlap_counter[:,:,-patch_H:,j*step_int[1]:j*step_int[1]+patch_W] += 1
    if((img_size[3] - patch_W) % step_int[1] != 0):
        for i in range(nrow):
            img[:,:,i*step_int[0]:i*step_int[0]+patch_H,-patch_W:] += patches[i,-1,]
            overlap_counter[:,:,i*step_int[0]:i*step_int[0]+patch_H,-patch_W:] += 1
    if((img_size[2] - patch_H) % step_int[0] != 0 and (img_size[3] - patch_W) % step_int[1] != 0):
        img[:,:,-patch_H:,-patch_W:] += patches[-1,-1,]
        overlap_counter[:,:,-patch_H:,-patch_W:] += 1
    img /= overlap_counter
    if(img_shape[0]<patch_H):
        num_padded_H_Top = (patch_H - img_shape[0])//2
        num_padded_H_Bottom = patch_H - img_shape[0] - num_padded_H_Top
        img = img[:,:,num_padded_H_Top:-num_padded_H_Bottom,]
    if(img_shape[1]<patch_W):
        num_padded_W_Left = (patch_W - img_shape[1])//2
        num_padded_W_Right = patch_W - img_shape[1] - num_padded_W_Left
        img = img[:,:,:,num_padded_W_Left:-num_padded_W_Right]
    return img


In [None]:
image_path_train=glob.glob('/content/drive/MyDrive/3. The cropped image tiles and raster labels/train/image*/**.tif')
image_path_train.sort()
label_path_train=glob.glob('/content/drive/MyDrive/3. The cropped image tiles and raster labels/train/label*/**.tif')
label_path_train.sort()

In [None]:
image_path_val=glob.glob('/content/drive/MyDrive/3. The cropped image tiles and raster labels/val/image*/**.tif')
image_path_val.sort()
label_path_val=glob.glob('/content/drive/MyDrive/3. The cropped image tiles and raster labels/val/label*/**.tif')
label_path_val.sort()

In [None]:
image_path_test1=glob.glob('/content/drive/MyDrive/3. The cropped image tiles and raster labels/test/image*/**.tif')
image_path_test1.sort()
image_path_test=image_path_test1[1:301]

label_path_test1=glob.glob('/content/drive/MyDrive/3. The cropped image tiles and raster labels/test/label*/**.tif')
label_path_test1.sort()
label_path_test=label_path_test1[1:301]


In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_paths, target_paths):

        self.image_paths = image_paths
        self.target_paths = target_paths
        self.transforms = transforms.ToTensor()

    def __getitem__(self, index):

        image = cv2.imread(self.image_paths[index])
        mask = cv2.imread(self.target_paths[index])
        image = self.transforms(image[:,:,:])
        mask = self.transforms(mask[:,:,0:1])

        return image, mask

    def __len__(self):
        return len(self.image_paths)


train_dataset = CustomDataset(image_path_train, label_path_train)
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

val_dataset = CustomDataset(image_path_val, label_path_val)
val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=2)

test_dataset = CustomDataset(image_path_test, label_path_test)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

print(len(train_dataset))
print(len(val_dataset))
print(len(test_dataset))

4736
1036
300


In [None]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


In [None]:
class Upsample(nn.Module):
    """ nn.Upsample is deprecated """

    def __init__(self, scale_factor, mode="bilinear"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=True, recompute_scale_factor=True)
        return x

In [None]:
# Index Pooling
class pool(nn.Module):
  def __init__(self,channels):
    super(pool, self).__init__()
    self.channels=channels
    self.weight1 = torch.zeros((channels,1,2,2)).cuda()
    self.weight2 = torch.zeros((channels,1,2,2)).cuda()
    self.weight3 = torch.zeros((channels,1,2,2)).cuda()
    self.weight4 = torch.zeros((channels,1,2,2)).cuda()
    self.weight1[:,:,0,0]=1
    self.weight2[:,:,0,1]=1
    self.weight3[:,:,1,0]=1
    self.weight4[:,:,1,1]=1
  def forward(self, x):
    with torch.no_grad():
         x1=F.conv2d(x, self.weight1,stride=2,groups=self.channels, bias=None)
         x2=F.conv2d(x, self.weight2,stride=2,groups=self.channels, bias=None)
         x3=F.conv2d(x, self.weight3,stride=2,groups=self.channels, bias=None)
         x4=F.conv2d(x, self.weight4,stride=2,groups=self.channels, bias=None)
    return x1,x2,x3,x4

In [None]:
#DAMIP Module
class attn_pool(nn.Module):
    def __init__(self,feature_channels):
       super(attn_pool, self).__init__()
       self.pool1=pool(feature_channels)
       self.pool2=pool(1)
       self.conv1 = nn.Conv2d(feature_channels*4, 2*feature_channels, kernel_size=1, stride=1, bias=False)
       self.conv2 = nn.Conv2d(2*feature_channels,2*feature_channels,kernel_size=7,stride=1,padding=3)
       self.bn2 = nn.BatchNorm2d(2*feature_channels)
       self.a1=nn.Parameter(torch.Tensor(1))
       self.a2=nn.Parameter(torch.Tensor(1))
       self.a3=nn.Parameter(torch.Tensor(1))
       self.a4=nn.Parameter(torch.Tensor(1))
    def forward(self,map,feature):
        feature1,feature2,feature3,feature4=self.pool1(feature)
        map1,map2,map3,map4=self.pool2(map)

        fm1 = self.a1*feature1 + feature1*map1
        fm2 = self.a2*feature2 + feature2*map2
        fm3 = self.a3*feature3 + feature3*map3
        fm4 = self.a4*feature4 + feature4*map4

        mat=torch.cat((fm1,fm2,fm3,fm4),1)
        mat=self.conv1(mat)
        mat=F.relu(self.bn2(self.conv2(mat)))
        return mat

In [None]:
# DPMG Module
class dsup(nn.Module):
  def __init__(self,input_channels):
    super(dsup,self).__init__()
    self.conv1=conv3x3(input_channels,input_channels//2)
    self.bn1=nn.BatchNorm2d(input_channels//2)
    self.conv2=conv3x3(input_channels//2,32)
    self.conv3=nn.Conv2d(32,1,kernel_size=3,stride=1,padding=1)
  def forward(self,x):
    x=F.relu(self.bn1(self.conv1(x)))
    x=self.conv2(x)
    x=self.conv3(x)
    return torch.sigmoid(x)

In [None]:
# Dilated Convolutional Block
class conv_enc(nn.Module):
  def __init__(self,in_channels,out_channels,dil):
    super(conv_enc,self).__init__()
    self.conv1 = conv1x1(in_channels,in_channels)
    self.bn1 = nn.BatchNorm2d(in_channels)
    self.conv2 = conv3x3(in_channels,in_channels,dilation=dil)
    self.bn2 = nn.BatchNorm2d(in_channels)
    self.conv3 = conv1x1(in_channels,out_channels)
    self.bn3 = nn.BatchNorm2d(out_channels)
    self.conv4 = conv3x3(out_channels,out_channels,dilation=dil)
    self.bn4 = nn.BatchNorm2d(out_channels)
    # input dimension matching
    self.conv0 = conv1x1(in_channels,out_channels)
  def forward(self,x):
    identity = self.conv0(x)

    x=F.relu(self.bn1(self.conv1(x)))
    x=F.relu(self.bn2(self.conv2(x)))
    x=F.relu(self.bn3(self.conv3(x)))
    x=self.bn4(self.conv4(x))

    return F.relu(x+identity)

In [None]:
class enc(nn.Module):
  def __init__(self,input_channels,output_channels,dil):
    super(enc,self).__init__()
    self.conv1 = conv_enc(input_channels,output_channels,dil)
    self.conv2 = conv_enc(output_channels,output_channels,2*dil)
    self.dp_sup = dsup(output_channels)

  def forward(self,x):
    x1 = self.conv1(x)
    x1 = self.conv2(x1)
    x1_out = self.dp_sup(x1)
    return x1, x1_out

In [None]:
# DAMSCA Module
class kqcbam(nn.Module):
  def __init__(self,input_channels,scale_factor=2):
    super(kqcbam,self).__init__()
    self.conv1=nn.Conv2d(1,input_channels,kernel_size=1)
    self.gap=nn.AdaptiveAvgPool2d((1,1))
    self.upsample=Upsample(scale_factor)
  def forward(self,map,feature):
    f1=map*feature
    map2=self.conv1(map)
    map2=self.gap(map2)
    f2=torch.sigmoid(map2)*feature
    out=F.relu(f1+f2)
    return self.upsample(out)

In [None]:
# Decoder Module
class decoder(nn.Module):
  def __init__(self,input_channels):
    super(decoder,self).__init__()
    self.conv1=nn.ConvTranspose2d(input_channels,128,kernel_size=4,stride=2,padding=1)
    self.bn1=nn.BatchNorm2d(128)
    self.conv2=nn.Conv2d(128,64,3,stride=1,padding=1)
    self.bn2=nn.BatchNorm2d(64)
    self.conv3=nn.Conv2d(64,32,3,stride=1,padding=1)
    self.conv_out=nn.Conv2d(32,1,3,stride=1,padding=1)
  def forward(self,x):
    x=F.relu(self.bn1(self.conv1(x)))
    x=F.relu(self.bn2(self.conv2(x)))
    x=self.conv3(x)
    x=self.conv_out(x)
    return torch.sigmoid(x)

In [None]:
# MSSDMPA-Net
class dsmpnet(nn.Module):
  def __init__(self,input_channels):
    super(dsmpnet,self).__init__()
    self.conv1=nn.Conv2d(input_channels,64,kernel_size=7,stride=2,padding=3)
    self.bn1=nn.BatchNorm2d(64)

    self.pool1=attn_pool(64)
    self.pool2=attn_pool(128)
    self.pool3=attn_pool(256)

    self.path1=enc(64,64,1)
    self.path2=enc(128,128,2)
    self.path3=enc(256,256,3)
    self.path4=enc(512,512,4)

    self.cbm1=kqcbam(64,1)
    self.cbm2=kqcbam(128,2)
    self.cbm3=kqcbam(256,4)
    self.cbm4=kqcbam(512,8)

    self.decoder=decoder(960)

  def forward(self,x):
    x=F.relu(self.bn1(self.conv1(x)))
    x1,x1_out=self.path1(x)
    x=self.pool1(x1_out,x)
    x2,x2_out=self.path2(x)
    x=self.pool2(x2_out,x)
    x3,x3_out=self.path3(x)
    x=self.pool3(x3_out,x)
    x4,x4_out=self.path4(x)

    x1=self.cbm1(x1_out,x1)
    x2=self.cbm2(x2_out,x2)
    x3=self.cbm3(x3_out,x3)
    x4=self.cbm4(x4_out,x4)
    x_out=torch.cat((x1,x2,x3,x4),1)
    x_out=self.decoder(x_out)
    return x_out,x1_out,x2_out,x3_out,x4_out


In [None]:
model=dsmpnet(3).cuda()
model=model.to('cuda:0')

In [None]:
class gen_loss(nn.Module):
  def __init__(self,gamma=1.5,batch=True):
    super(gen_loss,self).__init__()
    self.bce_loss=nn.BCELoss()
    self.gamma=gamma

  def gen_dice(self,y_pred,y_true):
    epsilon=1e-8
    l1=abs(y_pred-y_true)**self.gamma
    y_pred_sqsum=torch.sum((y_pred*y_pred))
    y_true_sqsum=torch.sum((y_true*y_true))
    l1_sum=torch.sum(l1)
    score=(l1_sum + epsilon)/(y_pred_sqsum + y_true_sqsum )
    return score.mean()

  def __call__(self,y_pred,y_true):
    a=self.bce_loss(y_pred,y_true)
    b=self.gen_dice(y_pred,y_true)
    return a+b

In [None]:
def y_bce_loss(prediction1,prediction2,prediction3,prediction4,prediction5,label):
    dice=gen_loss()
    loss1=dice(prediction1,label)
    label=torch.nn.functional.interpolate(label, size=(256,256), scale_factor=None, mode='nearest')
    loss2=dice(prediction2,label)
    label=torch.nn.functional.interpolate(label, size=(128,128), scale_factor=None, mode='nearest')
    loss3=dice(prediction3,label)
    label=torch.nn.functional.interpolate(label, size=(64,64), scale_factor=None, mode='nearest')
    loss4=dice(prediction4,label)
    label=torch.nn.functional.interpolate(label, size=(32,32), scale_factor=None, mode='nearest')
    loss5=dice(prediction5,label)
    loss=loss1+loss2+loss3+loss4+loss5
    return loss

In [None]:
class IoU(nn.Module):
    def __init__(self, threshold=0.5):
        super(IoU, self).__init__()
        self.threshold = threshold

    def forward(self, target, input):
        eps = 1e-10
        input_ = (input > self.threshold).data.float()
        target_ = (target > self.threshold).data.float()

        intersection = torch.clamp(input_ * target_, 0, 1)
        union = torch.clamp(input_ + target_, 0, 1)

        if torch.mean(intersection).lt(eps):
            return torch.Tensor([0., 0., 0., 0.])
        else:
            acc = torch.mean((input_ == target_).data.float())
            iou = torch.mean(intersection) / torch.mean(union)
            recall = torch.mean(intersection) / torch.mean(target_)
            precision = torch.mean(intersection) / torch.mean(input_)
            return torch.Tensor([acc, recall, precision, iou])
iou=IoU()

In [None]:
def dice_coeff(y_true,y_pred,batch=True):
        smooth = 1e-8
        if batch:
            i = torch.sum(y_true)
            j = torch.sum(y_pred)
            intersection = torch.sum(y_true * y_pred)
        else:
            i = y_true.sum(1).sum(1).sum(1)
            j = y_pred.sum(1).sum(1).sum(1)
            intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
        score = (2. * intersection + smooth) / (i + j + smooth)
        return score.mean()


In [None]:
def train_one_epoch_net(model, train_dl, learn):
    opt = torch.optim.Adam(model.parameters(),lr=learn)
    running_loss_image=0.0
    metric_epoch=0.0
    dice_epoch=0.0
    for a,b in train_dl:
        a=a.float()
        label=b.float()
        label_loss=b.type(torch.float)
        pred1,pred2,pred3,pred4,pred5=model(a.cuda())
        loss=y_bce_loss(pred1,pred2,pred3,pred4,pred5,label_loss.cuda())
        opt.zero_grad()
        loss.backward()
        opt.step()
        running_loss_image += loss
        metric=iou(label_loss,pred1.detach().cpu())
        dice = dice_coeff(label_loss,pred1.detach().cpu())
        dice_epoch += dice
        metric_epoch += metric
    running_loss_image/=len(train_dl)
    metric_epoch /= len(train_dl)
    dice_epoch /= len(train_dl)
    return model, dice_epoch, metric_epoch, running_loss_image

def validate_one_epoch_net(model, val_dl):
    running_loss_image=0.0
    metric_epoch=0.0
    dice_epoch=0.0
    with torch.no_grad():
        for a,b in val_dl:
            a=a.float()
            label=b.float()
            label_loss=b.type(torch.float)
            pred1,pred2,pred3,pred4,pred5=model(a.cuda())
            loss=y_bce_loss(pred1,pred2,pred3,pred4,pred5,label_loss.cuda())
            running_loss_image += loss
            metric=iou(label_loss,pred1.detach().cpu())
            dice = dice_coeff(label_loss,pred1.detach().cpu())
            dice_epoch += dice
            metric_epoch += metric
    running_loss_image/=len(val_dl)
    metric_epoch /= len(val_dl)
    dice_epoch /= len(val_dl)
    return dice_epoch, metric_epoch, running_loss_image


In [None]:
def train_epoches_net(model,train_dl,test_dl,epoches,learn,path):
    max_accuracy=0.0
    for i in range(epoches):
        model, dice_train, iou_train, loss_train=train_one_epoch_net(model,train_dl,learn)
        dice_test, iou_test, loss_test=validate_one_epoch_net(model,test_dl)
        print('epoch finished' +" " + str(i+1))
        print(f'train_loss: {loss_train:.6f}, train_dice: {dice_train:.6f}, train_iou: {iou_train[3]:.6f}')
        print(f'test_loss: {loss_test:.6f}, test_dice: {dice_test:.6f}, test_iou: {iou_test[3]:.6f}')
        path_final=os.path.join(path,
                                    f"epoch{i}_test_loss{loss_test:.4f}.pth")
        torch.save(model.state_dict(), path_final)

In [None]:
%mkdir tgrs1
train_epoches_net(model,test_dl,test_dl,20,0.00001,'/content/tgrs1')


In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/weights/building_big/tgrs_new/epoch4_test_iou0.6948'))

<All keys matched successfully>

In [None]:
test_dl=DataLoader(test_dataset,batch_size=8,num_workers=2)
x,y,z=validate_one_epoch_net(model,test_dl)
print(x,y,z)

tensor(0.7974) tensor([0.9830, 0.7974, 0.8574, 0.7000]) tensor(1.3297, device='cuda:0')
