In [1]:
import pandas as pd
import numpy as np
import torch 
import tqdm 
import glob, os, pickle
import time
import seaborn as sns
import sys
import copy
import tifffile
import torch.nn as nn
import matplotlib
import fire, math
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets as Datasets, models
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageFile
from sklearn.metrics import classification_report, f1_score
from sklearn.utils.class_weight import compute_class_weight
from kornia.losses import DiceLoss
from torch.optim.lr_scheduler import ExponentialLR


In [2]:
def load_image(image_file, check_size=False, mmap_mode=None):
    img_ext=os.path.splitext(image_file)
    if img_ext[-1]==".npy":
        image=np.load(image_file, mmap_mode=mmap_mode)
    elif img_ext[-1] in [".svs",".tif",".tiff",".png"]:
        if check_size:
            import openslide
            slide=openslide.open_slide(image_file)
        image=tifffile.imread(image_file, aszarr=mmap_mode is not None)
        if mmap_mode is not None:
            import zarr
            image=zarr.open(image, mode=mmap_mode)
        if check_size and (not (int(slide.properties.get('aperio.AppMag',40))==20 or int(slide.properties.get('openslide.objective-power',40))==20)):
            image = cv2.resize(image,None,fx=1/2,fy=1/2,interpolation=cv2.INTER_CUBIC)
    else:
        raise NotImplementedError
    return 

In [3]:
class PickleDataset(Dataset):
    def __init__(self, pkl, transform, label_map):
        self.data=pickle.load(open(pkl,'rb'))
        self.X,self.targets=self.data['X'],self.data['y']
        self.aux_data=self.data.get("z",None)
        self.has_aux=(self.aux_data is not None)
        if self.has_aux and isinstance(self.aux_data,pd.DataFrame): self.aux_data=self.aux_data.values
        if self.has_aux: self.n_aux_features=self.aux_data.shape[1]
        self.transform=transform
        self.to_pil=lambda x: Image.fromarray(x)
        self.label_map=label_map
        if self.label_map:
            self.targets=pd.Series(self.targets).map(lambda x: self.label_map.get(x,-1)).values
            if -1 in self.targets:
                remove_bool=(self.targets!=-1)
                self.targets=self.targets[remove_bool]
                self.X=pd.Series(self.X).iloc[remove_bool].tolist()
                if self.has_aux: self.aux_data=self.aux_data[remove_bool]
        self.length=len(self.X)

    def __getitem__(self,idx):
        items=(self.transform(self.to_pil(self.X[idx])), torch.tensor(self.targets[idx]).long())
        if self.has_aux: items+=(torch.tensor(self.aux_data[idx]).float(),)
        return items

    def __len__(self):
        return self.length

class NPYRotatingStack(Dataset):
    def __init__(self, patch_dir, transform, sample_frac=1., sample_every=0, target_col={'old_y_true':'y_true'},npy_rotate_sets_pkl="",Set=""):
        self.npy_rotate_sets_pkl=npy_rotate_sets_pkl
        if npy_rotate_sets_pkl:
            self.patch_npy=pd.read_pickle(self.npy_rotate_sets_pkl)
            self.patch_pkl=self.patch_npy[self.patch_npy['Set']==Set]['pkl'].values
            self.patch_npy=self.patch_npy[self.patch_npy['Set']==Set]['npy'].values
        else:
            self.patch_npy=np.array(glob.glob(os.path.join(patch_dir,"*.npy")))
            self.patch_pkl=np.vectorize(lambda x: x.replace(".npy",".pkl"))(self.patch_npy)
        self.sample_every=sample_every
        self.sample_frac=sample_frac
        if self.sample_frac==1: self.sample_every=0
        self.target_col=list(target_col.items())[0]
        self.ref_index=None # dictionary
        self.data={}
        self.cache_npy=None # dictionary keys
        self.to_pil=lambda x: Image.fromarray(x)
        self.transform=transform
        assert self.target_col[1]=='y_true'
        self.targets=np.hstack([pd.read_pickle(pkl)[self.target_col[0]].values for pkl in self.patch_pkl])
        self.load_image_annot()

    def load_image_annot(self):
        if self.sample_frac<1.:
            idx=np.arange(len(self.patch_npy))
            idx=np.random.choice(idx,int(self.sample_frac*len(idx)))
            patch_npy=self.patch_npy[idx]
            patch_pkl=self.patch_pkl[idx]
            remove_npy=np.setdiff1d(self.patch_npy,patch_npy)
            for npy in remove_npy:
                if isinstance(self.cache_npy,type(None))==False and npy not in self.cache_npy:
                    del self.data[npy]
            new_data={npy:(dict(patches=load_image(npy),
                               patch_info=pd.read_pickle(pkl)) if (self.cache_npy is None or (npy not in self.cache_npy if self.cache_npy is not None else False)) else self.data[npy]) for npy,pkl in zip(patch_npy,patch_pkl)}
            self.data.clear()
            self.data=new_data
            self.cache_npy=sorted(list(self.data.keys()))
        else:
            self.data={npy:dict(patches=load_image(npy),
                               patch_info=pd.read_pickle(pkl)) for npy,pkl in zip(self.patch_npy,self.patch_pkl)}
            self.cache_npy=sorted(self.patch_npy)
        self.ref_index=np.vstack([np.array(([i]*self.data[npy]['patch_info'].shape[0],list(range(self.data[npy]['patch_info'].shape[0])))).T for i,npy in enumerate(self.cache_npy)])
        for npy in self.data: self.data[npy]['patch_info'][self.target_col[1]]=self.data[npy]['patch_info'][self.target_col[0]]
        self.length=self.ref_index.shape[0]

    def __getitem__(self,idx):
        i,j=self.ref_index[idx]
        npy=self.cache_npy[i]
        X=self.data[npy]['patches'][j]
        y=torch.LongTensor(np.array(self.data[npy]['patch_info'].iloc[j][self.target_col[1]]).reshape(1))
        X=self.transform(self.to_pil(X))
        return X, y

    def __len__(self):
        return self.length

