## Data preparation

In [None]:
#1
import os
import sys
# os.environ['KMP_DUPLICATE_LIB_OK']='True'
sys.path.append("./")

import cv2
import numpy as np
import pandas as pd
import random, tqdm
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import albumentations as album
import segmentation_models_pytorch as smp

from torch.utils.data import DataLoader
#from dataloaders.datasets import Pathology

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# sdm
from scipy.ndimage import distance_transform_edt as distance
from skimage import segmentation as skimage_seg
from skimage import morphology

# test loader
import itertools

from scipy.ndimage import distance_transform_edt

In [None]:
#2
DATA_DIR = './PathologyDataset/labelled/'
x_train_dir = os.path.join(DATA_DIR, 'train/png')
y_train_dir = os.path.join(DATA_DIR, 'train/png_label')
train_dir = os.path.join(DATA_DIR, 'train/npy')


x_test_dir = os.path.join(DATA_DIR, 'test/png')
y_test_dir = os.path.join(DATA_DIR, 'test/png_label')
test_dir = os.path.join(DATA_DIR, 'test/npy')

In [None]:
#no~
'''
import shutil
print(x_train_dir+"/")
print(train_dir+"/")

def filename_without_ext(folder):
    filenames = os.listdir(folder)
    filename_without_ext = {os.path.splitext(filename)[0] for filename in filenames}
    return filename_without_ext

file_npy = filename_without_ext(train_dir)
file_png = filename_without_ext(x_train_dir)

print(len(file_npy))
print(len(file_png))

difference = file_png - file_npy
d_dir = os.path.join(DATA_DIR, "train/no_npy")
print(len(difference))
for files in difference:
    s_path = os.path.join(x_train_dir, files+".png")
    d_path = os.path.join(d_dir, files+".png")
    shutil.move(s_path, d_path)
    print(f"{files}.png has been moved to {d_path}")
'''



### data visualization

In [None]:
#3
#helper function for data visualization
def visualize(**images):
    """
    Plot images in one row
    """
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]); 
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# Perform one hot encoding on label
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)

    return semantic_map
    
# Perform reverse one-hot-encoding on labels / preds
def reverse_one_hot(image):
    x = np.argmax(image, axis = -1)
    return x

# Perform colour coding on the reverse-one-hot outputs
def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]

    return x

### dataset

In [None]:
#4
# Useful to shortlist specific classes in datasets with large number of classes
class_names = ['background', 'benign', 'malignant']
select_classes = ['background', 'benign', 'malignant']

background=[[0, 0, 0],]

# Get RGB values of required classes
class_rgb_values = [[0, 0, 0], [255, 0, 0], [0, 255, 0]]

select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

print('Selected classes and their corresponding RGB values in labels:')
print('Class Names: ', class_names)
print('Class RGB values: ', class_rgb_values)

### Data augmentation

