In [1]:
import os
import datetime
import uuid
from tqdm import tqdm

import torch
import torch.nn    as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from data_declaration import Task
from loader_helper    import LoaderHelper


from vapformer.model_components import thenet,UnetrPP
from evaluation import evaluate_model
import matplotlib.pyplot as plt
import numpy as np
import random
from torchmetrics.classification import BinaryAUROC
from vapformer.dynunet_block import get_conv_layer, UnetResBlock
from monai.networks.layers.utils import get_norm_layer

from sklearn import neighbors
import scipy.sparse as sp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda:0")

ld_helper = LoaderHelper(task=Task.NC_v_AD)
train_dl = ld_helper.get_train_dl(0, batch_size = 8)
test_dl = ld_helper.get_test_dl(0, batch_size = 16)

In [3]:
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url

__all__ = ['ResNet3D', 'resnet3d18']

model_urls = {
    'resnet3d18': 'https://download.pytorch.org/models/resnet3d18-5c106cde.pth',
}

def conv3x3x3(in_planes, out_planes, stride=1):
    """3x3x3 convolution with padding"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1x1(in_planes, out_planes, stride=1):
    """1x1x1 convolution"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet3D(nn.Module):

    def __init__(self, block, layers, num_classes=1, zero_init_residual=False):
        super(ResNet3D, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        # Remove the final fully connected layer and average pooling layer
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.sig = nn.Sigmoid()
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm3d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # Remove the final average pooling and fully connected layer
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        x = self.sig(x)

        return x

def _resnet3d(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet3D(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

def resnet3d18(pretrained=False, progress=True, **kwargs):
    return _resnet3d('resnet3d18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                     **kwargs)

resnet18 = resnet3d18(pretrained=False)


In [4]:
def evaluate(model_in, test_dl, thresh=0.5, param_count=False):
        
    correct = 0; total = 0
    model_in.eval()
    total_label = torch.tensor([]).to(DEVICE)
    total_pre = torch.tensor([]).to(DEVICE)
    
    TP = 0.000001; TN = 0.000001; FP = 0.000001; FN = 0.000001
    
    with torch.no_grad():
        
        for i_batch, sample_batched in enumerate(test_dl):
            
            batch_X  = sample_batched['mri'].to(DEVICE).float()
            batch_y  = sample_batched['label'].to(DEVICE).float()

            net_out = model_in(batch_X)
            total_label = torch.cat((total_label,batch_y),1)
            total_pre = torch.cat((total_pre,net_out),1)


            for i in range(len(batch_X)):
                
                real_class = batch_y[i].item()

                # print("real class: "+real_class)
                # print("out class: "+net_out[i])
                predicted_class = 1 if net_out[i] > thresh else 0      
                
                if (predicted_class == real_class):
                    correct += 1
                    if (real_class == 0):
                        TN += 1
                    elif (real_class == 1):
                        TP += 1
                else:
                    if (real_class == 0):
                        FP += 1
                    elif (real_class == 1):
                        FN += 1
                    
                    
                total += 1

    metric = BinaryAUROC(thresholds=None)
    auc = metric(total_pre, total_label).item()
    
    sensitivity = round((TP / (TP + FN)), 5)
    specificity = round((TN / (TN + FP)), 5)
    accuracy = round((sensitivity+specificity)/2, 5)
    
    return accuracy, sensitivity, specificity, auc


In [5]:
# 0.80 2mm
# 0.80 1.5mm
resnet18.to(DEVICE)
optimizer = optim.AdamW(resnet18.parameters(), lr=0.00001, weight_decay=5e-4)


loss_function = nn.BCELoss()
loss_fig = []
eva_fig = []

epochs = 100
best_auc = 0
nb_batch = len(train_dl)



# Train
for i in range(1,1+epochs):
    loss = 0.0
    resnet18.train()
    for _, sample_batched in enumerate(tqdm(train_dl)):

        batch_x = sample_batched['mri'].to(DEVICE).float()
        batch_y = sample_batched['label'].to(DEVICE).float()

        optimizer.zero_grad()
        outputs = resnet18(batch_x)
        
        batch_loss = loss_function(outputs, batch_y)
        
        batch_loss.backward()
        optimizer.step()


        loss += float(batch_loss) / nb_batch

    tqdm.write("Epoch: {}/{}, train loss: {}".format(i, epochs, round(loss, 5)))
    # filein.write("Epoch: {}/{}, train loss: {}\n".format(i, epochs, round(loss, 5)))
    loss_fig.append(round(loss, 5))
    accuracy, sensitivity, specificity, auc = evaluate(resnet18, test_dl)
    eva_fig.append(accuracy)
    tqdm.write("Epoch: {}/{}, evaluation loss: {}".format(i, epochs,(accuracy, sensitivity, specificity, auc)))
    # filein.write("Epoch: {}/{}, evaluation loss: {}\n".format(i, epochs,(accuracy, sensitivity, specificity, auc)))




100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:09<00:00,  4.54it/s]

Epoch: 1/100, train loss: 0.69962





Epoch: 1/100, evaluation loss: (0.5, 0.0, 1.0, 0.5267750024795532)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.14it/s]

Epoch: 2/100, train loss: 0.60385





Epoch: 2/100, evaluation loss: (0.6193, 0.57, 0.6686, 0.6377325057983398)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.13it/s]

Epoch: 3/100, train loss: 0.53075





Epoch: 3/100, evaluation loss: (0.57906, 0.17526, 0.98286, 0.7361413836479187)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.13it/s]

