In [16]:
import sys
sys.path.append('../')

from nilearn.image import index_img, smooth_img
from nilearn.masking import apply_mask
from nibabel.nifti1 import Nifti1Image
import os
import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader, Subset, Dataset, TensorDataset
from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler
from torchvision import transforms
import pytorch_lightning as pl
import numpy as np
import glob
import pandas as pd
import math
from functools import partial
from argparse import ArgumentParser
from pytorch_lightning.loggers import WandbLogger
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import wandb
import torchio as tio
from nilearn.image import crop_img, resample_to_img
from torch.optim.lr_scheduler import ReduceLROnPlateau, ExponentialLR
from torchmetrics.functional import accuracy

import warnings
from typing import (
    Callable,
    ClassVar,
    Dict,
    Iterable,
    List,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    Any
)


In [24]:
def class_imbalance_sampler(labels):
    # if type(labels) != torch.Tensor:
    #     labels = torch.tensor(labels)
    # print(labels.shape)

    class_sample_count = torch.tensor(
        [(labels == t).sum() for t in torch.unique(labels, sorted=True)])

    print(f"Class_count: {class_sample_count}")

    weight = 1. / class_sample_count.float()

    samples_weight = torch.tensor([weight[t] for t in labels])

    # Create sampler, dataset, loader
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight),replacement=True)

    pos_weight = torch.tensor(class_sample_count[1] / (class_sample_count[0] + 1e-5), dtype=torch.float)

    return sampler, weight

class MRIDataModuleIO(pl.LightningDataModule):
    def __init__(self, data_dir: str, labels: List[int], format: str, batch_size: int, augment: List[str],
                 mask: str = '', file_paths: List[str] = None, num_workers: int = 1, sampler: bool=True):
        super().__init__()
        self.data_dir = data_dir
        self.labels = torch.tensor(labels)
        self.format = format
        self.n = len(labels)
        self.n_train = int(.9*self.n)
        self.n_test = self.n - self.n_train
        self.mask = mask
        self.batch_size = batch_size
        self.augment = augment
        self.file_paths = file_paths
        self.num_workers = num_workers

        shuffled_ind = np.random.choice(range(self.n),len(range(self.n)),replace=False)

        self.train_labels = self.labels[shuffled_ind[:self.n_train]]
        self.test_labels = self.labels[shuffled_ind[self.n_train:]]
        self.train_paths = self.file_paths[shuffled_ind[:self.n_train]]
        self.test_paths = self.file_paths[shuffled_ind[self.n_train:]]
        self.sampler = sampler

        # check test distribution
        class_sample_count_test = torch.tensor(
            [(self.test_labels == int(t)).sum() for t in torch.unique(self.labels, sorted=True)])

        class_sample_count_train = torch.tensor(
            [(self.train_labels == int(t)).sum() for t in torch.unique(self.labels, sorted=True)])

        print(f"Class distribution in test set: {class_sample_count_test}")
        print(f"Class distribution in train set: {class_sample_count_train}")


    def get_max_shape(self, subjects):

        preprocess = tio.Compose([
            tio.EnsureShapeMultiple(2)
        ])

        dataset = tio.SubjectsDataset(subjects,transform=preprocess)
        shapes = np.array([s.spatial_shape for s in dataset])
        self.max_shape = shapes.max(axis=0)
        return self.max_shape

    def prepare_data(self):
        image_training_paths = self.train_paths
        label_training = self.train_labels
        image_test_paths = self.test_paths
        label_test = self.test_labels

        self.subjects = []
        for image_path, label in zip(image_training_paths, label_training):
            # 'image' and 'label' are arbitrary names for the images
            subject = tio.Subject(
                image=tio.ScalarImage(image_path),
                label=label
            )
            self.subjects.append(subject)

        self.test_subjects = []
        for image_path,label in  zip(image_test_paths,label_test):
            subject = tio.Subject(image=tio.ScalarImage(image_path),
                                  label=label)
            self.test_subjects.append(subject)

    def get_preprocessing_transform(self):
        preprocess = tio.Compose([
            tio.CropOrPad(self.get_max_shape(self.subjects + self.test_subjects)),
            tio.EnsureShapeMultiple(2),  
            tio.RescaleIntensity((-1, 1)),
        ])
        return preprocess

    def get_augmentation_transform(self):

        if self.augment:
            augment=[]
            for a in self.augment:
                if a == 'affine':
                    augment.append(tio.RandomAffine())
                elif a == 'noise':
                    augment.append(tio.RandomNoise(p=0.3))

                elif a == 'motion':
                    augment.append(tio.RandomMotion(p=0.2))

            augment = tio.Compose(augment)
            return augment
        else:
            return None
        
    def setup(self, stage=None):


        indices = range(self.n_train) #np.random.choice(range(self.n_train), range(self.n_train), replace=False)
        split = int(np.floor(.2 * self.n_train))
        train_indices, val_indices = indices[split:], indices[:split]
        # Creating PT data samplers and loaders:
        if self.sampler is not True:
            self.train_sampler = SubsetRandomSampler(train_indices)
            self.val_sampler = SubsetRandomSampler(val_indices)
        else:

            self.train_sampler, self.pos_weight = class_imbalance_sampler(self.train_labels[train_indices])


        train_subjects = [self.subjects[i] for i in train_indices]
        val_subjects = [self.subjects[i] for i in val_indices]

        self.preprocess = self.get_preprocessing_transform()
        augment = self.get_augmentation_transform()
        if augment is not None:
            self.transform = tio.Compose([self.preprocess, augment])
        else:
            self.transform = self.preprocess

        self.train_set = tio.SubjectsDataset(train_subjects, transform=self.transform)
        self.val_set = tio.SubjectsDataset(val_subjects, transform=self.preprocess)
        self.test_set = tio.SubjectsDataset(self.test_subjects, transform=self.preprocess)


    def train_dataloader(self):
        return DataLoader(self.train_set, self.batch_size, sampler=self.train_sampler, num_workers=self.num_workers, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, self.batch_size, num_workers=self.num_workers, drop_last=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, self.batch_size)
    

