In [None]:
import numpy as np
from matplotlib import pyplot as plt
import os
import pandas as pd
import PIL
from PIL import Image
import torchvision
import torch
import glob
import nibabel as nib
import time
from sklearn.model_selection import StratifiedKFold

import torch.nn.functional as F
from torchvision import utils, transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import KFold
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

%matplotlib inline

In [None]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)

### 1) Data wrangling

In [None]:
adni_num = "2"
test_data_suffix = 'multi-input'
experiment_type = f"{test_data_suffix}_XTEST"
num_folds = 5
epochs = 120

all_columns = ['AGE','PTGENDER','ADAS11', 'MMSE', 'FAQ', \
               'RAVLT_immediate', 'RAVLT_learning', 'RAVLT_forgetting', \
               'CDRSB', 'APOE4']

required_columns = ['AGE','PTGENDER','APOE4']
experiment_name = 'img_' + ('_'.join(required_columns)).lower()
print(f"{experiment_name}")

myseed = 1
torch.manual_seed(myseed)
np.random.seed(myseed)
num_classes=1

#### Retrieve img filenames and paths

The file adniX_paths.pkl is a  pikle file manually generated, containing 2 columns:
- PTID: subject ID
- IMG_PATH: containing the path to the T1 acquisition of the corresponding subject_ID. The path must be absolute, since it will be used by the data loader to load the specific image.

In [None]:
# specifying the zip file name
data_path = f"a{adni_num}"

img_df_filename=f"adni{adni_num}_paths.pkl"
filename=os.path.join(data_path, img_df_filename)
img_df=pd.read_pickle(filename)  

print(f"Final data has {len(img_df)}")
img_df.head()

#### Retrieve ADNI table, with normalized values (ADNI_ready.csv)

ADNI_ready.csv must be created manually after downloading the preferred set of subjects from the ADNI website. ADNI_ready must include the following fields:
- subject_id
- ADNI type (1, 2 or 3 depending to which dataset it belongs to)
- Labels 
- 'AGE','PTGENDER','APOE4' columns

In [None]:
data_path = "ADNI_csv/"
filename = "ADNI_ready.csv"

adni_tabular = pd.read_csv(os.path.join(data_path, filename))
adni_tabular.head()


# Print 
print(f"ALL adni has {len(adni_tabular)} entries")
print(f"Class distribution is organized as follow:")
print(f"\n {adni_tabular['labels'].value_counts()}")

adni_tabular=adni_tabular[adni_tabular['SRC']==f"ADNI{adni_num}"]

print(f"Adni{adni_num} has {len(adni_tabular)} entries")
print(f"Class distribution is organized as follow:")
print(f"Final:\n {adni_tabular['labels'].value_counts()}")


## Check for duplicated rows
dup = adni_tabular[adni_tabular.duplicated()]
if not dup.empty:
    print(f"WARNING: Dataframe contains duplicated rows!!!")

#### Merge the images and tabular dataframes

In [None]:
# Extract only ADNIx subjects
adni = pd.merge( left=img_df, right=adni_tabular, how="inner", on="PTID", 
                      suffixes=("_x", "_y"),copy=False, indicator=False, validate="one_to_one")

In [None]:
print(f"Class distribution is organized as follow:")
print(f"Final:\n {adni['labels'].value_counts()}")
adni.head()

### 2) Dataset creator

In [None]:
class ImgDataset(Dataset):
    """Tabular and Image dataset."""

    def __init__(self, adni_df, required_columns=required_columns):
        self.adni = adni_df
        
    def __len__(self):
        return len(self.adni)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        line = self.adni.iloc[idx, 0:]
        # Get Label
        y = line['labels']

        # Get tabular
        tabular = line[required_columns]
        tabular = torch.DoubleTensor(tabular)
        
        # Get image
        image = nib.load(line['IMG_PATH'])
        image = image.get_fdata() 
        #image = image[..., :3]
        image = transforms.functional.to_tensor(image)
        image = image.unsqueeze(dim=0)
        
        return image, tabular, y