Epoch: 4/100, train loss: 0.38503





Epoch: 4/100, evaluation loss: (0.64389, 0.40206, 0.88571, 0.7352577447891235)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.09it/s]

Epoch: 5/100, train loss: 0.18015





Epoch: 5/100, evaluation loss: (0.69352, 0.51064, 0.8764, 0.7689457535743713)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.21it/s]

Epoch: 6/100, train loss: 0.12364





Epoch: 6/100, evaluation loss: (0.71761, 0.55319, 0.88202, 0.8032512664794922)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]

Epoch: 7/100, train loss: 0.09374





Epoch: 7/100, evaluation loss: (0.65184, 0.36082, 0.94286, 0.8027688264846802)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.20it/s]

Epoch: 8/100, train loss: 0.09568





Epoch: 8/100, evaluation loss: (0.71766, 0.84211, 0.59322, 0.8099316358566284)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.23it/s]

Epoch: 9/100, train loss: 0.03595





Epoch: 9/100, evaluation loss: (0.56559, 0.13684, 0.99435, 0.8351472020149231)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]


Epoch: 10/100, train loss: 0.06525
Epoch: 10/100, evaluation loss: (0.54559, 0.09677, 0.99441, 0.8469393849372864)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.31it/s]

Epoch: 11/100, train loss: 0.03805





Epoch: 11/100, evaluation loss: (0.71127, 0.47872, 0.94382, 0.8380946516990662)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.22it/s]

Epoch: 12/100, train loss: 0.09865





Epoch: 12/100, evaluation loss: (0.64204, 0.3125, 0.97159, 0.855764627456665)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.25it/s]

Epoch: 13/100, train loss: 0.04436





Epoch: 13/100, evaluation loss: (0.71896, 0.51613, 0.92179, 0.8543881773948669)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]

Epoch: 14/100, train loss: 0.04325





Epoch: 14/100, evaluation loss: (0.64321, 0.30928, 0.97714, 0.871045708656311)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]

Epoch: 15/100, train loss: 0.03141





Epoch: 15/100, evaluation loss: (0.65787, 0.35484, 0.96089, 0.8507839441299438)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]

Epoch: 16/100, train loss: 0.06641





Epoch: 16/100, evaluation loss: (0.51579, 0.03158, 1.0, 0.8494201302528381)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.21it/s]

Epoch: 17/100, train loss: 0.10977





Epoch: 17/100, evaluation loss: (0.77815, 0.79787, 0.75843, 0.8519303798675537)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 18/100, train loss: 0.06019





Epoch: 18/100, evaluation loss: (0.58937, 0.19588, 0.98286, 0.8670986294746399)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 19/100, train loss: 0.05918





