In [1]:
import shutil
import torch


from dataset import *
from utils import *
from settings_benchmark import *

from dataset import writer
from torch.utils.tensorboard import SummaryWriter

all_dataset = prepareDatasets()
print(f"Models: {[name for name in models]}")
print(f"Datasets: {[name for name in all_dataset]}")

# 自检：尝试加载每个模型一次，以确保每个模型都能加载
print("Trying to load each model...")
for name_model in models:
    model:nn.Module = models[name_model]()
    


root_result = "result"
if not os.path.exists(root_result):
    os.mkdir(root_result)

id_card = 0
# 手动选择显卡
count_card = torch.cuda.device_count()
if count_card > 1:
    while True:
        s = input(f"Please choose a video card number (0-{count_card-1}): ")
        if s.isdigit():
            id_card = int(s)
            if id_card >= 0 and id_card < count_card:
                break
        print("Invalid input!")
        continue
device_cuda = torch.device(f'cuda:{id_card}' if torch.cuda.is_available() else 'cpu')
print(f"\n\nVideo Card {id_card} will be used.")


        
for name_model in models:
    root_result_model = os.path.join(root_result, name_model)
    if not os.path.exists(root_result_model):
        os.mkdir(root_result_model)
    foo = models[name_model]()
    total = sum([param.nelement() for param in foo.parameters()])
    print("Model:{}, Number of parameter: {:.3f}M".format(name_model, total/1e6))
    # continue
    # 在各个训练集上训练
    for name_dataset in all_dataset:
        dataset = all_dataset[name_dataset]
        
        trainLoader = DataLoader(dataset=dataset['train'],batch_size=2, shuffle=True, drop_last=False, num_workers=0)
        valLoader = DataLoader(dataset=dataset['val'])
        testLoader = DataLoader(dataset=dataset['test'])
        model:nn.Module = models[name_model]().to(device_cuda)
        
        
            
        root_result_model_dataset = os.path.join(root_result_model, name_dataset)
        path_flag = os.path.join(root_result_model_dataset, f"finished.flag")
        if os.path.exists(path_flag):
            continue
        if os.path.exists(root_result_model_dataset):
            shutil.rmtree(root_result_model_dataset)
        os.mkdir(root_result_model_dataset)
        
        
        print(f"\n\n\nCurrent Model:{name_model}, Current training dataset: {name_dataset}")
        

        log_section = f"{name_model}_{name_dataset}"
        


        funcLoss = DiceLoss() if 'loss' not in dataset else dataset['loss']
        thresh_value = None if 'thresh' not in dataset else dataset['thresh']
        # optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad ], lr=1e-3, weight_decay=1e-4)
        optimizer = torch.optim.Adam([param for param in model.parameters() if param.requires_grad ],
                                    lr=1e-4, weight_decay=0.001)
        NUM_MAX_EPOCH = 300
        bestResult = {"epoch":-1, "dice":-1}
        ls_best_result = []
        for epoch in range(NUM_MAX_EPOCH):
            torch.cuda.empty_cache()


            log_section_parent = f"{log_section}"
            result_train = traverseDataset(model=model, loader=trainLoader, 
                        thresh_value=thresh_value, 
                        log_section=f"{log_section_parent}_{epoch}_train",
                        log_writer=writer if epoch%5==0 else None,
                        description=f"Train Epoch {epoch}", device=device_cuda,
                        funcLoss=funcLoss, optimizer=optimizer)
            
            for key in result_train:
                writer.add_scalar(tag=f"{log_section}/{key}_train", 
                                scalar_value=result_train[key],  
                                global_step=epoch  
                                )

            # val
            result = traverseDataset(model=model, loader=valLoader, 
                        thresh_value=thresh_value, 
                        log_section=f"{log_section_parent}_{epoch}_val",
                        log_writer=writer if epoch%5==0 else None,
                        description=f"Val Epoch {epoch}", device=device_cuda,
                        funcLoss=funcLoss, optimizer=None)
            for key in result:
                writer.add_scalar(tag=f"{log_section}/{key}_val", 
                                scalar_value=result[key],  
                                global_step=epoch  
                                )
            


            dice = result['dice']
            print(f"val dice:{dice}. ({name_model} on {name_dataset})")
            if dice > bestResult['dice']:
                bestResult['dice'] = dice
                bestResult['epoch'] = epoch
                ls_best_result.append("epoch={}, val_dice={:.3f}".format(epoch, dice))
                print("best dice found. evaluating on testset...")

                result = traverseDataset(model=model, loader=testLoader, 
                        thresh_value=thresh_value, 
                        log_section=None,
                        log_writer=None,
                        description=f"Test Epoch {epoch}", device=device_cuda,
                        funcLoss=funcLoss, optimizer=None)
                ls_best_result.append(result)
                
                path_json = os.path.join(root_result_model_dataset, "best_result.json")
                with open(path_json, "w") as f:
                    json.dump(ls_best_result,f, indent=2)
                path_model = os.path.join(root_result_model_dataset, 'model_best.pth')
                torch.save(model.state_dict(), path_model)
            else:
                threshold = 100
                if epoch - bestResult['epoch'] >= threshold:
                    print(f"Precision didn't improve in recent {threshold} epoches, stop training.")
                    break

        with open(path_flag, "w") as f:
            f.write("training and testing finished.")
            



