In [None]:
# !pip install monai

In [None]:
!nvidia-smi

In [None]:
 # GPU 및 라이브러리 불러오기
import warnings
warnings.filterwarnings("ignore")

# !nvidia-smi
import os
import torch
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# import pydicom
# import nibabel as nib
import tqdm.notebook as tqdm

import torch
import numpy as np

import cv2
import matplotlib.pyplot as plt
import glob, natsort

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import segmentation_models_pytorch as smp
import monai
from monai.networks.utils import one_hot

import scipy
import pandas as pd
from monai.inferers import sliding_window_inference

import datetime
now = datetime.datetime.now()

In [None]:
import random
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore
    
seed_everything()

## Dataloader 2

In [None]:
import torch
from sklearn.model_selection import train_test_split

x_train = natsort.natsorted(glob.glob('preprocessed/processed_230107_nl/Train/*.png')) + natsort.natsorted(glob.glob('preprocessed/processed_230107_abnl/Train/*.png'))
x_test = natsort.natsorted(glob.glob('preprocessed/processed_230107_nl/Test/*.png')) + natsort.natsorted(glob.glob('preprocessed/processed_230107_abnl/Test/*.png')) 

x_train, x_valid = train_test_split(x_train,test_size=0.2, random_state=42, shuffle=True)
print(len(x_train), len(x_valid), len(x_test))

## augmentation

In [None]:
import albumentations as A
import cv2

transform_train = A.Compose([
    A.CenterCrop(height=800,width=1024,p=1),
    A.PadIfNeeded(min_height=1024,min_width=1024,border_mode=cv2.BORDER_CONSTANT,p=1),
    A.Resize(height=512, width=512, interpolation=cv2.INTER_CUBIC,p=1), #다시 512로
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),

    A.OneOf([
    A.InvertImg(p=0.5),
    A.ChannelShuffle(p=.5),
    ],p=0.2),

    A.OneOf([
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=False, p=0.5),
    A.RandomGamma(gamma_limit=(80,120), p=.5),
    A.RandomToneCurve(scale=0.4 ,p=.5),
    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=20, val_shift_limit=20, p=.5),
    A.RGBShift(r_shift_limit=10, g_shift_limit=10, b_shift_limit=10, p=.5),
    ],p=0.5),

    A.OneOf([
    A.MotionBlur(blur_limit=7, p=0.5),
    A.MedianBlur(blur_limit=7, p=0.5),
    A.GlassBlur(sigma=0.3, max_delta=2, p=0.5),
    A.Sharpen(alpha=(0.1, 0.3), lightness=(0.7, 1.1), p=0.5)
    ],p=0.2),

    A.OneOf([
    A.GaussNoise(var_limit=(10.0, 50.0), mean=0, p=0.5),
    A.MultiplicativeNoise(multiplier=(0.98, 1.02), p=0.5),
    A.ISONoise(color_shift=(0.01, 0.02), intensity=(0.1, 0.3), p=0.5),
    ],p=0.3),

    A.OneOf([
    A.ElasticTransform(border_mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_CUBIC, alpha=1, sigma=50, alpha_affine=50, p=0.5),
    A.GridDistortion(border_mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_CUBIC, distort_limit=0.3, num_steps=5, p=0.5),
    A.OpticalDistortion(border_mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_CUBIC, distort_limit=.05, shift_limit=0.05, p=0.5),
    ],p=0.5),
    
    A.ShiftScaleRotate(border_mode=cv2.BORDER_CONSTANT, interpolation=cv2.INTER_CUBIC, shift_limit=0.0625, scale_limit=0.0625, rotate_limit=20, p=0.5),
])

transform_valid = A.Compose([
    A.CenterCrop(height=800,width=1024,p=1),
    A.PadIfNeeded(min_height=1024,min_width=1024,border_mode=cv2.BORDER_CONSTANT,p=1),
    A.Resize(height=512, width=512, interpolation=cv2.INTER_CUBIC,p=1) #다시 512로
])