Epoch: 19/100, evaluation loss: (0.68513, 0.42708, 0.94318, 0.8629261255264282)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 20/100, train loss: 0.03065





Epoch: 20/100, evaluation loss: (0.68527, 0.40426, 0.96629, 0.8706669807434082)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.19it/s]

Epoch: 21/100, train loss: 0.03284





Epoch: 21/100, evaluation loss: (0.61837, 0.27083, 0.96591, 0.8708570003509521)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 22/100, train loss: 0.0296





Epoch: 22/100, evaluation loss: (0.66525, 0.36458, 0.96591, 0.8635179996490479)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 23/100, train loss: 0.02691





Epoch: 23/100, evaluation loss: (0.59612, 0.19792, 0.99432, 0.8665364980697632)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.26it/s]


Epoch: 24/100, train loss: 0.01769
Epoch: 24/100, evaluation loss: (0.58154, 0.16854, 0.99454, 0.8635721802711487)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 25/100, train loss: 0.03533





Epoch: 25/100, evaluation loss: (0.70596, 0.46939, 0.94253, 0.8552075624465942)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.25it/s]


Epoch: 26/100, train loss: 0.06271
Epoch: 26/100, evaluation loss: (0.62861, 0.27957, 0.97765, 0.8494023084640503)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.22it/s]

Epoch: 27/100, train loss: 0.02686





Epoch: 27/100, evaluation loss: (0.77536, 0.71739, 0.83333, 0.865881621837616)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.16it/s]

Epoch: 28/100, train loss: 0.0144





Epoch: 28/100, evaluation loss: (0.67522, 0.38889, 0.96154, 0.8722221851348877)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.25it/s]

Epoch: 29/100, train loss: 0.03311





Epoch: 29/100, evaluation loss: (0.61679, 0.23913, 0.99444, 0.8684178590774536)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.22it/s]

Epoch: 30/100, train loss: 0.04991





Epoch: 30/100, evaluation loss: (0.64669, 0.35052, 0.94286, 0.8352872133255005)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]

Epoch: 31/100, train loss: 0.03081





Epoch: 31/100, evaluation loss: (0.71022, 0.5, 0.92045, 0.8582504987716675)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 32/100, train loss: 0.05666





Epoch: 32/100, evaluation loss: (0.71142, 0.47312, 0.94972, 0.8718688488006592)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 33/100, train loss: 0.02605





Epoch: 33/100, evaluation loss: (0.67368, 0.36957, 0.97778, 0.8818236589431763)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 34/100, train loss: 0.01496





Epoch: 34/100, evaluation loss: (0.66179, 0.34043, 0.98315, 0.8734161853790283)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.23it/s]

Epoch: 35/100, train loss: 0.03405





Epoch: 35/100, evaluation loss: (0.67599, 0.39796, 0.95402, 0.8712174892425537)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.26it/s]

Epoch: 36/100, train loss: 0.10028





Epoch: 36/100, evaluation loss: (0.72284, 0.54737, 0.89831, 0.86815345287323)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]


Epoch: 37/100, train loss: 0.06627
Epoch: 37/100, evaluation loss: (0.75598, 0.63265, 0.87931, 0.8639456033706665)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.19it/s]


Epoch: 38/100, train loss: 0.02476
Epoch: 38/100, evaluation loss: (0.78424, 0.68085, 0.88764, 0.882022500038147)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.23it/s]

Epoch: 39/100, train loss: 0.0176





Epoch: 39/100, evaluation loss: (0.62169, 0.26042, 0.98295, 0.8839370608329773)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]

Epoch: 40/100, train loss: 0.02904





Epoch: 40/100, evaluation loss: (0.62861, 0.27957, 0.97765, 0.8648405075073242)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.17it/s]

Epoch: 41/100, train loss: 0.03999





Epoch: 41/100, evaluation loss: (0.79771, 0.77419, 0.82123, 0.8652610182762146)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.10it/s]

Epoch: 42/100, train loss: 0.02826