In [4]:
class Scheduler:
    """Scheduler class that modulates learning rate of torch optimizers over epochs.
    Parameters
    ----------
    optimizer : type
            torch.Optimizer object
    opts : type
            Options of setting the learning rate scheduler, see default.
    Attributes
    ----------
    schedulers : type
            Different types of schedulers to choose from.
    scheduler_step_fn : type
            How scheduler updates learning rate.
    initial_lr : type
            Initial set learning rate.
    scheduler_choice : type
            What scheduler type was chosen.
    scheduler : type
            Scheduler object chosen that will more directly update optimizer LR.
    """

    def __init__(self, optimizer=None, opts=dict(scheduler='null', lr_scheduler_decay=0.5, T_max=10, eta_min=5e-8, T_mult=2)):
        self.schedulers = {'exp': (lambda optimizer: ExponentialLR(optimizer, opts["lr_scheduler_decay"])),
                           'null': (lambda optimizer: None),
                           'warm_restarts': (lambda optimizer: CosineAnnealingWithRestartsLR(optimizer, T_max=opts['T_max'], eta_min=opts['eta_min'], last_epoch=-1, T_mult=opts['T_mult']))}
        self.scheduler_step_fn = {'exp': (lambda scheduler: scheduler.step()),
                                  'warm_restarts': (lambda scheduler: scheduler.step()),
                                  'null': (lambda scheduler: None)}
        self.initial_lr = optimizer.param_groups[0]['lr']
        self.scheduler_choice = opts['scheduler']
        self.scheduler = self.schedulers[self.scheduler_choice](
            optimizer) if optimizer is not None else None

    def step(self):
        """Update optimizer learning rate"""
        self.scheduler_step_fn[self.scheduler_choice](self.scheduler)

    def get_lr(self):
        """Return current learning rate.
Returns
-------
float
    Current learning rate.
"""
        lr = (self.initial_lr if self.scheduler_choice
              == 'null' else self.scheduler.optimizer.param_groups[0]['lr'])
        return lr