import mclahe as mc
class datasets():
    """
    explanation
    x_list:
    y_list:
    """
    def __init__(self, x_list, augmentation=None):
        self.x_list= x_list
        self.augmentation = augmentation
        
    def __len__(self):
        return len(self.x_list)
    
    def __getitem__(self, idx):
        # read data
        path_x = self.x_list[idx]
        x = cv2.imread(path_x) # H x W x 3    # 0 ~ 255   이게 3이라서 밑에 in_channel을 3으로 했다..
        #x = cv2.imread(path_x, 0) # H x W  이 된다.. 하지만 3 channel로 돌아가는 augmentation이 많다.
        x = cv2.cvtColor(x,cv2.COLOR_BGR2RGB) # H x W x 3 # 0 ~ 255
        y = np.array([0]) if 'normal' in self.x_list[idx] else np.array([1]) # 0 for nl 1 for abnl
        
        # augmentation
        if self.augmentation:
            transformed = self.augmentation(image=x) # return is dictionary transformed   ['image'] --> x, transformed['mask']  - y  #이미지와 마스크를 pair로 넣어줘야 한다.
            x = transformed['image']
            
        #normalization
        # x = x/255. # 0 ~ 255 --> 0 ~ 1
        x = mc.mclahe(x,(128,128,3))  #(128,128 정도의 patch로 순차적 normlization)
        
        # to torch type / 3 H W 
        x = np.moveaxis(x,-1,0).astype(np.float32)
        
        return {'x':x,'y':y, 'fname':path_x}
    

# 외부에서 따옴
class ImbalancedDatasetSampler(torch.utils.data.sampler.Sampler):
    """Samples elements randomly from a given list of indices for imbalanced dataset
    Arguments:
        indices: a list of indices
        num_samples: number of samples to draw
        callback_get_label: a callback-like function which takes two arguments - dataset and index
    """

    def __init__(self, dataset, indices = None, num_samples = None, callback_get_label = None):
        self.indices = list(range(len(dataset))) if indices is None else indices        # if indices is not provided, all elements in the dataset will be considered
        self.callback_get_label = callback_get_label                                    # define custom callback
        self.num_samples = len(self.indices) if num_samples is None else num_samples    # if num_samples is not provided, draw `len(indices)` samples in each iteration

        df = pd.DataFrame()                                                             # distribution of classes in the dataset
        
        label = []
        for idx in tqdm.tqdm(range(len(dataset))):
            ########## customize here ###############
            l = dataset[idx]['y'] # <-- return type of dataset was dictionary and 'y' was our label
            if 1 in l:
                label.append(1)
            else:
                label.append(0)                
            ########## customize here ###############
        label = torch.tensor(label)
        
        df["label"] = label
        df.index = self.indices
        df = df.sort_index()

        label_to_count = df["label"].value_counts()

        weights = 1.0 / label_to_count[df["label"]] # almost equally
#         weights = 1.0 / (label_to_count[df["label"]])**2 # slightly weighted to 1
        self.weights = torch.DoubleTensor(weights.to_list())

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))

    def __len__(self):
        return self.num_samples

## Dataloader

In [None]:
from torch.utils.data import DataLoader

class dataset:
    
    def __init__(self, x_list, augmentation = False):
        # variance initialization
        self.x_list = x_list
        self.augmentation = augmentation
        
    def __len__(self):
        # give information of total dataset numbers
        return len(self.x_list)
    
    def __getitem__(self, idx):
        # data extraction
        fname = self.x_list[idx]
        x = cv2.imread(fname) # (1024, 1024, 3)
        x = cv2.cvtColor(x,cv2.COLOR_BGR2RGB) # H x W x 3 # 0 ~ 255
    
        if np.percentile(x,90)>250:
            x = 255 - x
            
        if 'abnl' in fname:
            y = np.array([1]) # (1)
        else: 
            y = np.array([0]) # (1)
        
        # augmentation by Albumentation
        if self.augmentation:
            transformed = self.augmentation(image=x) # return is dictionary transformed   
            # ['image'] --> x, transformed['mask']  - y  #이미지와 마스크를 pair로 넣어줘야 한다.
            x = transformed['image']
            
        x = mc.mclahe(x,(128,128,3))  #(128,128 정도의 patch로 순차적 normlization)
        # numpy (H,W,C) --> torch (C,H,W)
        # x = x/255 # x = x/255
        x = np.moveaxis(x,-1,0)
        x = torch.tensor(x)
        
        # return x, y
        return {'x':x, 'y':y, 'fname':fname}

train_dataset = dataset(x_train, transform_train)
valid_dataset = dataset(x_valid, transform_valid)
test_dataset = dataset(x_test, transform_valid)

In [None]:
train_loader = DataLoader(train_dataset,batch_size=24,shuffle=False, sampler=ImbalancedDatasetSampler(train_dataset), pin_memory=True) #Imbalance쓰려면 False해야!
torch.save(train_loader, 'dataloader/Imbalanced_Trainloader_Hip_2C_240503.pt')
valid_loader = DataLoader(valid_dataset,batch_size=24,shuffle=False, sampler=ImbalancedDatasetSampler(valid_dataset), pin_memory=True)
torch.save(valid_loader, 'dataloader/Imbalanced_Validloader_Hip_2C_240503.pt')
test_loader = DataLoader(test_dataset,batch_size=4)