Epoch: 42/100, evaluation loss: (0.70054, 0.46465, 0.93642, 0.86857008934021)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.10it/s]

Epoch: 43/100, train loss: 0.03779





Epoch: 43/100, evaluation loss: (0.66726, 0.36842, 0.9661, 0.863038957118988)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.19it/s]

Epoch: 44/100, train loss: 0.06265





Epoch: 44/100, evaluation loss: (0.68246, 0.40426, 0.96067, 0.8711451292037964)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.22it/s]

Epoch: 45/100, train loss: 0.0341





Epoch: 45/100, evaluation loss: (0.76137, 0.625, 0.89773, 0.8548768758773804)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]

Epoch: 46/100, train loss: 0.07026





Epoch: 46/100, evaluation loss: (0.53284, 0.07143, 0.99425, 0.8509266376495361)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.21it/s]

Epoch: 47/100, train loss: 0.09301





Epoch: 47/100, evaluation loss: (0.75142, 0.59375, 0.90909, 0.8584281206130981)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.11it/s]

Epoch: 48/100, train loss: 0.01494





Epoch: 48/100, evaluation loss: (0.64962, 0.33333, 0.96591, 0.8739938735961914)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.12it/s]

Epoch: 49/100, train loss: 0.04221





Epoch: 49/100, evaluation loss: (0.66066, 0.35484, 0.96648, 0.8626779317855835)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  4.91it/s]

Epoch: 50/100, train loss: 0.05365





Epoch: 50/100, evaluation loss: (0.72195, 0.49474, 0.94915, 0.8635146617889404)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.11it/s]

Epoch: 51/100, train loss: 0.03477





Epoch: 51/100, evaluation loss: (0.5963, 0.20408, 0.98851, 0.8661154508590698)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.03it/s]

Epoch: 52/100, train loss: 0.0161





Epoch: 52/100, evaluation loss: (0.64444, 0.30612, 0.98276, 0.8633005023002625)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.01it/s]

Epoch: 53/100, train loss: 0.02565





Epoch: 53/100, evaluation loss: (0.61923, 0.25532, 0.98315, 0.8806478381156921)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.11it/s]

Epoch: 54/100, train loss: 0.05483





Epoch: 54/100, evaluation loss: (0.78551, 0.71875, 0.85227, 0.8803858757019043)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  4.97it/s]

Epoch: 55/100, train loss: 0.04341





Epoch: 55/100, evaluation loss: (0.77272, 0.6875, 0.85795, 0.8668323755264282)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.12it/s]

Epoch: 56/100, train loss: 0.04046





Epoch: 56/100, evaluation loss: (0.51315, 0.03191, 0.99438, 0.8813650608062744)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.15it/s]

Epoch: 57/100, train loss: 0.01734





Epoch: 57/100, evaluation loss: (0.73734, 0.53684, 0.93785, 0.8785608410835266)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.19it/s]

Epoch: 58/100, train loss: 0.02369





Epoch: 58/100, evaluation loss: (0.63268, 0.2766, 0.98876, 0.8830982446670532)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.25it/s]


Epoch: 59/100, train loss: 0.00871
Epoch: 59/100, evaluation loss: (0.67009, 0.36842, 0.97175, 0.8874219655990601)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.18it/s]

Epoch: 60/100, train loss: 0.00868





Epoch: 60/100, evaluation loss: (0.62545, 0.26804, 0.98286, 0.8697496056556702)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.25it/s]

Epoch: 61/100, train loss: 0.01217





Epoch: 61/100, evaluation loss: (0.62861, 0.27957, 0.97765, 0.8719888925552368)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.30it/s]

Epoch: 62/100, train loss: 0.01029





Epoch: 62/100, evaluation loss: (0.62957, 0.28723, 0.97191, 0.8844728469848633)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.26it/s]

Epoch: 63/100, train loss: 0.027





Epoch: 63/100, evaluation loss: (0.59962, 0.21053, 0.9887, 0.8708295822143555)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.17it/s]

Epoch: 64/100, train loss: 0.01846