def get_mri_data_beta(N,data_dir,cropped=False,test=False):

    name = f"data_split_c.csv"
    df = pd.read_csv(os.path.join(data_dir, name))
    labels = []

    dfg = df.groupby("class")
    data = []
    for name, subdata in dfg:
        print(f"Group: {name}")
        # shuffle
        K = subdata.shape[0]
        shuffled_ind = np.random.choice(range(K),len(range(K)),replace=False)
        #subsample
        shuffled_ind = shuffled_ind[:N]
        data.extend(subdata["filename"].values[shuffled_ind])
        labels.extend(subdata["class"].values[shuffled_ind])

    # shuffle
    data = np.array(data).reshape(-1)
    labels=np.array(labels).reshape(-1)
    ind = np.random.choice(len(labels),len(labels),replace=False)

    labels = labels[ind]
    data= data[ind]

    assert data.shape[0] == labels.shape[0]

    return data, labels

In [28]:
class VGG(pl.LightningModule):
    def __init__(self, features, **conf):
        super().__init__()
        self.save_hyperparameters(ignore=["features"])
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool3d((7,7,7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, self.hparams.num_classes),
        )
        #if init_weights:
        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
                
    def training_step(self, batch, batch_idx):
        x = batch['image'][tio.DATA]
        y = batch['label']
        raw_out = self(x)
        loss = self.loss(raw_out, y)
        preds = torch.argmax(torch.softmax(raw_out, dim=1), dim=1)
        acc = accuracy(preds, y)

        #print(f"Train Loss: {loss}")
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc)

        return loss


    def evaluate(self, batch, stage=None):
        x = batch['image'][tio.DATA]
        y = batch['label']
        raw_out = self(x)
        loss = self.loss(raw_out, y)
        preds = torch.argmax(torch.softmax(raw_out, dim=1), dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)


    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        # self.hparams available because we called self.save_hyperparameters()
        if self.hparams.optim == 'adam':
            optim = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate,weight_decay=1e-4)
        elif self.hparams.optim == 'sgd':
            optim = torch.optim.SGD(self.parameters(), lr=self.hparams.learning_rate, weight_decay=1e-4)
        elif  self.hparams.optim == 'adamw':
            optim = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)

        else:
            raise NotImplementedError

        return {
            "optimizer": optim,
            "lr_scheduler": {
                "scheduler": ExponentialLR(optim, gamma=0.9), #ReduceLROnPlateau(optim, ...),
                "monitor": "valid_loss",
            },
        }


    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group("VGGNet")
        parser.add_argument('--learning_rate', type=float, default=0.001)
        parser.add_argument('--dropout', type=float, default=0.5)
        parser.add_argument('--name', type=str, default='vggnet')
        parser.add_argument('--optim', type=str, default='adam')
        return parent_parser        