In [None]:
#5
def get_training_augmentation():
    train_transform = [    
        #album.RandomCrop(height=256, width=256, always_apply=True),
        album.OneOf(
            [
                album.HorizontalFlip(p=1),
                album.VerticalFlip(p=1),
                album.RandomRotate90(p=1),
            ],
            p=1,
        ),
    ]
    return album.Compose(train_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
        
    return album.Compose(_transform)

In [None]:
#6
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from utils.mypath import Path

class PathologyDataset(Dataset): #class PathologyDataset(Dataset):
    def __init__(self, args, augmentation, split:str):

        self.augmentation = augmentation
        self.args = args
        self.data_dir = os.path.join(Path.pathology_root_dir(), 'labelled', split, 'png')
        self.imgs = [os.path.join(self.data_dir, f) for f in os.listdir(self.data_dir)]
        #png images

    def __len__(self):
        return len(self.imgs)
        # return len(self.imgs)

    def __getitem__(self, index):
        _img = cv2.cvtColor(cv2.imread(self.imgs[index]), cv2.COLOR_BGR2RGB)
        mask_pth = self.imgs[index].replace("png", "npy")
        _label = np.load(mask_pth, allow_pickle=True).item().get('label')
        #print("~!~!~!~!~!~!~")
        # apply augmentations
        if self.augmentation != None:
            sample = self.augmentation(image=_img, mask=_label)
            _img, _label = sample['image'], sample['mask']
            #print("augmentation !!!")

        #_img = read_image(self.imgs[index])
        _img = _img / 255
        transform = A.Compose([ToTensorV2()])
        _img = transform(image=_img)
        #_label_pth = self.imgs[index].replace("png", "npy")
        #_label = np.load(_label_pth, allow_pickle=True).item().get('label')
        # _label[_label>0] = 1
        _label = torch.as_tensor(_label).long()

        sample = {
            'img':_img,
            'label':_label
        }
        return sample

In [None]:
#7
train_augmented_dataset = PathologyDataset([], get_training_augmentation(), "train") #Pathology.PathologyDataset()

print(len(train_augmented_dataset))

In [None]:
def collate_fn(batch):   
    batch = [b for b in batch if b is not None]
    return torch.utils.data.dataloader.default_collate(batch)

In [None]:
#8
def make_loaders(args, num_workers, pin_memory=True):
    train_set = PathologyDataset([], None, "train")+train_augmented_dataset #Pathology.PathologyDataset()
#    val_set = PathologyDataset([], get_validation_augmentation(), "valid")+val_augmented_dataset #Pathology.PathologyDataset()
    test_set = PathologyDataset([], None, "test") #Pathology.PathologyDataset()

    #train: train & validation 8:2
    print("train 나누기 전:",len(train_set), "/ test 나누기 전:",len(test_set))
    # print(len(val_set))

    # unlabel / label (60,006) ratio
    # [1:9] 6003 / 54023 (43218 / 10,805)
    # [2:8] 12005 / 48021 (38417 / 9604)
    # [3:7] 18008 / 42018 (33614 / 8404)
    # [4:6] 24010 / 36016 (28813 / 7203)
    # [5:5] 30013 / 30013 (24010 / 6003)
    # [6:4] 36016 / 24010 (19208 / 4802)
    # [7:3] 42018 / 18008 (14406 / 3602)
    # [8:2] 48021 / 12005 (9604 / 2401)
    # [9:1] 54023 / 6003 (4802 / 1201)


    
    unlabel_dataset, label_dataset = torch.utils.data.random_split(train_set, [30003, 30003])
    print("unlabel data:",len(unlabel_dataset), "/ label data:",len(label_dataset))
    
    train_set, val_set = torch.utils.data.random_split(label_dataset, [24002, 6001])
    print("label train 나눈 후:",len(train_set), "/ label validation 나눈 후:",len(val_set))
    print("test 나눈 후:",len(test_set))

    
    train_loader = DataLoader(
        train_set, batch_size=16, shuffle=True,
        num_workers=8, pin_memory=True, drop_last=True,
    ) #ori: batch size = 8, num_workers=4

    
    val_loader = DataLoader(
        val_set, batch_size=16, shuffle=False,
        num_workers=8, pin_memory=True, drop_last=True,
    ) #ori: batch size = 8, num_workers=4

    unlabel_loader = DataLoader(
        unlabel_dataset, batch_size=16, shuffle=False,
        num_workers=8, pin_memory=True, drop_last=True,
    ) #ori: batch size = 8, num_workers=4

#    test_loader = DataLoader(
#        test_set, shuffle=False, pin_memory=True
#    )
    test_loader = DataLoader(
        test_set, shuffle=False, num_workers=10, pin_memory=True, drop_last=True,
    )
    
    return train_loader, val_loader, unlabel_loader, test_loader

train_loader, val_loader, unlabel_loader, test_loader = \
    make_loaders([], num_workers=8, pin_memory=True)

### Parameter setting

In [None]:
#9
#label smoothing
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        assert 0 <= self.smoothing < 1
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [None]:
#10
def accuracy(dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        model.eval()
        for data in dataloader:
            inputs = data['img']['image']
            labels = data['label']
            images, labels = inputs.float().to(DEVICE), labels.to(DEVICE, dtype=torch.int64)
            outputs = model(images)
            #outputs = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.detach(), 1)
            total += labels.size(0)      
            correct += (predicted == labels).sum().item()

    acc = 100*correct/total
    model.train()
    return acc

In [None]:
#11
import segmentation_models_pytorch as smp
ENCODER = 'timm-resnest50d' #'timm-resnest50d' 'densenet201' 'resnet34'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = class_names
ACTIVATION ='softmax2d'  #'identity' # 'softmax2d'
DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
#if torch.cuda.is_available():
#    torch.cuda.set_device(DEVICE)
print(DEVICE)
#ACTIVATION = 'sigmoid'could be None for logits or 'softmax2d' for multiclass segmentation

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation= ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
model = model.to(DEVICE)

In [None]:
#12
# Set flag to train the model or not. If set to 'False', only prediction is performed (using an older model checkpoint)
TRAINING = True

# Set num of epochs
#update2 =70 
iterations = 80

# Set device: `cuda` or `cpu`

print("Device : ",DEVICE)

# define loss function
#criterion = nn.CrossEntropyLoss()
#(3, 0.4)
criterion = LabelSmoothingLoss(3, 0.2)
# tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)

# # define metrics
# metrics = [
#     smp.metrics.functional.iou_score(),
# ]

# define optimizer # 0.965 ~ 0.98
optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001, betas=(0.92, 0.99)),
])