print(len(train_loader), len(valid_loader), len(test_loader))

In [None]:
train_loader = torch.load('dataloader/Imbalanced_Trainloader_Hip_2C_240503.pt')
valid_loader = torch.load('dataloader/Imbalanced_Validloader_Hip_2C_240503.pt')
test_loader = DataLoader(test_dataset, batch_size=4)

print(len(train_loader), len(valid_loader), len(test_loader))

In [None]:
# 1  #label을 만드는중 (정상0, 비정상1)  1106 277 / 346
y_train = np.concatenate([np.zeros(400),np.ones(983)])   #0과 1로 된거 400개씩
y_test = np.concatenate([np.zeros(100),np.ones(246)])

In [None]:
from torch.utils.data import DataLoader

class dataset:
    
    def __init__(self, x_list, y_list, augmentation = False):
        # variance initialization
        self.x_list = x_list
        self.y_list = y_list
        self.augmentation = augmentation
        
    def __len__(self):
        # give information of total dataset numbers
        return len(self.x_list)  #x랑 y가 같을 거니까 x만 해줘도 된다.
    
    def __getitem__(self, idx):
        # data extraction
        fname = self.x_list[idx]
        x = cv2.imread(fname) # (1024, 1024, 3) # 파일명을 numpy로
#        y = self.y_list[idx]
        y = np.array([self.y_list[idx]])
        
        # numpy (H,W,C) --> torch (C,H,W)
        x = np.moveaxis(x,-1,0)
        # x = torch.tensor(x)
        
        # return x, y
        return {'x':x, 'y':y, 'fname':fname}

train_dataset = dataset(x_train, y_train, transform_train)
valid_dataset = dataset(x_valid, y_valid, transform_valid)
test_dataset = dataset(x_test, y_test, transform_valid)

train_loader = DataLoader(train_dataset,batch_size=2,shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=2,shuffle=False)
test_loader = DataLoader(test_dataset,batch_size=2)

print(len(train_loader), len(valid_loader), len(test_loader))

In [None]:
# import cv2
# import pylab as plt
# img = cv2.imread('processed_nl/040712_000000_11457954_normal.dcm.png')
# plt.imshow(img)

In [None]:
batch = next(iter(test_loader))
x = batch['x']
y = batch['y']
x.shape, y.shape  # B C H W

In [None]:
'''
train_dataset = datasets(x_train,transform_train_woCrop)
valid_dataset = datasets(x_valid,transform_valid_woCrop)
test_dataset = datasets(x_test,transform_valid_woCrop)

train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=4, pin_memory=True)
# train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=False, batch_size=4, pin_memory=True, sampler=ImbalancedDatasetSampler(train_dataset))  # 매 배치당 샘플러에 정의된 비율로 들어가게 해준다.
valid_loader = torch.utils.data.DataLoader(valid_dataset, shuffle=False, batch_size=2, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, batch_size=1)
'''

# Model

In [None]:
import monai
net = monai.networks.nets.EfficientNetBN("efficientnet-b3", in_channels = 3, num_classes=1, spatial_dims = 2, norm='batch', pretrained=True, adv_prop=True)  #classification
# net
# # test model
# x = torch.rand(2, 3, 1024,1024)
# yhat = net(x)
# yhat.shape

In [None]:
from modules_smr.ArcFace import *
from modules_smr.NLB import *

In [None]:
import monai

# nnblock = NLBlockND(in_channels=1024, dimension=2)
# net.features.denseblock4 = nn.Sequential(net.features.denseblock4, nnblock)

nnblock = NLBlockND(in_channels=net._conv_head.out_channels, dimension=2)
net._conv_head = nn.Sequential(net._conv_head, nnblock)

net = net.to(device)
# net

## Metric

In [None]:
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy(task='binary')
valid_accuracy = Accuracy(task='binary')

In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import monai
from sklearn.metrics import *

binarization = monai.transforms.AsDiscrete(threshold=0.5)

def metrics(yhat,y):
    """
    Binary classification metric
    
    input : long type inputs torch or numpy
    output : various metric in dictionary form
    """
    
    try:
        try:
            yhat = yhat.flatten().cpu().detach().numpy()
            y = y.flatten().cpu().detach().numpy()
        except:
            yhat = yhat.flatten().numpy()
            y = y.flatten().numpy()
    except:
        yhat = yhat.flatten()
        y = y.flatten()
    
    cm = confusion_matrix(y, yhat)
    tn, fp, fn, tp = confusion_matrix(y, yhat).ravel()   #sklearn #ravel: binary일때
    accuracy = (tp+tn)/(tn+fp+fn+tp)
    iou = tp/(tp+fp+fn)
    f1 = 2*tp/(2*tp+fp+fn)
    specificity = tn / (tn+fp)
    sensitivity = tp / (tp+fn)
    ppv = tp / (tp+fp)
    npv = tn / (tn+fn)