In [5]:
class CosineAnnealingWithRestartsLR(torch.optim.lr_scheduler._LRScheduler):
    r"""Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
     .. math::
             \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
            \cos(\frac{T_{cur}}{T_{max}}\pi))
     When last_epoch=-1, sets initial lr as lr.
     It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. This implements
    the cosine annealing part of SGDR, the restarts and number of iterations multiplier.
     Args:
            optimizer (Optimizer): Wrapped optimizer.
            T_max (int): Maximum number of iterations.
            T_mult (float): Multiply T_max by this number after each restart. Default: 1.
            eta_min (float): Minimum learning rate. Default: 0.
            last_epoch (int): The index of last epoch. Default: -1.
     .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
            https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, T_mult=1., alpha_decay=1.0):
        self.T_max = T_max
        self.T_mult = T_mult
        self.restart_every = T_max
        self.eta_min = eta_min
        self.restarts = 0
        self.restarted_at = 0
        self.alpha = alpha_decay
        super().__init__(optimizer, last_epoch)

    def restart(self):
        self.restarts += 1
        self.restart_every = int(round(self.restart_every * self.T_mult))
        self.restarted_at = self.last_epoch

    def cosine(self, base_lr):
        return self.eta_min + self.alpha**self.restarts * (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.step_n() / self.restart_every)) / 2

    def step_n(self):
        return self.last_epoch - self.restarted_at

    def get_lr(self):
        if self.step_n() >= self.restart_every:
            self.restart()
        return [self.cosine(base_lr) for base_lr in self.base_lrs]



In [6]:
class AuxNet(nn.Module):
    def __init__(self,net,n_aux_features):
        super().__init__()
        self.net=net
        self.features=self.net.features
        self.output=self.net.output
        self.n_features=self.net.output.in_features
        self.n_aux_features=n_aux_features
        self.transform_nn=nn.Sequential(nn.Linear(self.n_aux_features,self.n_features),nn.LeakyReLU())
        self.gate_nn=MLP(self.n_features,[32],dropout_p=0.2,binary=False)#nn.Linear(self.n_features,1)

    def forward(self,x,z=None):
        x=self.features(x)
        x = x.view(x.size(0), -1)
        if z is not None:
            z=self.transform_nn(z)
            #print(x.shape,z.shape,self.gate_nn(x).shape,self.gate_nn(z).shape)
            gate_h=F.softmax(torch.cat([self.gate_nn(xz) for xz in [x,z]],1),1)
            x = gate_h[:,0].unsqueeze(1) * x + gate_h[:,1].unsqueeze(1) * z
        x = self.output(x)
        return x

def prepare_model(model_name,
                  use_pretrained,
                  pretrained_model_file_path,
                  use_cuda=False,
                  use_data_parallel=True,
                  load_ignore_extra=False,
                  num_classes=3,
                  in_channels=3,
                  remap_to_cpu=True,
                  remove_module=False,
                  n_aux_features=None):
    from pytorchcv.model_provider import get_model
    """ https://raw.githubusercontent.com/osmr/imgclsmob/master/pytorch/utils.py
        Create and initialize model by name.
        Parameters
        ----------
        model_name : str
        Model name.
        use_pretrained : bool
        Whether to use pretrained weights.
        pretrained_model_file_path : str
        Path to file with pretrained weights.
        use_cuda : bool
        Whether to use CUDA.
        use_data_parallel : bool, default True
        Whether to use parallelization.
        net_extra_kwargs : dict, default None
        Extra parameters for model.
        load_ignore_extra : bool, default False
        Whether to ignore extra layers in pretrained model.
        num_classes : int, default None
        Number of classes.
        in_channels : int, default None
        Number of input channels.
        remap_to_cpu : bool, default False
        Whether to remape model to CPU during loading.
        remove_module : bool, default False
        Whether to remove module from loaded model.
        Returns
        -------
        Module
        Model.
        """
  
    net = get_model(model_name)

    if n_aux_features is not None:
        net=AuxNet(net,n_aux_features)
    
    return net


def generate_model(architecture, num_classes, pretrained=False, n_aux_features=None):
    #    from pytorchcv.pytorch.utils import prepare_model
    if os.path.exists(architecture):
        model = torch.load(architecture,map_location='cpu')
    else:
        model = prepare_model(architecture,
                          use_pretrained=pretrained,
                          pretrained_model_file_path='',
                          use_cuda=False,
                          num_classes=num_classes,
                          n_aux_features=n_aux_features)
    
    
    thresh = 3
    ct = 0
    #here we freeze up to and including the 4th layer
    for child in model.children():
        if ct <= thresh:
            for param in child.parameters():
                param.requires_grad = False
            print(child, ct)
            ct += 1
    return model

class ModelTrainer:
    """Trainer for the neural network model that wraps it into a scikit-learn like interface.
    Parameters
    ----------
    model:nn.Module
            Deep learning pytorch model.
    n_epoch:int
            Number training epochs.
    validation_dataloader:DataLoader
            Dataloader of validation dataset.
    optimizer_opts:dict
            Options for optimizer.
    scheduler_opts:dict
            Options for learning rate scheduler.
    loss_fn:str
            String to call a particular loss function for model.
    reduction:str
            Mean or sum reduction of loss.
    num_train_batches:int
            Number of training batches for epoch.
    """

    def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam', lr=1e-3, weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts', lr_scheduler_decay=0.5, T_max=10, eta_min=5e-8, T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, opt_level='O1', checkpoints_dir='checkpoints',tensor_dataset=False,transforms=None,save_metric='loss',save_after_n_batch=0):

        self.model = model
        # self.amp_handle = amp.init(enabled=True)
        optimizers = {'adam': torch.optim.Adam, 'sgd': torch.optim.SGD}
        loss_functions = {'bce': nn.BCEWithLogitsLoss(reduction=reduction), 'ce': nn.CrossEntropyLoss(
            reduction=reduction), 'mse': nn.MSELoss(reduction=reduction), 'nll': nn.NLLLoss(reduction=reduction),'dice':DiceLoss()}
        
        if 'name' not in list(optimizer_opts.keys()):
            optimizer_opts['name'] = 'adam'
        
        self.optimizer = optimizers[optimizer_opts.pop('name')](
            self.model.parameters(), **optimizer_opts)
        
        if False and torch.cuda.is_available():
            self.cuda = True
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=opt_level)
        else:
            self.cuda = False
        self.scheduler = Scheduler(
            optimizer=self.optimizer, opts=scheduler_opts)
        self.n_epoch = n_epoch
        self.validation_dataloader = validation_dataloader
        self.loss_fn = loss_functions[loss_fn]
        self.loss_fn_name = loss_fn
        self.bce = (self.loss_fn_name == 'bce')
        self.sigmoid = nn.Sigmoid()
        self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn])
        self.num_train_batches = num_train_batches
        self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn])
        self.verbosity=0
        self.checkpoints_dir=checkpoints_dir
        self.tensor_dataset=tensor_dataset
        self.transforms=transforms
        self.save_metric=save_metric
        self.save_after_n_batch=save_after_n_batch
        self.train_batch_count=0
        self.initial_seed=0
        self.seed=0

    def save_checkpoint(self,model,epoch,batch=0):
        os.makedirs(self.checkpoints_dir,exist_ok=True)
        out_name = f"{batch}.batch" if batch else f"{epoch}.epoch"
        torch.save(model,os.path.join(self.checkpoints_dir,f"{out_name}.checkpoint.pth"))

    def calc_loss(self, y_pred, y_true):
        """Calculates loss supplied in init statement and modified by reweighting.
        Parameters
        ----------
        y_pred:tensor
                Predictions.
        y_true:tensor
                True values.
        Returns
        -------
        loss
        """

        return self.loss_fn(y_pred, y_true)

    def calc_val_loss(self, y_pred, y_true):
        """Calculates loss supplied in init statement on validation set.
        Parameters
        ----------
        y_pred:tensor
                Predictions.
        y_true:tensor
                True values.
        Returns
        -------
        val_loss
        """

        return self.val_loss_fn(y_pred, y_true)

    def reset_loss_fn(self):
        """Resets loss to original specified loss."""
        self.loss_fn = self.original_loss_fn

    def add_class_balance_loss(self, y, custom_weights=''):
        """Updates loss function to handle class imbalance by weighting inverse to class appearance.
        Parameters
        ----------
        dataset:DynamicImageDataset
                Dataset to balance by.
        """
        self.class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(y),y=y)#dataset.get_class_weights() if not custom_weights else np.array(
            #list(map(float, custom_weights.split(','))))
        if custom_weights:
            self.class_weights = self.class_weights / sum(self.class_weights)
        print('Weights:', self.class_weights)
        self.original_loss_fn = copy.deepcopy(self.loss_fn)
        weight = torch.tensor(self.class_weights, dtype=torch.float)
        if torch.cuda.is_available():
            weight = weight.cuda()
        if self.loss_fn_name == 'ce':
            self.loss_fn = nn.CrossEntropyLoss(weight=weight)
        elif self.loss_fn_name == 'nll':
            self.loss_fn = nn.NLLLoss(weight=weight)
        else:  # modify below for multi-target
            self.loss_fn = lambda y_pred, y_true: sum([self.class_weights[i] * self.original_loss_fn(
                y_pred[y_true == i], y_true[y_true == i]) if sum(y_true == i) else 0. for i in range(3)])

    def calc_best_confusion(self, y_pred, y_true):
        """Calculate confusion matrix on validation set for classification/segmentation tasks, optimize threshold where positive.
        Parameters
        ----------
        y_pred:array
                Predictions.
        y_true:array
                Ground truth.
        Returns
        -------
        float
                Optimized threshold to use on test set.
        dataframe
                Confusion matrix.
        """
        fpr, tpr, thresholds = roc_curve(y_true, y_pred)
        threshold = thresholds[np.argmin(
            np.sum((np.array([0, 1]) - np.vstack((fpr, tpr)).T)**2, axis=1)**.5)]
        y_pred = (y_pred > threshold).astype(int)
        return threshold, pd.DataFrame(confusion_matrix(y_true, y_pred), index=['F', 'T'], columns=['-', '+']).iloc[::-1, ::-1].T

    def loss_backward(self, loss):
        """Backprop using mixed precision for added speed boost.
        Parameters
        ----------
        loss:loss
                Torch loss calculated.
        """
        # with self.amp_handle.scale_loss(loss, self.optimizer) as scaled_loss:
        # 	scaled_loss.backward()
        # loss.backward()
        if self.cuda:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

    # @pysnooper.snoop()
    def train_loop(self, epoch, train_dataloader):
        """One training epoch, calculate predictions, loss, backpropagate.
        Parameters
        ----------
        epoch:int
                Current epoch.
        train_dataloader:DataLoader
                Training data.
        Returns
        -------
        float
                Training loss for epoch
        """
        self.model.train(True)
        running_loss = 0.
        n_batch = len(
            train_dataloader.dataset) // train_dataloader.batch_size if self.num_train_batches == None else self.num_train_batches
        for i, batch in enumerate(train_dataloader):
            starttime = time.time()
            X, y_true = batch[:2]
            if len(batch)==3: Z=batch[2]
            else: Z=None

            if i == n_batch:
                break

            # X = Variable(batch[0], requires_grad=True)
            # y_true = Variable(batch[1])

            if torch.cuda.is_available():
                X = X.cuda()
                y_true = y_true.cuda()
                if Z is not None: Z=Z.cuda()


            y_pred = self.model(X) if Z is None else self.model(X,Z)
            # y_true=y_true.argmax(dim=1)

            loss = self.calc_loss(y_pred, y_true.flatten())  # .view(-1,1)
            train_loss = loss.item()
            running_loss += train_loss
            self.optimizer.zero_grad()
            self.loss_backward(loss)  # loss.backward()
            self.optimizer.step()
            torch.cuda.empty_cache()
            endtime = time.time()
            if self.verbosity >=1:
                print("Epoch {}[{}/{}] Time:{}, Train Loss:{}".format(epoch,
                                                                  i, n_batch, round(endtime - starttime, 3), train_loss))
            self.train_batch_count+=1
            if self.save_after_n_batch and self.train_batch_count%self.save_after_n_batch==0:
                val_loss,val_f1=self.val_loop(epoch, self.val_dataloader)
                self.batch_val_losses.append(val_loss)
                self.batch_val_f1.append(val_f1)
                self.save_best_val_model(val_loss, val_f1, self.batch_val_losses, self.batch_val_f1, epoch, True, self.train_batch_count)
                self.model.train(True)

        self.scheduler.step()
        running_loss /= n_batch
        return running_loss

    def val_loop(self, epoch, val_dataloader, print_val_confusion=True, save_predictions=True):
        """Calculate loss over validation set.
        Parameters
        ----------
        epoch:int
                Current epoch.
        val_dataloader:DataLoader
                Validation iterator.
        print_val_confusion:bool
                Calculate confusion matrix and plot.
        save_predictions:int
                Print validation results.
        Returns
        -------
        float
                Validation loss for epoch.
        """
        self.model.train(False)
        n_batch = len(val_dataloader.dataset) // val_dataloader.batch_size
        running_loss = 0.
        Y = {'pred': [], 'true': []}
        with torch.no_grad():
            for i, batch in enumerate(val_dataloader):
                # X = Variable(batch[0], requires_grad=True)
                # y_true = Variable(batch[1])
                X, y_true = batch[:2]
                if len(batch)==3: Z=batch[2]
                else: Z=None
                if torch.cuda.is_available():
                    X = X.cuda()
                    y_true = y_true.cuda()
                    if Z is not None: Z=Z.cuda()


                y_pred = self.model(X) if Z is None else self.model(X,Z)
                # y_true=y_true.argmax(dim=1)
                # if save_predictions:
                Y['true'].append(
                    y_true.detach().cpu().numpy().astype(int).flatten())
                y_pred_numpy = ((y_pred if not self.bce else self.sigmoid(y_pred)).detach().cpu().numpy()).astype(float)
                if self.loss_fn_name in ['ce','dice']:
                    y_pred_numpy = y_pred_numpy.argmax(axis=1)
                Y['pred'].append(y_pred_numpy.flatten())

                loss = nn.CrossEntropyLoss(y_pred,y_true.flatten())#.view(-1,1)
                val_loss = loss.item()
                running_loss += val_loss
                if self.verbosity >=1:
                    print("Epoch {}[{}/{}] Val Loss:{}".format(epoch, i, n_batch, val_loss))
        # if print_val_confusion and save_predictions:
        y_pred, y_true = np.hstack(Y['pred']).flatten(), np.hstack(Y['true']).flatten()
        print(classification_report(y_true, y_pred))
        running_loss /= n_batch
        return running_loss, f1_score(y_true, y_pred,average='macro')

    # @pysnooper.snoop("test_loop.log")
    def test_loop(self, test_dataloader):
        """Calculate final predictions on loss.
        Parameters
        ----------
        test_dataloader:DataLoader
                Test dataset.
        Returns
        -------
        array
                Predictions or embeddings.
        """
        # self.model.train(False) KEEP DROPOUT? and BATCH NORM??
        self.model.eval()
        y_pred = []
        Y_true = []
        running_loss = 0.
        n_batch = len(
            test_dataloader.dataset) // test_dataloader.batch_size
        with torch.no_grad():
            for i, batch in tqdm.tqdm(enumerate(test_dataloader),total=n_batch):
                #X = Variable(batch[0],requires_grad=False)
                X, y_true = batch[:2]
                if len(batch)==3: Z=batch[2]
                else: Z=None
                if torch.cuda.is_available():
                    X = X.cuda()
                    y_true = y_true.cuda()
                    if Z is not None: Z=Z.cuda()

                prediction = self.model(X) if Z is None else self.model(X,Z)
                y_pred.append(prediction.detach().cpu().numpy())
                Y_true.append(y_true.detach().cpu().numpy())
        y_pred = np.concatenate(y_pred, axis=0)  # torch.cat(y_pred,0)
        y_true = np.concatenate(Y_true, axis=0).flatten()
        return y_pred,y_true

    def fit(self, train_dataloader, verbose=False, print_every=10, save_model=True, plot_training_curves=True, plot_save_file='/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/', print_val_confusion=True, save_val_predictions=True):
        """Fits the segmentation or classification model to the patches, saving the model with the lowest validation score.
        Parameters
        ----------
        train_dataloader:DataLoader
                Training dataset.
        verbose:bool
                Print training and validation loss?
        print_every:int
                Number of epochs until print?
        save_model:bool
                Whether to save model when reaching lowest validation loss.
        plot_training_curves:bool
                Plot training curves over epochs.
        plot_save_file:str
                File to save training curves.
        print_val_confusion:bool
                Print validation confusion matrix.
        save_val_predictions:bool
                Print validation results.
        Returns
        -------
        self
                Trainer.
        float
                Minimum val loss.
        int
                Best validation epoch with lowest loss.
        """
        # choose model with best f1
        self.train_losses = []
        self.val_losses = []
        self.val_f1 = []
        self.batch_val_losses = []
        self.batch_val_f1 = []
        if verbose:
            self.verbosity+=1
        for epoch in range(self.n_epoch):
            self.seed=self.initial_seed+epoch
            np.random.seed(self.seed)
            start_time = time.time()
            train_loss = self.train_loop(epoch, train_dataloader)
            current_time = time.time()
            train_time = current_time - start_time
            self.train_losses.append(train_loss)
            val_loss, val_f1 = self.val_loop(epoch, self.validation_dataloader,
                                     print_val_confusion=print_val_confusion, save_predictions=save_val_predictions)
            val_time = time.time() - current_time
            self.val_losses.append(val_loss)
            self.val_f1.append(val_f1)
            self.batch_val_losses.append(val_loss)
            self.batch_val_f1.append(val_f1)
            # if True:#verbose and not (epoch % print_every):
            if plot_training_curves:
                self.plot_train_val_curves(plot_save_file)
            print("Epoch {}: Train Loss {}, Val Loss {}, Train Time {}, Val Time {}".format(
                epoch, train_loss, val_loss, train_time, val_time))
            print('Training complete in {:.0f}m {:.0f}s'.format(train_time // 60, train_time % 60))
            
            self.save_best_val_model(val_loss, val_f1, self.val_losses, self.val_f1, epoch, save_model)
            if "save_every" in dir(train_dataloader.dataset) and train_dataloader.dataset.save_every and epoch%train_dataloader.dataset.save_every==0:
                train_dataloader.dataset.load_image_annot()
        if save_model:
            print("Saving best model at epoch {}".format(self.best_epoch))
            self.model.load_state_dict(self.best_model_state_dict)
        return self, self.min_val_loss_f1, self.best_epoch

    def save_best_val_model(self, val_loss, val_f1, val_loss_list, val_f1_list, epoch, save_model=True, batch=0):
        if (val_loss <= min(val_loss_list) if self.save_metric=='loss' else val_f1 >= max(val_f1_list)) and save_model:
            print("New best model at epoch {}".format(epoch))
            self.min_val_loss_f1 = val_loss if self.save_metric=='loss' else val_f1
            self.best_epoch = epoch
            if batch: self.best_batch = batch
            self.best_model_state_dict = copy.deepcopy(self.model.state_dict())
            self.save_checkpoint(self.best_model_state_dict,epoch,batch)

    def plot_train_val_curves(self, save_file='/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/'):
        """Plots training and validation curves.
        Parameters
        ----------
        save_file:str
                File to save to.
        """
        plt.figure()
        sns.lineplot('epoch', 'value', hue='variable',
                     data=pd.DataFrame(np.vstack((np.arange(len(self.train_losses)), self.train_losses, self.val_losses)).T,
                                       columns=['epoch', 'train', 'val']).melt(id_vars=['epoch'], value_vars=['train', 'val']))
        if save_file is not None:
            plt.savefig(save_file, dpi=300)

    def predict(self, test_dataloader):
        """Make classification segmentation predictions on testing data.
        Parameters
        ----------
        test_dataloader:DataLoader
                Test data.
        Returns
        -------
        array
                Predictions.
        """
        y_pred,y_true = self.test_loop(test_dataloader)
        return y_pred,y_true

    def fit_predict(self, train_dataloader, test_dataloader):
        """Fit model to training data and make classification segmentation predictions on testing data.
        Parameters
        ----------
        train_dataloader:DataLoader
                Train data.
        test_dataloader:DataLoader
                Test data.
        Returns
        -------
        array
                Predictions.
        """
        return self.fit(train_dataloader)[0].predict(test_dataloader)

    def return_model(self):
        """Returns pytorch model.
        """
        return self.model

In [7]:
class Reshape(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        return x.view(x.shape[0],-1)

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

def generate_transformers(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    train_transform = [transforms.Resize((resize,resize))]
    if include_jitter:
        train_transform.append(transforms.ColorJitter(brightness=0.4,
                                            contrast=0.4, saturation=0.4, hue=0.1))
    train_transform.extend([transforms.RandomHorizontalFlip(p=0.5),
           transforms.RandomVerticalFlip(p=0.5),
           transforms.RandomRotation(90),
           transforms.RandomResizedCrop((image_size,image_size)),
           transforms.ToTensor(),
           transforms.Normalize(mean if mean else [0.5, 0.5, 0.5],
                                std if std else [0.1, 0.1, 0.1])
           ])
    train_transform=transforms.Compose(train_transform)
    val_transform = transforms.Compose([
        transforms.Resize((resize,resize)),
        transforms.CenterCrop((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean if mean else [0.5, 0.5, 0.5],
                             std if std else [0.1, 0.1, 0.1])
    ])
    normalization_transform = transforms.Compose([transforms.Resize((resize,resize)),
                                                  transforms.CenterCrop(
                                                      (image_size,image_size)),
                                                  transforms.ToTensor()])
    return {'train': train_transform, 'val': val_transform, 'test': val_transform, 'norm': normalization_transform}

def generate_kornia_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5])
    std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1])
    if torch.cuda.is_available():
        mean=mean.cuda()
        std=std.cuda()
    train_transforms=[G.Resize((resize,resize))]
    if include_jitter:
        train_transforms.append(K.ColorJitter(brightness=0.4, contrast=0.4,
                                   saturation=0.4, hue=0.1))
    train_transforms.extend([K.RandomHorizontalFlip(p=0.5),
           K.RandomVerticalFlip(p=0.5),
           K.RandomRotation(90),
           K.RandomResizedCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ])
    val_transforms=[G.Resize((resize,resize)),
           K.CenterCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ]
    transforms=dict(train=nn.Sequential(*train_transforms),
                val=nn.Sequential(*val_transforms))
    if torch.cuda.is_available():
        for k in transforms:
            transforms[k]=transforms[k].cuda()
    return transforms


In [8]:
def train_model(inputs_dir="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/cnn_model_input/",
                learning_rate=1e-4,
                n_epochs=300,
                crop_size=224,
                resize=256,
                mean=[0.5, 0.5, 0.5],
                std=[0.1, 0.1, 0.1],
                num_classes=3,
                architecture='resnet50',
                batch_size=32,
                predict=False,
                model_save_loc='saved_model.pkl',
                pretrained_save_loc='pretrained_model.pkl',
                predictions_save_path='predictions.pkl',
                predict_set='test',
                verbose=False,
                class_balance=True,
                extract_embeddings="",
                extract_embeddings_df="",
                embedding_out_dir="./",
                gpu_id=-1,
                checkpoints_dir="checkpoints",
                pickle_dataset=True,
                label_map=dict(),
                save_metric="loss",
                custom_dataset=None,
                save_predictions=True,
                pretrained=False,
                save_after_n_batch=0,
                include_test_set=False,
                use_npy_rotate=False,
                sample_frac=1.,
                sample_every=0,
                num_workers=0,
                npy_rotate_sets_pkl="",
                visualize_predictions=True,
                ):
    assert save_metric in ['loss','f1']
    if use_npy_rotate: tensor_dataset,pickle_dataset=False,False
    else: sample_every=0
    if predict: include_test_set=True
    if predict: assert not use_npy_rotate
    if extract_embeddings: assert predict, "Must be in prediction mode to extract embeddings"
    if gpu_id>=0: torch.cuda.set_device(gpu_id)
    transformers=generate_transformers 
    transformers = transformers(
        image_size=crop_size, resize=resize, mean=mean, std=std)
    if custom_dataset is not None:
        assert predict
        datasets={}
        datasets['custom']=custom_dataset
        predict_set='custom'
    else:
        if pickle_dataset:
            datasets = {x: PickleDataset(os.path.join(inputs_dir,f"{x}_data.pkl"),transformers[x],label_map) for x in (['train','val']+(['test'] if include_test_set else [])) if os.path.exists(os.path.join(inputs_dir,f"{x}_data.pkl"))}
        elif use_npy_rotate:
            datasets = {x: NPYRotatingStack(os.path.join(inputs_dir,x),transformers[x],(sample_frac if x=='train' else 1.),sample_every,label_map,npy_rotate_sets_pkl,x) for x in (['train','val']+(['test'] if include_test_set else []))}
        else:
            datasets = {x: Datasets.ImageFolder(os.path.join(
                inputs_dir, x), transformers[x]) for x in (['train','val']+(['test'] if include_test_set else []))}

    if verbose: print(datasets)

    dataloaders = {x: DataLoader(
        datasets[x], batch_size=batch_size, num_workers=num_workers, shuffle=(x == 'train' and not predict), worker_init_fn=None) for x in datasets}

    model = generate_model(architecture,
                           num_classes,
                           pretrained=pretrained,
                           n_aux_features=None if  "n_aux_features" not in dir(datasets.get('train',datasets.get('custom',None))) else datasets.get('train',datasets.get('custom',None)).n_aux_features)
    
    
    if verbose: print(model)

    if torch.cuda.is_available():
        model = model.cuda()

    optimizer_opts = dict(name='adam',
                          lr=learning_rate,
                          weight_decay=1e-4)

    scheduler_opts = dict(scheduler='warm_restarts',
                          lr_scheduler_decay=0.5,
                          T_max=10,
                          eta_min=5e-8,
                          T_mult=2)

    trainer = ModelTrainer(model,
                           n_epochs,
                           None if predict else dataloaders['val'],
                           optimizer_opts,
                           scheduler_opts,
                           loss_fn='dice' if not class_balance else 'ce',
                           checkpoints_dir=checkpoints_dir,
                           tensor_dataset=None,
                           transforms=transformers,
                           save_metric=save_metric,
                           save_after_n_batch=save_after_n_batch)

    if os.path.exists(pretrained_save_loc):
        trainer.model.load_state_dict(torch.load(pretrained_save_loc,map_location=f"cuda:{gpu_id}" if gpu_id>=0 else "cpu"))

    if not predict:

        if class_balance:
            trainer.add_class_balance_loss(datasets['train'].targets)

        trainer, min_val_loss_f1, best_epoch=trainer.fit(dataloaders['train'],verbose=verbose)

        torch.save(trainer.model.state_dict(), model_save_loc)

        return trainer.model

    else:
        # assert not tensor_dataset, "Only ImageFolder and NPYDatasets allowed"

        if os.path.exists(model_save_loc):
            trainer.model.load_state_dict(torch.load(model_save_loc,map_location=f"cuda:{gpu_id}" if gpu_id>=0 else "cpu"))

        if extract_embeddings:
            trainer.model=nn.Sequential(trainer.model.features,Reshape())#,trainer.model.output
            if predict_set=='custom':
                dataset=datasets['custom']
                assert 'embed' in dir(dataset), "Embedding method required for dataset with model input, batch size and embedding output directory as arguments."
            else:
                assert len(extract_embeddings_df)>0 and os.path.exists(extract_embeddings_df), "Must load data from SQL database or pickle if not using custom dataset"
                if extract_embeddings_df.endswith(".db"):
                    from pathflowai.utils import load_sql_df
                    patch_info=load_sql_df(extract_embeddings_df,resize)
                elif extract_embeddings_df.endswith(".pkl"):
                    patch_info=pd.read_pickle(extract_embeddings_df)
                    assert patch_info['patch_size'].iloc[0]==resize, "Patch size pickle does not match."
                else:
                    raise NotImplementedError
                dataset=NPYDataset(patch_info,extract_embeddings,transformers["test"],tensor_dataset)
            return dataset.embed(trainer.model,batch_size,embedding_out_dir)
            # return "Output Embeddings"
            
        else:
            Y = dict()

            Y['pred'],Y['true'] = trainer.predict(dataloaders[predict_set])

            # Y['model'] = trainer.model

            # Y['true'] = datasets[predict_set].targets

            if save_predictions: torch.save(Y, predictions_save_path)

            return Y
    #if visualize_predictions:
        
        

def main():
    fire.Fire(train_model)

if __name__ == '__main__':
    main()

Sequential(
  (init_block): ResInitBlock(
    (conv): ConvBlock(
      (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)
    )
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (stage1): Sequential(
    (unit1): ResUnit(
      (body): ResBottleneck(
        (conv1): ConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activ): ReLU(inplace=True)
        )
        (conv2): ConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activ): ReLU(inplace=True)
        )
        (conv3): ConvBlock

RuntimeError: weight tensor should be defined either for all 1000 classes or no classes but got weight tensor of shape: [3]

In [9]:
train_model(inputs_dir="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/cnn_model_input/",
                learning_rate=1e-4,
                n_epochs=10,
                num_classes=3,
                batch_size=32,
                predict=False,
                model_save_loc='/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/cnn_model_123.pth',
                verbose=2,
                class_balance=True,
                pickle_dataset=True,
                custom_dataset=None,
                save_predictions=True,
                include_test_set=False,
                use_npy_rotate=True,
                label_map={'y_true':'y_true'},
                sample_frac=1.,
                sample_every=3,
                num_workers=0,
                npy_rotate_sets_pkl="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/rotating_stack.pkl"
                )

{'train': <__main__.NPYRotatingStack object at 0x2adfa215fca0>, 'val': <__main__.NPYRotatingStack object at 0x2adfa2151940>}
Sequential(
  (init_block): ResInitBlock(
    (conv): ConvBlock(
      (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activ): ReLU(inplace=True)
    )
    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (stage1): Sequential(
    (unit1): ResUnit(
      (body): ResBottleneck(
        (conv1): ConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (activ): ReLU(inplace=True)
        )
        (conv2): ConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, mo

TypeError: 'NoneType' object is not subscriptable