#update2 =7 or 5
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=7, eta_min=0.0001,
)

# define batch size
#device = "cuda:0"

# load best saved model checkpoint from previous commit (if present)
# if os.path.exists('./Resnest_model/best_model.pth'):
#     model = torch.load('./Resnest_model/best_model.pth', map_location=DEVICE)

In [None]:
#13
import time
trainLoss = []
valLoss = []
#valDice = []
start = time.time()

# epoch = 5까지는 alpha=0으로 학습, 5~70까지는 alpha일정하게 크게 만들어 학습진행, 70이상은 alpha_t로 고정
alpha = 0
beta = 0.3 #1e-4 # 기존 연구 기반.
alpha_t = 1e-4
T1 = 5
T2 = 70
model.to(DEVICE)

In [None]:
# sdm map calculation
def compute_sdf(img_gt, out_shape):
    """
    compute the signed distance map of binary mask
    input: segmentation, shape = (batch_size, c, x, y)
    output: the Signed Distance Map (SDM) 
    sdf(x) = 0; x in segmentation boundary
             -inf|x-y|; x in segmentation
             +inf|x-y|; x out of segmentation
    normalize sdf to [-1,1]

    """

    img_gt = img_gt.astype(np.uint8)
    normalized_sdf = np.zeros(out_shape)
    # thresh = 15

    for b in range(out_shape[0]):  # batch size
    # Foreground mask: cls1 or cls2
        posmask = np.isin(img_gt[b], [1, 2]).astype(bool)  # cls1과 cls2를 foreground로 설정
        # posmask = np.isin(img_gt[b], [1]).astype(bool)
        if posmask.any():
        
            negmask = ~posmask
            posdis = distance(posmask)
            negdis = distance(negmask)
            # Apply distance function and log scaling
            # posdis = np.log1p(distance(posmask))  # Log transformation to compress range
            # negdis = np.log1p(distance(negmask))  # Log transformation for background
            
            boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
            # print(f"min negdis: {np.min(negdis)}, max negdis: {np.max(negdis)}, min posdis: {np.min(posdis)}, max posdis: {np.max(posdis)}")
            # Signed Distance Map 계산
            max_posdis = np.max(posdis)
            max_negdis = np.max(negdis)
            if max_posdis == 0 or max_negdis == 0:
                sdf = np.zeros_like(posdis)
            else:
                sdf = (negdis - np.min(negdis)) / (np.max(negdis) - np.min(negdis)) - (posdis - np.min(posdis)) / (np.max(posdis) - np.min(posdis))
                sdf[boundary == 1] = 0
            # print(f"max sdf: {np.max(sdf)}, min sdf: {np.min(sdf)}")
            normalized_sdf[b] = sdf  # batch 내에서 전체 foreground에 대한 SDM 저장


    return normalized_sdf
    



In [None]:
#13-2 wandb init
import wandb
current_time = time.strftime("%m-%d %H-%M", time.localtime())
rate = "resnest-5:5-without_sdm"
wandb.init(project='sdm_pseudo', 
           name=current_time+rate)