Epoch: 64/100, evaluation loss: (0.61411, 0.23958, 0.98864, 0.8666548132896423)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.31it/s]

Epoch: 65/100, train loss: 0.01475





Epoch: 65/100, evaluation loss: (0.60452, 0.22581, 0.98324, 0.8803988695144653)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 66/100, train loss: 0.04295





Epoch: 66/100, evaluation loss: (0.55821, 0.12766, 0.98876, 0.8794525861740112)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 67/100, train loss: 0.06798





Epoch: 67/100, evaluation loss: (0.61502, 0.25263, 0.9774, 0.8678561449050903)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]


Epoch: 68/100, train loss: 0.02974
Epoch: 68/100, evaluation loss: (0.59337, 0.21, 0.97674, 0.8755813837051392)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 69/100, train loss: 0.0211





Epoch: 69/100, evaluation loss: (0.69504, 0.44565, 0.94444, 0.8621376752853394)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]

Epoch: 70/100, train loss: 0.01231





Epoch: 70/100, evaluation loss: (0.67778, 0.38947, 0.9661, 0.8749925494194031)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.23it/s]

Epoch: 71/100, train loss: 0.01162





Epoch: 71/100, evaluation loss: (0.63136, 0.28571, 0.97701, 0.8748533725738525)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]

Epoch: 72/100, train loss: 0.00517





Epoch: 72/100, evaluation loss: (0.65814, 0.33333, 0.98295, 0.8925189971923828)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 73/100, train loss: 0.00686





Epoch: 73/100, evaluation loss: (0.62086, 0.24742, 0.99429, 0.8881884813308716)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.22it/s]

Epoch: 74/100, train loss: 0.00991





Epoch: 74/100, evaluation loss: (0.63936, 0.30108, 0.97765, 0.8808794021606445)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 75/100, train loss: 0.03084





Epoch: 75/100, evaluation loss: (0.78807, 0.74468, 0.83146, 0.8637938499450684)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.26it/s]

Epoch: 76/100, train loss: 0.01316





Epoch: 76/100, evaluation loss: (0.62831, 0.30208, 0.95455, 0.8317945003509521)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.21it/s]

Epoch: 77/100, train loss: 0.04918





Epoch: 77/100, evaluation loss: (0.70346, 0.45161, 0.95531, 0.8701267242431641)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 78/100, train loss: 0.05582





Epoch: 78/100, evaluation loss: (0.5, 0.0, 1.0, 0.8790172338485718)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 79/100, train loss: 0.09403





Epoch: 79/100, evaluation loss: (0.79835, 0.74194, 0.85475, 0.860455334186554)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.22it/s]

Epoch: 80/100, train loss: 0.0393





Epoch: 80/100, evaluation loss: (0.71408, 0.47872, 0.94944, 0.8826799392700195)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]

Epoch: 81/100, train loss: 0.1058





Epoch: 81/100, evaluation loss: (0.51896, 0.04348, 0.99444, 0.8244564533233643)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.24it/s]

Epoch: 82/100, train loss: 0.0843





Epoch: 82/100, evaluation loss: (0.82083, 0.85714, 0.78453, 0.8935098052024841)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.26it/s]

Epoch: 83/100, train loss: 0.02884





Epoch: 83/100, evaluation loss: (0.64698, 0.30526, 0.9887, 0.900029718875885)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.30it/s]

Epoch: 84/100, train loss: 0.01531





Epoch: 84/100, evaluation loss: (0.66726, 0.36842, 0.9661, 0.8961641788482666)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.31it/s]


Epoch: 85/100, train loss: 0.05973
Epoch: 85/100, evaluation loss: (0.6931, 0.42553, 0.96067, 0.8853693008422852)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.23it/s]

Epoch: 86/100, train loss: 0.00949





Epoch: 86/100, evaluation loss: (0.71686, 0.47917, 0.95455, 0.8798531889915466)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.31it/s]


Epoch: 87/100, train loss: 0.00741
Epoch: 87/100, evaluation loss: (0.61884, 0.26042, 0.97727, 0.8938210010528564)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]

Epoch: 88/100, train loss: 0.03715