#     print('cm',confusion_matrix(y, yhat))
    
    return {'accuracy':accuracy,
            'f1':f1, 
            'iou':iou, 
            'npv':npv,
            'sensitivity':sensitivity,
            'specificity':specificity,
            'ppv':ppv,
            'TP':tp,
            'FP':fp,
            'FN':fn,
            'TN':tn,
            'cm':cm
           }

## Train!

In [None]:
epochs = 400
lossfn = nn.BCELoss()
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)

def train(loader):
    losses = []
    net.train()
    for idx, batch in tqdm.tqdm(enumerate(loader), desc='train', total=len(loader)):
        x = batch['x'].float().to(device)
        y = batch['y'].float().to(device)
        fname = batch['fname']
        # print(x.shape, y.shape)
        
        yhat = F.sigmoid(net(x))
        # print(yhat, y)
        # print(f'epoch {epoch}, idx {idx}, yhat {yhat},y {y}')
        # print(yhat,y)
        loss = lossfn(yhat,y) #pytorch는 이 순서 / sklearn은 반대
        # metric?
        
        train_accuracy.update(yhat.cpu().detach().round().to(torch.int64), y.cpu().detach().to(torch.int64)) #tensor
            
        optimizer.zero_grad()        
        loss.backward()
        optimizer.step()
        losses.append(loss.cpu().detach().numpy())  # 2개의 batch씩 학습되고 난 loss를 나열
        
    total_train_accuracy = train_accuracy.compute()
    return np.mean(losses), total_train_accuracy

def valid(loader):
    losses = []
    ys = []
    yhats = []
    net.eval()
    for idx, batch in tqdm.tqdm(enumerate(loader), desc='valid', total=len(loader)):
        x = batch['x'].float().to(device)
        y = batch['y'].float().to(device)
        fname = batch['fname']
        # print(x.shape, y.shape)
        
        with torch.no_grad():  #valid에는 이게 꼭 있으면 좋다.
            yhat = net(x)        
            yhat = F.sigmoid(yhat)
            # print(f'epoch {epoch}, idx {idx}, yhat {yhat},y {y}')
            loss = lossfn(yhat,y) #pytorch는 이 순서 / sklearn은 반대
            # metric?
            valid_accuracy.update(yhat.cpu().detach().round().to(torch.int64), y.cpu().detach().to(torch.int64)) #tensor
            #a= valid_accuracy(yhat, y) #tensor
            #print(a,valid_accuracy.compute())
            
            ys.append(y.cpu().detach().numpy())
            yhats.append(yhat.cpu().detach().numpy())
        
        losses.append(loss.cpu().detach().numpy())  # 2개의 batch씩 학습되고 난 loss를 나열
        
        # print(f'train_loss_seg:{loss_seg}',f'train_loss_reg:{loss_reg}'
        # if idx == (len(loader) - 1) :
            # print(f'valid_loss:{np.mean(losses)}')

    total_valid_accuracy = valid_accuracy.compute()
    valid_accuracy.reset()
#    yhats = binarization(np.array(yhats)) 이게 왜 안될가?
#    print(metrics(yhats,np.array(ys)))            
    return np.mean(losses), total_valid_accuracy

def test(loader):
    xs = []
    ys = []
    yhats = []
    fnames = []
    net.eval()
    for idx, batch in tqdm.tqdm(enumerate(loader), desc='test', total=len(loader)):
        x = batch['x'].float().to(device)
        y = batch['y'].float().to(device)
        fname = batch['fname']
        
        with torch.no_grad():
            yhat = F.sigmoid(net(x))
            loss = lossfn(yhat,y)

        xs.extend(x.cpu().detach().numpy())
        ys.extend(y.cpu().detach().numpy())
        yhats.extend(yhat.cpu().detach().numpy())
        fnames.extend(fname)

    return xs, ys, yhats, fnames

In [None]:
## Train 날짜시간 폴더 만들기!
import datetime
now = datetime.datetime.now()
print(now)

import os

def createDirectory():
    try:
        if not os.path.exists(f'weights/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}'):
            os.makedirs(f'weights/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}')
    except OSError:
        print("Error: Failed to create the directory.")
        
createDirectory()

