In [None]:
import numpy as np
from matplotlib import pyplot as plt
import os
import pandas as pd
import PIL
from PIL import Image
import torchvision
import torch
import glob
import nibabel as nib
import time
from sklearn.model_selection import StratifiedKFold

import torch.nn.functional as F
from torchvision import utils, transforms
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.model_selection import KFold
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

%matplotlib inline

In [None]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(dev)

### 1) Data wrangling

In [None]:
adni_num = "2"
experiment_type = "txt-only_XTEST"
folds_num = 5
epochs = 200


all_columns = ['AGE','PTGENDER','ADAS11', 'MMSE', 'FAQ', \
               'RAVLT_immediate', 'RAVLT_learning', 'RAVLT_forgetting', \
               'CDRSB', 'APOE4']

required_columns = ['AGE','PTGENDER','APOE4']
experiment_name = ('_'.join(required_columns)).lower()
print(f"{experiment_name}")

myseed = 1
torch.manual_seed(myseed)
np.random.seed(myseed)

num_classes=1

#### Retrieve ADNI table, with normalized values (ADNI_ready.csv)
ADNI_ready.csv must be created manually after downloading the preferred set of subjects from the ADNI website. ADNI_ready must include the following fields:
- subject_id
- ADNI type (1, 2 or 3 depending to which dataset it belongs to)
- Labels 
- 'AGE','PTGENDER','APOE4' columns

In [None]:
data_path = "ADNI_csv/"
filename = "ADNI_ready.csv"

adni_tabular = pd.read_csv(os.path.join(data_path, filename))
adni_tabular.head()

# Print 
print(f"ALL adni has {len(adni_tabular)} entries")
print(f"Class distribution is organized as follow:")
print(f"\n {adni_tabular['labels'].value_counts()}")

if (adni_num == '123'):
    experiment_name = 'full10'
else:
    adni_tabular=adni_tabular[adni_tabular['SRC']==f"ADNI{adni_num}"]

print(f"Adni{adni_num} has {len(adni_tabular)} entries")
print(f"Class distribution is organized as follow:")
print(f"Final:\n {adni_tabular['labels'].value_counts()}")


## Check for duplicated rows
dup = adni_tabular[adni_tabular.duplicated()]
if not dup.empty:
    print(f"WARNING: Dataframe contains duplicated rows!!!")

### 2) Dataset creator

In [None]:
class TxtDataset(Dataset):
    """Tabular and Image dataset."""

    def __init__(self, adni_df, required_columns=required_columns):
        self.adni = adni_df
        
    def __len__(self):
        return len(self.adni)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        line = self.adni.iloc[idx, 0:]
        # Get Label
        y = line['labels']

        # Get tabular
        tabular = line[required_columns] #line[['AGE','PTGENDER','APOE4']]
        tabular = torch.FloatTensor(tabular)

        return tabular, y

img_data = TxtDataset(adni_df=adni_tabular)

### Detach test set and use remaining data for train-val k-fold split (further down)

In [None]:
from sklearn.model_selection import train_test_split

labels = adni_tabular['labels'].tolist()
# Split data into train+val and test set indexes
tv_idx, test_idx = train_test_split(np.arange(len(labels)), test_size=0.1,shuffle=True,stratify=labels)

# Create train+val dataframe and show class balance
adni_tv = adni_tabular.iloc[tv_idx]
print(adni_tv.groupby(["labels"]).count())
tv_data = TxtDataset(adni_df=adni_tv)

# Create test dataframe and show class balance
adni_test = adni_tabular.iloc[test_idx]
print(adni_test.groupby(["labels"]).count())
test_data = TxtDataset(adni_df=adni_test)

In [None]:
# If executed, this cell will save the test set for this specific adni_num.
# the saved test set can be then be shared to perform a cross-test (X-TEST) models evaluation
if False:
    torch.save(test_data, f'test_adni{adni_num}.pt')
    saved_test = torch.load(f'test_adni{adni_num}.pt')

    i = 10
    print(f"tabular = {saved_test[i][0]}, label = {saved_test[i][1]}")

    len(saved_test[i][0])

In [None]:
i = 10
print(f"tabular = {test_data[i][0]}, label = {test_data[i][1]}")

len(test_data[i][0])

### 3. Model: tabular

In [None]:
class TextNN(nn.Module):

    #Constructor
    def __init__(self, num_variables):
    # Call parent contructor
        super().__init__()
        torch.manual_seed(myseed)
        self.relu = nn.ReLU()
        self.ln1 = nn.Linear(num_variables, 50) #num_variables sono le colonne in input
        self.ln2 = nn.Linear(50, 50)
        self.ln3 = nn.Linear(50, 10)
        self.ln4 = nn.Linear(10, 1)
    
    def forward(self, tab):
        tab = self.ln1(tab)
        tab = self.relu(tab)
        tab = self.ln2(tab)
        tab = self.relu(tab)
        tab = self.ln3(tab)
        tab = self.relu(tab)
        tab = self.ln4(tab)

        return tab