Epoch: 88/100, evaluation loss: (0.61254, 0.24242, 0.98266, 0.8754013776779175)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]

Epoch: 89/100, train loss: 0.03855





Epoch: 89/100, evaluation loss: (0.69034, 0.4375, 0.94318, 0.8706794381141663)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 90/100, train loss: 0.01399





Epoch: 90/100, evaluation loss: (0.65756, 0.36082, 0.95429, 0.8787628412246704)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.29it/s]

Epoch: 91/100, train loss: 0.01593





Epoch: 91/100, evaluation loss: (0.58049, 0.16667, 0.99432, 0.8770123720169067)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 92/100, train loss: 0.00913





Epoch: 92/100, evaluation loss: (0.64204, 0.3125, 0.97159, 0.8752366900444031)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]


Epoch: 93/100, train loss: 0.00704
Epoch: 93/100, evaluation loss: (0.64204, 0.3125, 0.97159, 0.8864820003509521)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.23it/s]

Epoch: 94/100, train loss: 0.00366





Epoch: 94/100, evaluation loss: (0.6284, 0.29032, 0.96648, 0.8743317127227783)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.28it/s]

Epoch: 95/100, train loss: 0.00866





Epoch: 95/100, evaluation loss: (0.73012, 0.5, 0.96023, 0.885061502456665)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.30it/s]

Epoch: 96/100, train loss: 0.00757





Epoch: 96/100, evaluation loss: (0.66326, 0.36082, 0.96571, 0.8758173584938049)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.32it/s]

Epoch: 97/100, train loss: 0.00286





Epoch: 97/100, evaluation loss: (0.6385, 0.30526, 0.97175, 0.8773714303970337)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.27it/s]

Epoch: 98/100, train loss: 0.00447





Epoch: 98/100, evaluation loss: (0.67162, 0.36559, 0.97765, 0.8888088464736938)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.21it/s]

Epoch: 99/100, train loss: 0.00541





Epoch: 99/100, evaluation loss: (0.66482, 0.35789, 0.97175, 0.8898602724075317)


100%|███████████████████████████████████████████████████████████████████████████| 43/43 [00:08<00:00,  5.26it/s]

Epoch: 100/100, train loss: 0.0056





Epoch: 100/100, evaluation loss: (0.64094, 0.31579, 0.9661, 0.8923580646514893)


In [None]:
all_outputs = []
# 将模型设置为评估模式
model.eval()
for _, sample_batched in enumerate(tqdm(train_dl)):
    batch_x = sample_batched['mri'].to(DEVICE).float()
    with torch.no_grad():
        outputs = model(batch_x)
    
    all_outputs.append(outputs)
    stacked_outputs = torch.cat(all_outputs, dim=0)
    

In [10]:
stacked_outputs.shape

torch.Size([350, 512])

In [12]:
def create_graph_from_embedding(embedding, name, n = 30):
    latent_dim, batch_size = embedding.shape
    if name == 'knn':
        A = neighbors.kneighbors_graph(embedding, n_neighbors = n).toarray()
        A = (A + np.transpose(A)) / 2
        return A

In [14]:
stacked_outputs = stacked_outputs.cpu()
adj = create_graph_from_embedding(stacked_outputs, name='knn', n=30)

In [16]:
adj.shape

(350, 350)

In [17]:
from utils import *

In [18]:
adj, edge = preprocess_adj(adj)
adj = adj.todense()

In [21]:
class GraphConvolution(nn.Module):
    def __init__(self, input_dim, output_dim, use_bias=True):
        super(GraphConvolution, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(output_dim))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, input_feature, adjacency):

        support = torch.mm(input_feature, self.weight)
        output = torch.mm(adjacency, support)
        if self.use_bias:
            output += self.bias
        return output