In [9]:
def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 1
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool3d(kernel_size=2, stride=2)]
        else:
            conv3d = nn.Conv3d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv3d, nn.BatchNorm3d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv3d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


In [19]:
cfgs = {
    'A': [8, 'M', 16, 'M', 32, 32, 'M', 64, 64, 'M'],
    #'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


In [29]:
def _vgg(arch, cfg, batch_norm, pretrained, progress, hparams):
    if pretrained:
        kwargs['init_weights'] = False
    model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **hparams)
    if pretrained:
        pass
       #state_dict = load_state_dict_from_url(model_urls[arch],
                #                              progress=progress)
        #model.load_state_dict(state_dict)
    return model

def vgg11(pretrained=False, progress=True, hparams={}):
    r"""VGG 11-layer model (configuration "A") from
    `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_

    Args:
       	pretrained (bool): If True, returns a model pre-trained on ImageNet
       	progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _vgg('vgg11', 'A', False, pretrained, progress, hparams)


In [30]:
args = {'data_dir': '/scratch/spinney/enigma_drug/data',
       'batch_size': 16,
       'num_classes': 5,
       'num_workers': 4,
       'num_samples': -1,
       'seed': 0,
       'format': 'nifti',
       'test': False,
       'cropped': True,
       'augment': None,
       'learning_rate': 0.001,
       'dropout': 0.5,
       'name': 'vggnet',
       'optim': 'adam',
       'max_epochs': 20}

# set global seed
pl.seed_everything(args['seed'])

mask = ''

if args['num_samples'] == -1:
    args['num_samples'] = -1*args['num_classes']

# these are returned shuffled
file_paths, labels = get_mri_data_beta(args['num_samples']//args['num_classes'], args['data_dir'], cropped=args['cropped'], test=False)

dm = MRIDataModuleIO(args['data_dir'], labels, args['format'], args['batch_size'], args['augment'], mask, file_paths, args['num_workers'])
dm.prepare_data()
dm.setup(stage='fit')

print(f"Input shape used: {dm.max_shape}")
dict_args = args
dict_args['pos_weight'] = dm.pos_weight
dict_args['input_shape'] = dm.max_shape
dict_args['class_names'] = ["control","ALC","ATS","COC","NIC"] 
slurm = os.environ.get("SLURM_JOB_NUM_NODES")
num_nodes = int(slurm) if slurm else 1

Global seed set to 0


Group: 0
Group: 1
Group: 2
Group: 3
Group: 4
Class distribution in test set: tensor([ 61, 105,  11,  13,  40])
Class distribution in train set: tensor([582, 892, 165, 120, 305])
Class_count: tensor([462, 729, 139,  92, 230])




Input shape used: [172 176 218]


In [34]:
dm.pos_weight
model = vgg11(pretrained=False, progress=True, hparams=dict_args)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [26]:


trainer = pl.Trainer(default_root_dir="/scratch/spinney/enigma_drug/checkpoints/",
                     gpus=torch.cuda.device_count(),
                     num_nodes=num_nodes,
                     strategy='ddp' if num_nodes > 1 else 'dp',
                     max_epochs=args['max_epochs'],
                     log_every_n_steps=10,
#                     logger=wandb_logger,
                     replace_sampler_ddp=False)#,
                     #precision=16)
                     #early_stop_callback=False)
                     #callbacks=[early_stopping_callback])

trainer.fit(model, dm)

# ------------
# testing
# ------------

dm.setup(stage='test')
trainer.test(datamodule=dm)

NameError: name 'kwargs' is not defined