img_data = ImgDataset(adni_df=adni)

### Detach test set and use remaining data for train-val k-fold split (further down)

In [None]:
from sklearn.model_selection import train_test_split

labels = adni['labels'].tolist()
# Split data into train+val and test set indexes
tv_idx, test_idx = train_test_split(np.arange(len(labels)), test_size=0.1,shuffle=True,stratify=labels)

# Create train+val dataframe and show class balance
adni_tv = adni.iloc[tv_idx]
print(adni_tv.groupby(["labels"]).count())
tv_data = ImgDataset(adni_df=adni_tv)

# Create test dataframe and show class balance
adni_test = adni.iloc[test_idx]
print(adni_test.groupby(["labels"]).count())
test_data = ImgDataset(adni_df=adni_test)

In [None]:
# If executed, this cell will save the test set for this specific adni_num.
# the saved test set can be then be shared to perform a cross-test (X-TEST) models evaluation

i = 10

if False:
    torch.save(test_data, f'test_adni{adni_num}_{test_data_suffix}.pt')
    saved_test = torch.load(f'test_adni{adni_num}_{test_data_suffix}.pt')
   
    print(f"{test_data[i][0].size()}, label = {test_data[i][1]}")
    len(saved_test[i][0])

In [None]:
print(f"{test_data[i][0].size()}, tabular = {test_data[i][1]}, label = {test_data[i][2]}")

len(test_data[i][1])

### 3. Model: img+tabular

In [None]:
def get_inplanes():
    return [64, 128, 256, 512]


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