# print(now.year, now.month, now.day)

In [None]:
from livelossplot import PlotLosses
plotlosses = PlotLosses()
train_losses = []
valid_losses = []

for epoch in tqdm.trange(epochs):
    train_loss, train_acc = train(train_loader) 
    valid_loss, valid_acc = valid(valid_loader)

    if epoch>20:
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)        

    if valid_losses and np.min(valid_losses) == valid_loss:
        torch.save(net, f'weights/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}_{epoch}.pt')  # save all models        
        print(f'weights/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}_{epoch}: weight saved')
        weight_name = f'weights/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}_{epoch}.pt'

    plotlosses.update({
        'loss': train_loss,
        'val_loss': valid_loss,
        'acc': train_acc,
        'val_acc': valid_acc
    })
    plotlosses.send()

In [None]:
import pandas as pd
# print(train_losses)
# print(valid_losses)
df = pd.DataFrame({'train_losses': train_losses, 'valid_losses': valid_losses})
df.to_csv('losses_log.csv', index=True)

## 학습끝. Weight 불러오기

In [None]:
weight_name = 'weights/binary_ENb3_Imb_NLB_best_20230130/binary_ENb3_Imb_NLB_best_20230130_379.pt'
net = torch.load(weight_name)
print(weight_name)

experiment_name = 'binary_ENb3_Imb_NLB_best_20230130'
weight_number = 379

In [None]:
xs, ys, yhats, fnames = test(test_loader)

In [None]:
print(len(xs), len(ys), len(yhats), len(fnames))
print(type(xs), type(ys), type(yhats), type(fnames))
print(type(xs[0]), type(ys[0]), type(yhats[0]), type(fnames[0]))
print(xs[0].shape, ys[0].shape, yhats[0].shape)

# Confusion matrix

## Confusion matrix 0.5

In [None]:
# Find unique values and counts
unique_values, counts = np.unique(yhats, return_counts=True)

# Iterate through the unique values and counts, and print them as pairs
for value, count in zip(unique_values, counts):
    print(f'Value: {value}, Count: {count}')

In [None]:
print(np.unique(ys))
binarization = monai.transforms.AsDiscrete(threshold=0.5)
yhats_binary50 = binarization(np.array(yhats))
print(np.unique(yhats))
print(np.unique(yhats_binary50))

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
y_true = ys#.flatten()
y_pred = yhats_binary50 #.flatten()
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=['normal','abnormal'])
disp.plot()
# plt.savefig(f'weights/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}_{epoch}.png')
plt.savefig(f'weights/{experiment_name}/{experiment_name}_{weight_number}_50.png')
plt.show()

## AUROC

In [None]:
from sklearn import metrics
from sklearn.metrics import roc_curve
# xs, ys, yhats, fnames = test(test_loader)

In [None]:
fpr, tpr, thresholds = metrics.roc_curve(ys, yhats, pos_label=1)
roc_auc = metrics.auc(fpr, tpr)
display = metrics.RocCurveDisplay(fpr=fpr, tpr=tpr, roc_auc=roc_auc,
                                  estimator_name=f'{experiment_name}_{weight_number}')

# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html
fig, ax = plt.subplots() #2개만 받는다
display.plot(ax=ax, figure=fig)

# Customize the plot
ax.set_xlabel('False Positive Rate (1-Specificity)')
ax.set_ylabel('True Positive Rate (Sensitivity)')
ax.set_title('ROC Curve')

# display.line_.set_color('blue')

fig.savefig(f'weights/{experiment_name}/{experiment_name}_{weight_number}_roc_curve_example.png')

plt.show()

In [None]:
def youden_index(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    idx = np.argmax(tpr - fpr)
    return thresholds[idx]

youden = youden_index(ys, yhats)

print(youden, type(youden))
youden = np.asscalar(youden)

In [None]:
print(f'ys:{np.unique(ys)}')
# print(f'yhats:{np.unique(yhats)}')
binarization = monai.transforms.AsDiscrete(threshold=youden)
yhats_binary_youden = binarization((np.array(yhats)))
print(np.unique(yhats_binary_youden))
print(np.array(ys).shape, np.array(yhats_binary_youden).shape)

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
y_true = ys#.flatten()
y_pred = yhats_binary_youden#.flatten()
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
                              display_labels=['normal','abnormal'])
disp.plot()
# plt.savefig(f'weights/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}/binary_ENb3_Imb_NLB_best_{now.year:02d}{now.month:02d}{now.day:02d}_{epoch}.png')
plt.savefig(f'weights/{experiment_name}/{experiment_name}_{weight_number}_youden.png')
plt.show()