In [None]:
#14
torch.backends.cudnn.benchmark = False
for epoch in range(iterations):
    print(epoch,"epoch start !")
    best_score=999999
    epochStart = time.time()
    runningLoss = 0
    train_i_number = 0
    val_i_number = 0
    correct = 0
    total = 0
    best_acc = 0
    model.train() # For training
    print("train_loader", len(train_loader))
    print("unlabel_loader", len(unlabel_loader))
    for traindata, pseudodata in zip(train_loader, unlabel_loader):
    # for data in tqdm.notebook.tqdm(train_loader):
        #print(f"train data {i+1}/{len(train_loader)}")
        #print("*")
        train_inputs = traindata['img']['image']
        train_labels = traindata['label']

        pseudo_inputs = pseudodata['img']['image']
        pseudo_labels = pseudodata['label']
        #print("train_label : ", type(train_inputs), "pseudo label : ", type(pseudo_inputs))
        
        train_inputs, train_labels = train_inputs.float().to(DEVICE), train_labels.to(device=DEVICE, dtype=torch.int64)
        pseudo_inputs, pseudo_labels = pseudo_inputs.float().to(DEVICE), pseudo_labels.to(device=DEVICE, dtype=torch.int64)
        #print("**") 
        
        
        # Initialize gradients to zero
        optimizer.zero_grad()  
        # Feed-forward input data through the network
        #print("***")
        pred = model(train_inputs)
        train_outputs = F.softmax(pred, dim=1)
        print(f"pred: {pred[0]}")
        print(f"target shape: {train_labels[0]}")
        #print("pred:", pred)

        # sdm loss
        #pred = torch.log(train_outputs)
        #pred = pred - pred.mean()
        pred_sdm = torch.tanh(pred)
        # pred_sdm = torch.sigmoid(-1500*pred_tanh)
        # label_sdm = F.softmax(pred, dim=1)
        
        with torch.no_grad():
            label_sdm = compute_sdf(train_labels.cpu().numpy(), pred_sdm.shape)
            label_sdm = torch.from_numpy(label_sdm).float().to(device=DEVICE)
        sdm_loss = F.l2_loss(pred_sdm, label_sdm) #l1_loss(pred_tanh, pred_sdm) #mse_loss()
        # print("pred_tanh:", pred_tanh, "pred_sdm:", pred_sdm, "train output:", train_outputs)
        #print("*****")

        #print("**")
        if alpha > 0: # alpha>0이면 pseudo label 포함해서 loss 계산
            pseudo_pred = model(pseudo_inputs)
            pseudo_outputs = F.softmax(pseudo_pred, dim=1)
            _, pseudo_labels = torch.max(pseudo_outputs.detach(), 1)

            # sdm loss 연산
            # pseudo_pred = torch.log(pseudo_outputs)
            # pseudo_pred = pseudo_pred - pseudo_pred.mean()   
            pseudo_sdm = torch.tanh(pseudo_pred)
            # pseudo_tanh = torch.sigmoid(-1500*pseudo_tanh)
            # pseudo_label_sdm = F.softmax(pseudo_pred, dim=1)
            with torch.no_grad():
                pseudo_label_sdm = compute_sdf(pseudo_labels.cpu().numpy(), pseudo_sdm.shape)
                pseudo_label_sdm = torch.from_numpy(pseudo_label_sdm).float().to(device=DEVICE)
            pseudo_sdm_loss = F.l2_loss(pseudo_sdm, pseudo_label_sdm) #l1_loss(pseudo_tanh, pseudo_sdm) #mse_loss
            loss = criterion(train_outputs, train_labels)  + alpha*criterion(pseudo_outputs, pseudo_labels) + beta*(sdm_loss+pseudo_sdm_loss)
            #print("sdm loss: %.2f, pseudo sdm loss: %.2f, sdm included loss: %.2f" %(sdm_loss, pseudo_sdm_loss, loss))
            #print("******")
        else:
            loss = criterion(train_outputs, train_labels) + beta*sdm_loss
            #print("sdm loss: %.2f, sdm included loss: %.2f "%(sdm_loss, loss))
            # print("loss: ", loss)
            #print("******") 
            
        # Backpropagate loss and compute gradients
        loss.backward()
        # Update the network parameters
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        #print("****")

        _, predicted = torch.max(train_outputs.detach(), 1)
        total += train_labels.size(0)
        correct += (predicted == train_labels).sum().item()

        #print("*****")
    
        if (epoch > T1) and (epoch < T2):  #epoch이 5부터 70까지일 때 
            alpha = alpha_t*(epoch - T1)/(T2 - T1)               
            #print("******")

        elif epoch >= T2:    #epoch이 70이상일 때 
            alpha = alpha_t
            #print("******")
    val_acc = accuracy(val_loader)
    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './model/resnest_pseudo5.pth')    
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))  
        wandb.log({"train acc": 100*correct/total, "validation acc":val_acc})
        print("*******")
    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))
        wandb.log({"train acc": 100*correct/total, "validation acc":val_acc})

        print("*******")


