In [None]:
# import libs

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
import torch.nn.functional as F

from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from utils import CustomDataset, JointTransform
from unet_model import UNet
from train import train_model
from scipy.optimize import curve_fit
from PIL import Image

import os
import logging
import time

In [None]:
# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available! Training on GPU.")
else:
    print("CUDA is not available. Training on CPU.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(device)

In [None]:
# log part
def log_creater(logger_file_path):
    if not os.path.exists(logger_file_path):
        os.makedirs(logger_file_path)
    log_name = '{}.log'.format(time.strftime('%Y-%m-%d-%H-%M'))
    final_log_file = os.path.join(logger_file_path, log_name)

    logger = logging.getLogger()  # 设定日志对象
    logger.setLevel(logging.INFO)  # 设定日志等级

    file_handler = logging.FileHandler(final_log_file)  # 文件输出
    console_handler = logging.StreamHandler()  # 控制台输出

    formatter = logging.Formatter(
        "%(asctime)s %(levelname)s: %(message)s "
    )

    file_handler.setFormatter(formatter)  # 设置文件输出格式
    console_handler.setFormatter(formatter)  # 设施控制台输出格式
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    return logger

In [None]:
# configurations

train_file_dir = 'C:\\Users\\LJY\\Desktop\\研一后\\2024kth比赛\\20240413\\testcode\\'
file_Resultdir='./Result_unet/'
fname_gt ='_IVIMParam.npy'
fname_tissue ='_TissueType.npy'
fname_noisyDWIk = '_NoisyDWIk.npy'
model_name=  "UNet"

logger = log_creater('./log/')

seg_parts = 8

set_seed(seed=10)

b_values = np.array([0, 5, 50, 100, 200, 500, 800, 1000])
loss_type = "L1L2"



#==========

learn_rate = 0.00015
batch_size = 8
time_point = 7
#==========

# Network
if time_point==7:
    b_values_no0 = torch.FloatTensor(b_values[1:])
else:
    b_values_no0 = torch.FloatTensor(b_values)
net = get_model(model_name, seg_parts, device)

optimizer = optim.AdamW(net.parameters(), lr = learn_rate, weight_decay=1e-3)  
# use crossentropyloss
criterion = nn.CrossEntropyLoss()
# # use focalloss
# class FocalLoss(nn.Module):
#     def __init__(self, gamma=2.0, alpha=None):
#         super(FocalLoss, self).__init__()
#         self.gamma = gamma
#         self.alpha = alpha  # 可选的 alpha，用于处理类别不平衡

#     def forward(self, outputs, targets):
#         # 使用 CrossEntropyLoss 计算每个样本的损失
#         ce_loss = F.cross_entropy(outputs, targets, reduction='none')
#         pt = torch.exp(-ce_loss)  # 计算概率

#         # 如果有 alpha，应用类别平衡
#         if self.alpha is not None:
#             at = self.alpha[targets]
#             ce_loss = at * ce_loss

#         # 计算 Focal Loss
#         focal_loss = ((1 - pt) ** self.gamma) * ce_loss
#         return focal_loss.mean()
# criterion = FocalLoss(gamma=2.0, alpha=torch.tensor([0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.15, 0.3]).to(device))

In [6]:


print("seg parts = ", seg_parts)

train_transforms = Compose([
    RandFlip(spatial_axis=[0], prob=0.5),
    RandFlip(spatial_axis=[1], prob=0.5),
    # RandRotate(range_x=np.pi/12, prob=0.2)
    # RandGaussianNoise(prob=0.5, mean=0.0, std=0.01),
    # RandAdjustContrast(prob=0.3, gamma=(0.9, 1.1)),
    # ScaleIntensity(minv=0.9, maxv=1.1)
])

breast_data_train = BreastDataset_seg(time_point,train_file_dir,train_val="train",one_or_twoDim=False,denoise=if_denoise, seg_num=seg_parts, transform=None)
breast_data_val = BreastDataset_seg(time_point,train_file_dir,train_val="val",one_or_twoDim=False,denoise=if_denoise, seg_num=seg_parts)


trainloader = utils.DataLoader(breast_data_train ,
                                batch_size = batch_size, 
                                shuffle = False,
                                num_workers = 0,
                                drop_last = True)

valloader = utils.DataLoader(breast_data_val,
                                batch_size = 1, 
                                shuffle = False,
                                num_workers = 0,
                                drop_last = False)     

seg parts =  8


In [None]:
testimg, testlabel = next(iter(valloader))

print(testimg.shape)
plt.figure()

for i in range(7):
    plt.subplot(2,4,i+1)
    plt.imshow(testimg[0,i,:,:])
plt.show()

plt.figure()

for i in range(7):
    plt.imshow(testlabel[0,:,:])
plt.show()

plt.figure()

for i in range(7):
    plt.imshow(testimg[0,i,:,:])
plt.show()


In [None]:
epochs = 15
model, results = train_seg(net, trainloader, valloader, criterion, optimizer, device, num_epochs = epochs, logger = logger)

In [None]:
# vali
img, label = next(iter(valloader))

pred = model(img.float().to(device))

pred = pred.argmax(1).squeeze().cpu()
label = label.squeeze().cpu()
print(pred.shape)

plt.figure()
plt.subplot(1,2,1)
plt.imshow(pred)
plt.title("pred")
plt.subplot(1,2,2)
plt.imshow(label)
plt.title("label")
plt.colorbar()
plt.show()

In [None]:
torch.save(model.state_dict(), 'model_weightsa.pth')

In [None]:
# dice = multiclass_dice_coeff(pred, label)
# print(dice)

import torch
import torch.nn.functional as F

def multiclass_dice_score(preds, labels, num_classes, eps=1e-6):

    labels_one_hot = F.one_hot(labels, num_classes).permute(0, 3, 1, 2)
    
    dice_scores = []
    
    for cls in range(num_classes):
        pred_cls = preds[:, cls, :, :]
        label_cls = labels_one_hot[:, cls, :, :]
        
        intersection = (pred_cls * label_cls).sum(dim=(1, 2))
        union = pred_cls.sum(dim=(1, 2)) + label_cls.sum(dim=(1, 2))
        
        dice = (2 * intersection + eps) / (union + eps)
        dice_scores.append(dice.mean().item())
    
    return sum(dice_scores) / num_classes

num_classes = seg_parts  

p = F.one_hot(pred, seg_parts).permute(0, 3, 1, 2).float()

dice_score = multiclass_dice_score(pred, label, num_classes)
print("Multi-class Dice Score:", dice_score)


In [None]:
# vali

checkpoint = torch.load('model_weights_8_21_34_50_l2_1e-3.pth')
net.load_state_dict(checkpoint)
net.eval()

breast_data_inf = BreastDataset_seg(time_point,train_file_dir,train_val="inf",one_or_twoDim=False,denoise=if_denoise, seg_num=seg_parts)


infloader = utils.DataLoader(breast_data_inf ,
                                batch_size = 1, 
                                shuffle = False,
                                num_workers = 0,
                                drop_last = False)

In [None]:
import train

inf_loss, inf_accuracy, inf_dicelist = train.check_accuracy(net, infloader, criterion, device)
print(inf_loss)
print(inf_dicelist)

In [None]:
plt.plot(range(epochs), results['train_loss'])
plt.plot(range(0,epochs,2), results['val_loss'])
plt.show()