def conv1x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv3x3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super().__init__()

        self.conv1 = conv1x1x1(in_planes, planes)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = conv3x3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = conv1x1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, block_inplanes, \
                 n_input_channels=3, conv1_t_size=7, \
                 conv1_t_stride=1, no_max_pool=False, \
                 shortcut_type='B', widen_factor=1.0, \
                 n_classes=400, img_contribution=10, \
                 tabular_val=10, tabular_contribution=10):
        super().__init__()

        block_inplanes = [int(x * widen_factor) for x in block_inplanes]

        self.in_planes = block_inplanes[0]
        self.no_max_pool = no_max_pool

        self.conv1 = nn.Conv3d(n_input_channels,
                               self.in_planes,
                               kernel_size=(conv1_t_size, 7, 7),
                               stride=(conv1_t_stride, 2, 2),
                               padding=(conv1_t_size // 2, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(self.in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, block_inplanes[0], layers[0],
                                       shortcut_type)
        self.layer2 = self._make_layer(block,
                                       block_inplanes[1],
                                       layers[1],
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       block_inplanes[2],
                                       layers[2],
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       block_inplanes[3],
                                       layers[3],
                                       shortcut_type,
                                       stride=2)

        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(block_inplanes[3] * block.expansion, img_contribution)
        
        #qui ho cambiato da num_classes a 10, per far cat con testo
        #TESTO
        self.relu = nn.ReLU()
        self.ln1 = nn.Linear(tabular_val, 50) #23 sono le colonne in input
        self.ln2 = nn.Linear(50, 50)
        self.ln3 = nn.Linear(50, tabular_contribution)
        self.ln4 = nn.Linear(tabular_contribution+img_contribution, n_classes) #20 perchè 10 derivano da img e 10 da tab
        

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _downsample_basic_block(self, x, planes, stride):
        out = F.avg_pool3d(x, kernel_size=1, stride=stride)
        zero_pads = torch.zeros(out.size(0), planes - out.size(1), out.size(2),
                                out.size(3), out.size(4))
        if isinstance(out.data, torch.FloatTensor):
            zero_pads = zero_pads

        out = torch.cat([out.data, zero_pads], dim=1)

        return out

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(self._downsample_basic_block,
                                     planes=planes * block.expansion,
                                     stride=stride)
            else:
                downsample = nn.Sequential(
                    conv1x1x1(self.in_planes, planes * block.expansion, stride),
                    nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(in_planes=self.in_planes,
                  planes=planes,
                  stride=stride,
                  downsample=downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, tab):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if not self.no_max_pool:
            x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        tab = self.ln1(tab)
        tab = self.relu(tab)
        tab = self.ln2(tab)
        tab = self.relu(tab)
        tab = self.ln3(tab)
        tab = self.relu(tab)
        
        x = torch.cat((x, tab), dim=1)
        x= self.relu(x)
        
        x = self.ln4(x)        

        return x


def ResNet18(in_channels, num_classes):
    return ResNet(BasicBlock, [2, 2, 2, 2], in_channels=in_channels, num_classes=num_classes)


def generate_model(model_depth, **kwargs):
    assert model_depth in [10, 18, 34, 50, 101, 152, 200]

    if model_depth == 10:
        model = ResNet(BasicBlock, [1, 1, 1, 1], get_inplanes(), **kwargs)
    elif model_depth == 18:
        model = ResNet(BasicBlock, [2, 2, 2, 2], get_inplanes(), **kwargs)
    elif model_depth == 34:
        model = ResNet(BasicBlock, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 50:
        model = ResNet(Bottleneck, [3, 4, 6, 3], get_inplanes(), **kwargs)
    elif model_depth == 101:
        model = ResNet(Bottleneck, [3, 4, 23, 3], get_inplanes(), **kwargs)
    elif model_depth == 152:
        model = ResNet(Bottleneck, [3, 8, 36, 3], get_inplanes(), **kwargs)
    elif model_depth == 200:
        model = ResNet(Bottleneck, [3, 24, 36, 3], get_inplanes(), **kwargs)

    return model


model = generate_model(18, n_input_channels=1, widen_factor=1.0, 
                       n_classes=1, img_contribution=10, tabular_val=len(required_columns), tabular_contribution=10)
model = model.double()
print(model)


print('Total Parameters:',
      sum([torch.numel(p) for p in model.parameters()]))
print('Trainable Parameters:',
      sum([torch.numel(p) for p in model.parameters() if p.requires_grad]))

In [None]:
def train(net, loaders, optimizer, criterion, epochs=500, dev='cpu', save_param = True, model_name="adni_multi_input"):
    torch.manual_seed(myseed)
    start_time = time.time()
    try:
        net = net.to(dev)
        #print(net)
        # Initialize history
        history_loss = {"train": [], "val": [], "test": []}
        history_accuracy = {"train": [], "val": [], "test": []}
        # Store the best val accuracy
        best_val_accuracy = 0

        # Process each epoch
        for epoch in range(epochs):
            # Initialize epoch variables
            sum_loss = {"train": 0, "val": 0, "test": 0}
            sum_accuracy = {"train": 0, "val": 0, "test": 0}
            # Process each split
            for split in ["train", "val", "test"]:
                if split == "train":
                    net.train()
                else:
                    net.eval()
                # Process each batch
                for (image, tabular, labels) in loaders[split]:
                    # Move to CUDA
                    images = image.to(dev)
                    tabular = tabular.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred = net(image, tabular)
                    #pred = pred.squeeze(dim=1) # Output shape is [Batch size, 1], but we want [Batch size]
                    labels = labels.unsqueeze(1)
                    labels = labels.float()
                    loss = criterion(pred, labels)
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    #pred_labels = pred.argmax(1) + 1
                    pred_labels = (pred >= 0).long() # Binarize predictions to 0 and 1
                    batch_accuracy = (pred_labels == labels).sum().item()/image.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
                scheduler.step()
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}

            # Update history
            for split in ["train", "val", "test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])
                
            writer.add_scalar("Train Loss", epoch_loss['train'], epoch)
            writer.add_scalar("Valid Loss", epoch_loss['val'], epoch)
            writer.add_scalar("Test Loss", epoch_loss['test'], epoch)
            writer.add_scalar("Train Accuracy", epoch_accuracy['train'], epoch)
            writer.add_scalar("Valid Accuracy", epoch_accuracy['val'], epoch)
            writer.add_scalar("Test Accuracy", epoch_accuracy['test'], epoch)
            writer.add_scalar("ETA", time.time()-start_time, epoch)
            
            # Print info
            print(f"Epoch {epoch+1}:",
                  f"TrL={epoch_loss['train']:.4f},",
                  f"TrA={epoch_accuracy['train']:.4f},",
                  f"VL={epoch_loss['val']:.4f},",
                  f"VA={epoch_accuracy['val']:.4f},",
                  f"TeL={epoch_loss['test']:.4f},",
                  f"TeA={epoch_accuracy['test']:.4f},",
                  f"LR={optimizer.param_groups[0]['lr']:.5f},"
                  f"s={time.time()-start_time:.4f},")

            
           # Store params at the best validation accuracy
            if save_param:
                if (epoch_accuracy['val'] > best_val_accuracy):
                    print(f"\nFound new best: {epoch_accuracy['val']} - Saving best at epoch: {epoch+1}")
                    PATH = os.path.join(model_name,"best_val.pth")
                    try:
                        state_dict = net.module.state_dict()
                    except AttributeError:
                        state_dict = net.state_dict()
                        
                    torch.save({
                                'epoch': epoch,
                                'model_state_dict': state_dict,
                                'optimizer_state_dict': optimizer.state_dict(),
                                'loss': loss,
                                }, PATH)
                    best_val_accuracy = epoch_accuracy['val']
            
            
    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        # Plot loss
        plt.title("Loss")
        for split in ["train", "val", "test"]:
            plt.plot(history_loss[split], label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train", "val", "test"]:
            plt.plot(history_accuracy[split], label=split)
        plt.legend()
        plt.show()

In [None]:
def reset_weights(m):

    if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
        m.reset_parameters()


In [None]:
# Save test set for x-tests
outpath = os.path.join(f"runs",f"adni{adni_num}_{experiment_type}",f"{experiment_name}")
if not os.path.exists(outpath):
    os.makedirs(outpath)

generator = torch.Generator()
generator.manual_seed(myseed)
test_loader = DataLoader(test_data,  batch_size=8, num_workers=1, drop_last=False, shuffle=False, generator=generator)

tv_labels = adni_tv['labels'].tolist()
skf = StratifiedKFold(n_splits = num_folds)

for fold,(train_idx,val_idx) in enumerate(skf.split(tv_data, tv_labels)):
    
    writer = SummaryWriter(os.path.join(outpath,f"{fold}"), filename_suffix=f"_E{epochs}")
    print('------------fold no---------{}----------------------'.format(fold))   
    train_df = adni_tv.iloc[train_idx]
    train_set = ImgDataset(adni_df=train_df)

    val_df = adni_tv.iloc[val_idx]
    val_set = ImgDataset(adni_df=val_df)
    
    train_loader = DataLoader(train_set, batch_size=8, drop_last=False)
    val_loader = DataLoader(val_set, batch_size=8, drop_last=False)
    
    # Define dictionary of loaders
    loaders = {"train": train_loader,
               "val": val_loader,
               "test": test_loader}

    # Model Params
    optimizer = optim.Adam(model.parameters(), lr = 0.0001)
    # Define a loss 
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=epochs, steps_per_epoch=len(train_loader))
    
    
    #model, optimizer = ipex.optimize(model, optimizer=optimizer,dtype=torch.double)
    #model = model.float()
    # Train model
    train(model, loaders, optimizer, criterion, epochs=epochs, dev=dev, model_name=os.path.join(outpath,f"{fold}"))
    writer.flush()
    writer.close()
    model.apply(reset_weights)

### TEST on Cross datasets

Please make sure that all the cross-test sets have been also stored in a unique directory called "X-TEST_img-only" (default value). To change the default value, please update the "x_test_dir" variable at the beginning of the next cell.

In [None]:
x_test_dir = f'X-TEST_{test_data_suffix}'

external_test_names = ['test_adni1', 'test_adni2', 'test_adni3']
external_datal = {}
# Load external test_sets
for name in external_test_names:
    
    ext_test_path = os.path.join(outpath, x_test_dir ,f'{name}_{test_data_suffix}.pt')
    loaded_test = torch.load(ext_test_path)
    
    # Create DataLoader
    test_loader = DataLoader(loaded_test,  batch_size=8, num_workers=4, drop_last=False, shuffle=False, generator=generator)
    external_datal[name] = test_loader

In [None]:
x_test_results = {}

for fold in list(range(num_folds)):
#for fold in [0]:
    fold_results = {}
    #saved_test = torch.load(os.path.join(outpath, f'test_adni{adni_num}.pt') )
    best_model_path = os.path.join(outpath, f"{fold}","best_val.pth")

    model = generate_model(18, n_input_channels=1, widen_factor=1.0, 
                           n_classes=1, img_contribution=10, tabular_val=len(required_columns), 
                           tabular_contribution=10)
    
    model = model.double()
    optimizer = optim.Adam(model.parameters(), lr = 0.0001)

    checkpoint = torch.load(best_model_path)
    state_dict = checkpoint['model_state_dict']

    if False:
        from collections import OrderedDict
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            name = k[7:] # remove module.
            new_state_dict[name] = v
        #load params
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    net = model.eval()

    sum_loss = {x_test:0 for x_test in external_datal }
    sum_accuracy = {x_test:0 for x_test in external_datal }

    for x_test in external_datal:
        test_loader = external_datal[x_test]
        
        for (image, tabular, labels) in test_loader:
            # Move to CUDA
            images = image.to(dev)
            tabular = tabular.to(dev)
            labels = labels.to(dev)

            # Reset gradients
            optimizer.zero_grad()
            # Compute output
            pred = net(image, tabular)
            #pred = pred.squeeze(dim=1) # Output shape is [Batch size, 1], but we want [Batch size]
            labels = labels.unsqueeze(1)
            labels = labels.float()
            loss = criterion(pred, labels)

            # Update loss
            sum_loss[x_test] += loss.item()

            # Compute accuracy
            #pred_labels = pred.argmax(1) + 1
            pred_labels = (pred >= 0.0).long() # Binarize predictions to 0 and 1
            batch_accuracy = (pred_labels == labels).sum().item()/tabular.size(0)
            # Update accuracy
            sum_accuracy[x_test] += batch_accuracy

        scheduler.step()
        # Compute epoch loss/accuracy

        loss = {x_test: sum_loss[x_test]/len(external_datal[x_test]) for x_test in list(external_datal.keys())}
        accuracy = {x_test: sum_accuracy[x_test]/len(external_datal[x_test]) for x_test in list(external_datal.keys())}
        
        fold_results['loss'] = loss
        fold_results['accuracy'] = accuracy
        x_test_results[f"{fold}"] = fold_results

In [None]:
decimals = 4
final_summary={}
for x_test in external_datal.keys():  
    local_summary = []
    for f in x_test_results:
        acc = x_test_results[f]['accuracy'][x_test]
        local_summary.append(acc)
        
    final_summary[x_test] = local_summary 
    print(f"{x_test}, \
          \n Values = {local_summary}, \
          \n avg = {round(np.average(local_summary), decimals)}, std = {round(np.std(local_summary),decimals)}\n")

In [None]:
# Check the tensorboard aftern enabling the port fwd using same port: localhost:XXXX
#!tensorboard --logdir /PATH/TO/LOG/DIR --bind_all --load_fast=false --port=XXX