In [None]:
#14-1: without sdm
torch.backends.cudnn.benchmark = False
for epoch in range(iterations):
    print(epoch,"epoch start !")
    best_score=999999
    epochStart = time.time()
    runningLoss = 0
    train_i_number = 0
    val_i_number = 0
    correct = 0
    total = 0
    best_acc = 0
    model.train() # For training
    #print("train_loader", len(train_loader))
    #print("unlabel_loader", len(unlabel_loader))
    for traindata, pseudodata in zip(train_loader, unlabel_loader):
    # for data in tqdm.notebook.tqdm(train_loader):
        #print("*")
        train_inputs = traindata['img']['image']
        train_labels = traindata['label']

        pseudo_inputs = pseudodata['img']['image']
        pseudo_labels = pseudodata['label']
        #print("train_label : ", type(train_inputs), "pseudo label : ", type(pseudo_inputs))
        
        train_inputs, train_labels = train_inputs.float().to(DEVICE), train_labels.to(device=DEVICE, dtype=torch.int64)
        pseudo_inputs, pseudo_labels = pseudo_inputs.float().to(DEVICE), pseudo_labels.to(device=DEVICE, dtype=torch.int64) 
        #print("**")
        # Initialize gradients to zero
        optimizer.zero_grad()  
        # Feed-forward input data through the network
        
        train_outputs = model(train_inputs)
        
        #print("***")
        if alpha > 0: 
            pseudo_outputs = model(pseudo_inputs)
            _, pseudo_labels = torch.max(pseudo_outputs.detach(), 1)     
            loss = criterion(train_outputs, train_labels)  + alpha*criterion(pseudo_outputs, pseudo_labels)
            #print("****")
        else:
            loss = criterion(train_outputs, train_labels)
            #print("****")     
            
        # Backpropagate loss and compute gradients
        loss.backward()
        # Update the network parameters
        optimizer.step()

        #print("*****")

        _, predicted = torch.max(train_outputs.detach(), 1)
        total += train_labels.size(0)
        correct += (predicted == train_labels).sum().item()

        #print("******")
    
        if (epoch > T1) and (epoch < T2): 
            alpha = alpha_t*(epoch - T1)/(T2 - T1)               
            #print("*******")

        elif epoch >= T2:  
            alpha = alpha_t
            #print("*******")
    print("********")    
    val_acc = accuracy(val_loader)
    if val_acc >= best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), './[SDM]pseudo_segmentation_model/resnest_pseudo5_withoutsdm.pth')    
        print('[%d] train acc: %.2f, validation acc: %.2f - Saved the best model' %(epoch, 100*correct/total, val_acc))  
        #wandb.log({"train acc": 100*correct/total, "validation acc":val_acc})
        print("********")
    elif epoch % 10 == 0:
        print('[%d] train acc: %.2f, validation acc: %.2f' %(epoch, 100*correct/total, val_acc))
        #wandb.log({"train acc": 100*correct/total, "validation acc":val_acc})
        print("********")

In [None]:
torch.cuda.empty_cache()

### visualize result

In [None]:
micro_iou_score_list = []
micro_f1_score_list = []
micro_accuracy_list = []
micro_recall_list = []
micro_precision_list = []
micro_sensitivity_list = []
micro_specificity_list = []

macro_iou_score_list = []
macro_f1_score_list = []
macro_accuracy_list = []
macro_recall_list = []
macro_precision_list = []
macro_sensitivity_list = []
macro_specificity_list = []

micro_imagewise_iou_score_list = []
micro_imagewise_f1_score_list = []
micro_imagewise_accuracy_list = []
micro_imagewise_recall_list = []
micro_imagewise_precision_list = []
micro_imagewise_sensitivity_list = []
micro_imagewise_specificity_list = []

macro_imagewise_iou_score_list = []
macro_imagewise_f1_score_list = []
macro_imagewise_accuracy_list = []
macro_imagewise_recall_list = []
macro_imagewise_precision_list = []
macro_imagewise_sensitivity_list = []
macro_imagewise_specificity_list = []