model = TextNN(len(test_data[i][0])) # required_columns - label column
print(model)

print('Total Parameters:',
      sum([torch.numel(p) for p in model.parameters()]))
print('Trainable Parameters:',
      sum([torch.numel(p) for p in model.parameters() if p.requires_grad]))

In [None]:
def train(net, loaders, optimizer, criterion, epochs=500, dev='cpu', save_param = True, model_name="adni_only-text"):
    torch.manual_seed(myseed)
    try:
        net = net.to(dev)
        #print(net)
        # Initialize history
        history_loss = {"train": [], "val": [], "test": []}
        history_accuracy = {"train": [], "val": [], "test": []}
        # Store the best val accuracy
        best_val_accuracy = 0

        # Process each epoch
        for epoch in range(epochs):
            # Initialize epoch variables
            sum_loss = {"train": 0, "val": 0, "test": 0}
            sum_accuracy = {"train": 0, "val": 0, "test": 0}
            # Process each split
            for split in ["train", "val", "test"]:
                if split == "train":
                    net.train()
                else:
                    net.eval()
                # Process each batch
                for (tabular, labels) in loaders[split]:
                    # Move to CUDA
                    tabular = tabular.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred = net(tabular)
                    #pred = pred.squeeze(dim=1) # Output shape is [Batch size, 1], but we want [Batch size]
                    labels = labels.unsqueeze(1)
                    labels = labels.float()
                    loss = criterion(pred, labels)
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    #pred_labels = pred.argmax(1) + 1
                    pred_labels = (pred >= 0.0).long() # Binarize predictions to 0 and 1
                    batch_accuracy = (pred_labels == labels).sum().item()/tabular.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
                scheduler.step()
            # Compute epoch loss/accuracy

            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}          
                       
            # Update history
            for split in ["train", "val", "test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])
  
            writer.add_scalar("Train Loss", epoch_loss['train'], epoch)
            writer.add_scalar("Valid Loss", epoch_loss['val'], epoch)
            writer.add_scalar("Test Loss", epoch_loss['test'], epoch)
            writer.add_scalar("Train Accuracy", epoch_accuracy['train'], epoch)
            writer.add_scalar("Valid Accuracy", epoch_accuracy['val'], epoch)
            writer.add_scalar("Test Accuracy", epoch_accuracy['test'], epoch)

            # Print info
            print(f"Epoch {epoch+1}:",
                  f"TrL={epoch_loss['train']:.4f},",
                  f"TrA={epoch_accuracy['train']:.4f},",
                  f"VL={epoch_loss['val']:.4f},",
                  f"VA={epoch_accuracy['val']:.4f},",
                  f"TeL={epoch_loss['test']:.4f},",
                  f"TeA={epoch_accuracy['test']:.4f},",
                  f"LR={optimizer.param_groups[0]['lr']:.5f},")          

            
            
            # Store params at the best validation accuracy
            if save_param:
                if (epoch_accuracy['val'] > best_val_accuracy):
                    print(f"\nFound new best: {epoch_accuracy['val']} - Saving best at epoch: {epoch+1}")
                    PATH = os.path.join(model_name,"best_val.pth")
                    try:
                        state_dict = net.module.state_dict()
                    except AttributeError:
                        state_dict = net.state_dict()
                        
                    torch.save({
                                'epoch': epoch,
                                'model_state_dict': state_dict,
                                'optimizer_state_dict': optimizer.state_dict(),
                                'loss': loss,
                                }, PATH)
                    best_val_accuracy = epoch_accuracy['val']


    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        # Plot loss
        plt.title("Loss")
        for split in ["train", "val", "test"]:
            plt.plot(history_loss[split], label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train", "val", "test"]:
            plt.plot(history_accuracy[split], label=split)
        plt.legend()
        plt.show()

In [None]:
def reset_weights(m):

    if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
        m.reset_parameters()


In [None]:
# Save test set for x-tests
outpath = f"runs/adni{adni_num}_{experiment_type}/{experiment_name}"
if not os.path.exists(outpath):
    os.makedirs(outpath)


torch.save(test_data, os.path.join(outpath, f'test_adni{adni_num}.pt') )

generator = torch.Generator()
generator.manual_seed(myseed)
test_loader = DataLoader(test_data,  batch_size=8, num_workers=4, drop_last=False, shuffle=False, generator=generator)

tv_labels = adni_tv['labels'].tolist()

skf = StratifiedKFold(n_splits = folds_num)

for fold,(train_idx,val_idx) in enumerate(skf.split(tv_data, tv_labels)):
    
    writer = SummaryWriter(os.path.join(outpath,f"{fold}"), filename_suffix=f"_E{epochs}")
    print('------------fold no---------{}----------------------'.format(fold))   
    train_df = adni_tv.iloc[train_idx]
    train_set = TxtDataset(adni_df=train_df)

    val_df = adni_tv.iloc[val_idx]
    val_set = TxtDataset(adni_df=val_df)
    
    train_loader = DataLoader(train_set, batch_size=8, num_workers=1, drop_last=False)
    val_loader = DataLoader(val_set, batch_size=8, num_workers=1, drop_last=False)
    
    # Define dictionary of loaders
    loaders = {"train": train_loader,
               "val": val_loader,
               "test": test_loader}

    # Model Params
    optimizer = optim.Adam(model.parameters(), lr = 0.0001)
    # Define a loss 
    criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=epochs, steps_per_epoch=len(train_loader))
    
    # Train model
    train(model, loaders, optimizer, criterion, epochs=epochs, dev=dev, model_name=os.path.join(outpath,f"{fold}") )
    writer.flush()
    writer.close()
    model.apply(reset_weights)

### TEST on Cross datasets

Please make sure that all the cross-test sets have been also stored in a unique directory called "X-TEST_txt-only" (default value). To change the default value, please update the "x_test_dir" variable at the beginning of the next cell.

In [None]:
x_test_dir = 'X-TEST_txt-only'

external_test_names = ['test_adni1', 'test_adni2', 'test_adni3']
external_datal = {}
# Load external test_sets
for name in external_test_names:
    
    ext_test_path = os.path.join(outpath, x_test_dir ,f'{name}.pt')
    loaded_test = torch.load(ext_test_path)
    
    # Create DataLoader
    test_loader = DataLoader(loaded_test,  batch_size=8, num_workers=4, drop_last=False, shuffle=False, generator=generator)
    external_datal[name] = test_loader

In [None]:
x_test_results = {}

for fold in list(range(folds_num)):
    fold_results = {}
    #saved_test = torch.load(os.path.join(outpath, f'test_adni{adni_num}.pt') )
    best_model_path = os.path.join(outpath, f"{fold}","best_val.pth")

    model = TextNN(len(saved_test[i][0]))
    optimizer = optim.Adam(model.parameters(), lr = 0.0001)

    checkpoint = torch.load(best_model_path)
    state_dict = checkpoint['model_state_dict']

    if False:
        from collections import OrderedDict
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            name = k[7:] # remove module.
            new_state_dict[name] = v
        #load params
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(state_dict)

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    net = model.eval()



    sum_loss = {x_test:0 for x_test in external_datal }
    sum_accuracy = {x_test:0 for x_test in external_datal }

    for x_test in external_datal:
        test_loader = external_datal[x_test]
        for (tabular, labels) in test_loader:
            # Move to CUDA
            tabular = tabular.to(dev)
            labels = labels.to(dev)

            # Reset gradients
            optimizer.zero_grad()
            # Compute output
            pred = net(tabular)
            #pred = pred.squeeze(dim=1) # Output shape is [Batch size, 1], but we want [Batch size]
            labels = labels.unsqueeze(1)
            labels = labels.float()
            loss = criterion(pred, labels)

            # Update loss
            sum_loss[x_test] += loss.item()

            # Compute accuracy
            #pred_labels = pred.argmax(1) + 1
            pred_labels = (pred >= 0.0).long() # Binarize predictions to 0 and 1
            batch_accuracy = (pred_labels == labels).sum().item()/tabular.size(0)
            # Update accuracy
            sum_accuracy[x_test] += batch_accuracy

        scheduler.step()
        # Compute epoch loss/accuracy

        loss = {x_test: sum_loss[x_test]/len(external_datal[x_test]) for x_test in list(external_datal.keys())}
        accuracy = {x_test: sum_accuracy[x_test]/len(external_datal[x_test]) for x_test in list(external_datal.keys())}
        
        fold_results['loss'] = loss
        fold_results['accuracy'] = accuracy
        x_test_results[f"{fold}"] = fold_results

In [None]:
## Return results

decimals = 4
final_summary={}
for x_test in external_datal.keys():  
    local_summary = []
    for f in x_test_results:
        acc = x_test_results[f]['accuracy'][x_test]
        local_summary.append(acc)
        
    final_summary[x_test] = local_summary 
    print(f"{x_test}, \
          \n Values = {local_summary}, \
          \n avg = {round(np.average(local_summary), decimals)}, std = {round(np.std(local_summary),decimals)}\n")

In [None]:
# Check the tensorboard aftern enabling the port fwd using same port: localhost:XXXX
#!tensorboard --logdir /PATH/TO/LOG/DIR --bind_all --load_fast=false --port=XXX