Models: ['AttUNet', 'UNetppp', 'CSNet']
Datasets: ['OCTA500_6M', 'OCTA500_3M', 'ROSSA']
Trying to load each model...


  nn.init.kaiming_normal(m.weight)




Video Card 0 will be used.
Model:AttUNet, Number of parameter: 0.354M



Current Model:AttUNet, Current training dataset: OCTA500_6M


Train Epoch 0: 100%|██████████| 90/90 [09:52<00:00,  6.58s/batch, avg_loss=5.576, curr_loss=3.348]   
Val Epoch 0: 100%|██████████| 20/20 [00:22<00:00,  1.14s/batch, avg_loss=6.697, curr_loss=9.810]  


val dice:0.10260868577335573. (AttUNet on OCTA500_6M)
best dice found. evaluating on testset...


Test Epoch 0: 100%|██████████| 100/100 [01:51<00:00,  1.12s/batch, avg_loss=-0.617, curr_loss=2.158]   
Train Epoch 1: 100%|██████████| 90/90 [09:45<00:00,  6.50s/batch, avg_loss=3.447, curr_loss=2.859]
Val Epoch 1: 100%|██████████| 20/20 [00:21<00:00,  1.08s/batch, avg_loss=-31.770, curr_loss=16.996]  


val dice:0.11867581767731342. (AttUNet on OCTA500_6M)
best dice found. evaluating on testset...


Test Epoch 1: 100%|██████████| 100/100 [01:46<00:00,  1.06s/batch, avg_loss=3.269, curr_loss=2.336]  
Train Epoch 2: 100%|██████████| 90/90 [09:39<00:00,  6.44s/batch, avg_loss=2.933, curr_loss=2.821]
Val Epoch 2: 100%|██████████| 20/20 [00:22<00:00,  1.11s/batch, avg_loss=3.715, curr_loss=5.243]


val dice:0.10717957933221575. (AttUNet on OCTA500_6M)


Train Epoch 3: 100%|██████████| 90/90 [09:41<00:00,  6.46s/batch, avg_loss=2.517, curr_loss=2.594]
Val Epoch 3: 100%|██████████| 20/20 [00:21<00:00,  1.06s/batch, avg_loss=3.065, curr_loss=4.250]


val dice:0.11305412641714754. (AttUNet on OCTA500_6M)


Train Epoch 4:   6%|▌         | 5/90 [00:33<09:26,  6.66s/batch, avg_loss=2.305, curr_loss=2.278]


KeyboardInterrupt: 