best_model=model
n=0
best_model.load_state_dict(torch.load('./model/resnest_sdm.pth', map_location=DEVICE))
with torch.no_grad(): # torch.no_grad()를 하면 gradient 계산을 수행 안 함
    best_model.eval()
    total=len(test_loader)
    for i, data in enumerate(itertools.islice(test_loader, total-1)):
    # for i, data in enumerate(test_loader):
    # for data in tqdm.notebook.tqdm(test_loader):
        inputs = data['img']['image']
        # print(f"{i}: {inputs.shape}")
        #print(inputs)
        labels = data['label'] 

        inputs, labels = inputs.float().to(DEVICE), labels.float().to(device=DEVICE, dtype=torch.int64) #inputs.float().to(DEVICE), labels.float().to(device=DEVICE, dtype=torch.int64)    
        preds = best_model(inputs)    
        output = F.softmax(preds, dim=1)
 
        target = labels


        _, output = torch.max(output, 1)

        tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)        
        
        # then compute metrics with required reduction (see metric docs)
        # micro, macro, weighted,
        
        micro_iou_score = round(smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro").item(), 3)
        micro_f1_score = round(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item(), 3)
        micro_accuracy = round(smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").item(),3)
        micro_recall = round(smp.metrics.recall(tp, fp, fn, tn, reduction="micro").item(), 3)
        micro_precision = round(smp.metrics.precision(tp, fp, fn, tn, reduction="micro").item(), 3)
        micro_sensitivity = round(smp.metrics.sensitivity(tp, fp, fn, tn, reduction="micro").item(), 3)
        micro_specificity = round(smp.metrics.specificity(tp, fp, fn, tn, reduction="micro").item(), 3)

        macro_iou_score = round(smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro").item(), 3)
        macro_f1_score = round(smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro").item(), 3)
        macro_accuracy = round(smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro").item(),3)
        macro_recall = round(smp.metrics.recall(tp, fp, fn, tn, reduction="macro").item(), 3)
        macro_precision = round(smp.metrics.precision(tp, fp, fn, tn, reduction="macro").item(), 3)
        macro_sensitivity = round(smp.metrics.sensitivity(tp, fp, fn, tn, reduction="macro").item(), 3)
        macro_specificity = round(smp.metrics.specificity(tp, fp, fn, tn, reduction="macro").item(), 3)

        micro_imagewise_iou_score = round(smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise").item(), 3)
        micro_imagewise_f1_score = round(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise").item(), 3)
        micro_imagewise_accuracy = round(smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro-imagewise").item(),3)
        micro_imagewise_recall = round(smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise").item(), 3)
        micro_imagewise_precision = round(smp.metrics.precision(tp, fp, fn, tn, reduction="micro-imagewise").item(), 3)
        micro_imagewise_sensitivity = round(smp.metrics.sensitivity(tp, fp, fn, tn, reduction="micro-imagewise").item(), 3)
        micro_imagewise_specificity = round(smp.metrics.specificity(tp, fp, fn, tn, reduction="micro-imagewise").item(), 3)

        macro_imagewise_iou_score = round(smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro-imagewise").item(), 3)
        macro_imagewise_f1_score = round(smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro-imagewise").item(), 3)
        macro_imagewise_accuracy = round(smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro-imagewise").item(),3)
        macro_imagewise_recall = round(smp.metrics.recall(tp, fp, fn, tn, reduction="macro-imagewise").item(), 3)
        macro_imagewise_precision = round(smp.metrics.precision(tp, fp, fn, tn, reduction="macro-imagewise").item(), 3)
        macro_imagewise_sensitivity = round(smp.metrics.sensitivity(tp, fp, fn, tn, reduction="macro-imagewise").item(), 3)
        macro_imagewise_specificity = round(smp.metrics.specificity(tp, fp, fn, tn, reduction="macro-imagewise").item(), 3)


        micro_iou_score_list.append(micro_iou_score)
        micro_f1_score_list.append(micro_f1_score)
        micro_accuracy_list.append(micro_accuracy)
        micro_recall_list.append(micro_recall)
        micro_precision_list.append(micro_precision)
        micro_sensitivity_list.append(micro_sensitivity)
        micro_specificity_list.append(micro_specificity)

        macro_iou_score_list.append(macro_iou_score)
        macro_f1_score_list.append(macro_f1_score)
        macro_accuracy_list.append(macro_accuracy)
        macro_recall_list.append(macro_recall)
        macro_precision_list.append(macro_precision)
        macro_sensitivity_list.append(macro_sensitivity)
        macro_specificity_list.append(macro_specificity)

        micro_imagewise_iou_score_list.append(micro_imagewise_iou_score)
        micro_imagewise_f1_score_list.append(micro_imagewise_f1_score)
        micro_imagewise_accuracy_list.append(micro_imagewise_accuracy)
        micro_imagewise_recall_list.append(micro_imagewise_recall)
        micro_imagewise_precision_list.append(micro_imagewise_precision)
        micro_imagewise_sensitivity_list.append(micro_imagewise_sensitivity)
        micro_imagewise_specificity_list.append(micro_imagewise_specificity)

        macro_imagewise_iou_score_list.append(macro_imagewise_iou_score)
        macro_imagewise_f1_score_list.append( macro_imagewise_f1_score)
        macro_imagewise_accuracy_list.append(macro_imagewise_accuracy)
        macro_imagewise_recall_list.append(macro_imagewise_recall)
        macro_imagewise_precision_list.append(macro_imagewise_precision)
        macro_imagewise_sensitivity_list.append(macro_imagewise_sensitivity)
        macro_imagewise_specificity_list.append(macro_imagewise_specificity)



In [None]:
fig = plt.figure(figsize=(5,3))
plt.title('micro_iou_score')
plt.plot(micro_iou_score_list,'r-',label='IoU_score')
plt.grid(color = 'gray', linestyle = ':', linewidth = 0.5)


fig = plt.figure(figsize=(5,3))
plt.title('micro_f1_score')
plt.plot(micro_f1_score_list,'b-',label='f1_score')
plt.grid(color = 'gray', linestyle = ':', linewidth = 0.5)

fig = plt.figure(figsize=(5,3))
plt.title('micro_accuracy')
plt.plot(micro_accuracy_list,'y-',label='accuracy')
plt.grid(color = 'gray', linestyle = ':', linewidth = 0.5)

In [None]:
import math
def metrics_average(metrics_list):
    metric_avg = sum([x for x in metrics_list if isinstance(x, (int, float)) and not math.isnan(x)])/len(metrics_list)
    return round(metric_avg, 3)

In [None]:
print("micro_iou_score: ",metrics_average(micro_iou_score_list))
print("micro_f1_score_score: ",metrics_average(micro_f1_score_list))
print("micro_accuracy_score: ",metrics_average(micro_accuracy_list))
print("micro_recall_score: ",metrics_average(micro_recall_list))
print("micro_precision_score: ",metrics_average(micro_precision_list))
print("micro_sensitivity_score: ",metrics_average(micro_sensitivity_list))
print("micro_specificity_score: ",metrics_average(micro_specificity_list))

# print("macro_iou_score: ",metrics_average(macro_iou_score_list))
# print("macro_f1_score_score: ",metrics_average(macro_f1_score_list))
# print("macro_accuracy_score: ",metrics_average(macro_accuracy_list))
# print("macro_recall_score: ",metrics_average(macro_recall_list))
# print("macro_precision_score: ",metrics_average(macro_precision_list))
# print("macro_sensitivity_score: ",metrics_average(macro_sensitivity_list))
# print("macro_specificity_score: ",metrics_average(macro_specificity_list))

# print("micro_imagewise_iou_score: ",metrics_average(micro_imagewise_iou_score_list))
# print("micro_imagewise_f1_score_score: ",metrics_average(micro_imagewise_f1_score_list))
# print("micro_imagewise_accuracy_score: ",metrics_average(micro_imagewise_accuracy_list))
# print("micro_imagewise_recall_score: ",metrics_average(micro_imagewise_recall_list))
# print("micro_imagewise_precision_score: ",metrics_average(micro_imagewise_precision_list))
# print("micro_imagewise_sensitivity_score: ",metrics_average(micro_imagewise_sensitivity_list))
# print("micro_imagewise_specificity_score: ",metrics_average(micro_imagewise_specificity_list))

# print("macro_imagewise_iou_score: ",metrics_average(macro_imagewise_iou_score_list))
# print("macro_imagewise_f1_score_score: ",metrics_average(macro_imagewise_f1_score_list))
# print("macro_imagewise_accuracy_score: ",metrics_average(macro_imagewise_accuracy_list))
# print("macro_imagewise_recall_score: ",metrics_average(macro_imagewise_recall_list))
# print("macro_imagewise_precision_score: ",metrics_average(macro_imagewise_precision_list))
# print("macro_imagewise_sensitivity_score: ",metrics_average(macro_imagewise_sensitivity_list))
# print("macro_imagewise_specificity_score: ",metrics_average(macro_imagewise_specificity_list))

### 시각화 코드(input, target, prediction, sdm)

In [None]:
torch.cuda.empty_cache()

In [None]:
# malignant category 선언
sem_classes = ['background', 'benign', 'malignant']
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} 
benign_category = sem_class_to_idx["benign"]
malignant_category = sem_class_to_idx["malignant"]

#x,y존재
best_model = model
best_model.load_state_dict(torch.load('/home/SDMSegmentationCode/model/resnest_pseudo5.pth', map_location=DEVICE))
def visualize_one_sample(test_x_data, test_y_data, model):

    model.eval() 
    with torch.no_grad():

        X = test_x_data.float().to(DEVICE)
        Y = test_y_data.float().to(DEVICE)
        prediction = model(X)  # 예: output shape=(B, num_classes, H, W)
        

        _, pred_mask = torch.max(prediction, 1)  # shape=(B, H, W)
        pred_mask = pred_mask[0].cpu().numpy()
        target_mask = Y[0].cpu().numpy().squeeze()

        input_img = X[0].cpu().numpy()
        if input_img.ndim == 3 and input_img.shape[0] in [1, 3]:
            if input_img.shape[0] == 1:
                input_img = input_img.squeeze(0)
                cmap_input = 'viridis'  # 단일 채널은 viridis 컬러맵 적용
            else:
                input_img = np.transpose(input_img, (1, 2, 0))
                cmap_input = None  
        else:
            cmap_input = 'viridis'
  
        binary_mask = (pred_mask == benign_category).astype(np.uint8)
    
        sdm = compute_sdf(binary_mask, binary_mask.shape).squeeze()
        sdm = torch.from_numpy(sdm).float()
        

    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    
    axs[0].imshow(input_img, cmap=cmap_input)
    axs[0].set_title("Input")
    axs[0].axis("off")
    

    axs[1].imshow(target_mask[0], cmap='viridis') # bianry_mask
    axs[1].set_title("Target")
    axs[1].axis("off")
    
    axs[2].imshow(binary_mask, cmap='viridis')
    axs[2].set_title("Prediction")
    axs[2].axis("off")
    
    im = axs[3].imshow(sdm, cmap='seismic')
    axs[3].set_title("SDM Output")
    axs[3].axis("off")
    fig.colorbar(im, ax=axs[3], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()

In [None]:
import os
import cv2
import torchvision.transforms.functional as TF
from torchvision.io import read_image


test_x_data_path = '/home/SDMSegmentationCode/PathologyDataset/labelled/test/png'
test_y_data_path = '/home/DMSegmentationCode/PathologyDataset/labelled/test/png_label'

test_x_data = sorted(os.listdir(test_x_data_path))
test_y_data = sorted(os.listdir(test_y_data_path))

for i in range(50, 52):

   xname = test_x_data[i]
   yname = test_y_data[i]

   x_path = os.path.join(test_x_data_path, xname)
   y_path = os.path.join(test_y_data_path, yname)
    
    
   x = cv2.imread(x_path)
   x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
    # plt.imshow(x)
   X = torch.Tensor(x)
   X = X /255
    
    # print(X)
     
   y = cv2.imread(y_path)    
   y = cv2.cvtColor(y, cv2.COLOR_BGR2RGB)
    # plt.imshow(y)
   Y = torch.Tensor(y)
    # print(Y)
    
    #y = cv2.imread(yname)
    #y = cv2.cvtColor(cv2.imread(yname), cv2.COLOR_BGR2RGB)
    #Y = TF.to_tensor(y)
    
    #파일명 출력
   print("x 파일명 : ",xname)
   print("y 파일명 : ",yname)

    
   X = np.transpose(X, (2, 0, 1))
   Y = np.transpose(Y, (2, 0, 1))
    
   X.unsqueeze_(0)
   Y.unsqueeze_(0)
    
   #데이터 shape 출력
   # print("기본 x shape : ",X.shape)
   # print("기본 y shape : ",X.shape)
   # 
    
   # 시각화: 한 행에 2개의 이미지 (좌측: x, 우측: y)
   # plt.subplot(1, 2, 1)
   # plt.imshow(x)
   # plt.title("X Image")

   # plt.subplot(1, 2, 2)
   # plt.imshow(y)
   # plt.title("Y Mask")

   # plt.show()

   visualize_one_sample(X, Y, best_model)

#### visualize sdm(single image)

In [None]:
label_image = '/home/SegmentationCode/PathologyDataset/labelled/test/png_label/S21-3099_x100_11_[1088_1088].png'

label = cv2.imread(label_image)
label = np.transpose(label, (2, 0, 1))
label = label/255

sdm = compute_sdf(label[2], label[2].shape)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# 원본 라벨 이미지 시각화
axs[0].imshow(label[2], cmap='viridis')
axs[0].set_title('Label Image')

im = axs[1].imshow(sdm, cmap='seismic', vmin=-1, vmax=1)
axs[1].set_title('Signed Distance Map')

fig.colorbar(im, ax=axs[1])

plt.show()


In [None]:
from PIL import Image
both_images = np.hstack((inputs[1].permute(1,2,0).cpu().detach().numpy(), np.repeat(malignant_mask_uint8[:, :, None], 3, axis=-1)))
Image.fromarray(both_images.astype(np.uint8))

In [None]:
targets = [SemanticSegmentationTarget(1, malignant_mask_float)]

In [None]:
targets

In [None]:
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image

In [None]:
plt.imshow(inputs[4].permute(1,2,0).cpu().detach().numpy())

In [None]:
plt.imshow(output[1].cpu().detach().numpy())

In [None]:
rgb_img = inputs[1].permute(1,2,0).cpu().detach().numpy()

In [None]:
inputs = inputs[1].permute(1,2,0).cpu().detach().numpy()

In [None]:
with GradCAM(model=model,
             target_layers=target_layers,
             use_cuda=torch.cuda.is_available()) as cam:
    grayscale_cam = cam(input_tensor=inputs,
                        targets=targets)[0, :]
    cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

Image.fromarray(cam_image)

#원본이미지 같이 출력이 안됨

In [None]:
plt.imshow(grayscale_cam)

In [None]:
grayscale_cam.shape

In [None]:
inputs.shape

In [None]:
plt.imshow(np.argmax(preds[4].cpu().detach().numpy(), axis=0), vmin = 0, vmax =1)

In [None]:
#The Dice coefficient can be used to compare the pixel-wise agreement between a predicted segmentation and its ground truth. 
def dice_coef(y_pred,y_true,smooth=1):
    #_, prediction = torch.max(y_pred, 1)
    y_pred = y_pred.view(batch_size,-1)
    y_true = y_true.view(batch_size,-1)
    y_pred = (y_pred > 0.5).float()
    y_true = (y_true > 0.5).float()
    dice = (2* torch.sum(y_pred*y_true)+smooth)/(torch.sum(y_pred)+torch.sum(y_true)+smooth)
    dice = torch.mean(dice , axis=0)
    return dice