In [None]:
class SYNet_AD(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SYNet_AD, self).__init__()
        nd_1 = 128
        nd_2 = 64
        nd_3 = 32
        n_4 = 2

        self.gcn1_1 = GraphConvolution(input_dim, nd_1)  # AD-NC
        self.gcn1_2 = GraphConvolution(nd_1, nd_2)


        self.linear1 = torch.nn.Linear(93, n_4)  # AD-NC
        self.linear2 = torch.nn.Linear(n_4, output_dim)
        self.linear3 = torch.nn.Linear(output_dim, output_dim)




    def forward(self, adjacency1, feature, adjacency2):
        # adjacency1:S  adjacency2:A feature:X
        lam1 = 1e-5
        lam2 = 1e-3
        y_hat = F.relu(self.gcn1_1(feature.t(), adjacency1))
        y_hat = F.dropout(y_hat, 0.2, training=self.training)
        y_hat = self.gcn1_2(y_hat, adjacency1)
        
        h = self.linear1(y_hat)
        h = F.relu(self.linear2(h))
        h = self.linear3(h)

        logits = h

        return logits, X_a, X_s

In [22]:
def evaluate(model_in, test_dl, thresh=0.5, param_count=False):
        
    correct = 0; total = 0
    model_in.eval()
    total_label = torch.tensor([]).to(DEVICE)
    total_pre = torch.tensor([]).to(DEVICE)
    
    TP = 0.000001; TN = 0.000001; FP = 0.000001; FN = 0.000001
    
    with torch.no_grad():
        
        for i_batch, sample_batched in enumerate(test_dl):
            
            batch_X  = sample_batched['mri'].to(DEVICE)
            batch_clinical = sample_batched['clin_t'].to(DEVICE)
            batch_y  = sample_batched['label'].to(DEVICE)

            net_out = model_in(batch_X,batch_clinical)
            total_label = torch.cat((total_label,batch_y),1)
            total_pre = torch.cat((total_pre,net_out),1)


            for i in range(len(batch_X)):
                
                real_class = batch_y[i].item()

                # print("real class: "+real_class)
                # print("out class: "+net_out[i])
                predicted_class = 1 if net_out[i] > thresh else 0      
                
                if (predicted_class == real_class):
                    correct += 1
                    if (real_class == 0):
                        TN += 1
                    elif (real_class == 1):
                        TP += 1
                else:
                    if (real_class == 0):
                        FP += 1
                    elif (real_class == 1):
                        FN += 1
                    
                    
                total += 1

    metric = BinaryAUROC(thresholds=None)
    auc = metric(total_pre, total_label).item()
    
    sensitivity = round((TP / (TP + FN)), 5)
    specificity = round((TN / (TN + FP)), 5)
    accuracy = round((sensitivity+specificity)/2, 5)


In [31]:
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=5e-4)


loss_function = nn.BCELoss()
loss_fig = []
eva_fig = []

epochs = 100
best_auc = 0
nb_batch = len(train_dl)



# Train
for i in range(1,1+epochs):
    loss = 0.0
    model.train()
    for _, sample_batched in enumerate(tqdm(train_dl)):

        batch_x = sample_batched['mri'].to(DEVICE).float()
        batch_y = sample_batched['label'].to(DEVICE)

        model.zero_grad()
        outputs = model(batch_x)
        
        batch_loss = loss_function(outputs, batch_y)
        
        batch_loss.backward()
        optimizer.step()


        loss += float(batch_loss) / nb_batch

    tqdm.write("Epoch: {}/{}, train loss: {}".format(i, epochs, round(loss, 5)))
    # filein.write("Epoch: {}/{}, train loss: {}\n".format(i, epochs, round(loss, 5)))
    loss_fig.append(round(loss, 5))
    accuracy, sensitivity, specificity, auc = evaluate(model_in, test_dl)
    eva_fig.append(accuracy)
    tqdm.write("Epoch: {}/{}, evaluation loss: {}".format(i, epochs,(accuracy, sensitivity, specificity, auc)))
    # filein.write("Epoch: {}/{}, evaluation loss: {}\n".format(i, epochs,(accuracy, sensitivity, specificity, auc)))




  0%|                                                                                   | 0/175 [00:00<?, ?it/s]


ValueError: Using a target size (torch.Size([2, 1])) that is different to the input size (torch.Size([2, 512])) is deprecated. Please ensure they have the same size.