In [18]:
import pandas as pd
import numpy as np
import torch 
import tqdm 
import glob, os, pickle
import time
import seaborn as sns
import sys
import copy
import tifffile
import torch.nn as nn
import matplotlib
import matplotlib.pyplot as plt
import fire, math
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, f1_score
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets as Datasets
from torch.utils.data import TensorDataset
import torch.nn as nn
#from pathpretrain.models import generate_model#, ModelTrainer
#from pathpretrain.ModelTrainer import calc_best_confusion
from PIL import Image
import kornia.augmentation as K, kornia.geometry.transform as G
from pathpretrain.datasets import NPYDataset, PickleDataset, NPYRotatingStack
from pathpretrain.schedulers import Scheduler
#from pathpretrain.models import fit
from kornia.losses import DiceLoss
# import pysnooper

In [19]:
from pathpretrain.models import MLP
class AuxNet(nn.Module):
    def __init__(self,net,n_aux_features):
        super().__init__()
        self.net=net
        self.features=self.net.features
        self.output=self.net.output
        self.n_features=self.net.output.in_features
        self.n_aux_features=n_aux_features
        self.transform_nn=nn.Sequential(nn.Linear(self.n_aux_features,self.n_features),nn.LeakyReLU())
        self.gate_nn=MLP(self.n_features,[32],dropout_p=0.2,binary=False)#nn.Linear(self.n_features,1)
        
        for param in self.features[:4].parameters():
            param.requires_grad=False

    def forward(self,x,z=None):
        x=self.features(x)
        x = x.view(x.size(0), -1)
        if z is not None:
            z=self.transform_nn(z)
            #print(x.shape,z.shape,self.gate_nn(x).shape,self.gate_nn(z).shape)
            gate_h=F.softmax(torch.cat([self.gate_nn(xz) for xz in [x,z]],1),1)
            x = gate_h[:,0].unsqueeze(1) * x + gate_h[:,1].unsqueeze(1) * z
        x = self.output(x)
        return x

In [20]:
def prepare_model(model_name,
                  use_pretrained,
                  pretrained_model_file_path,
                  use_cuda=False,
                  use_data_parallel=True,
                  net_extra_kwargs=None,
                  load_ignore_extra=False,
                  num_classes=None,
                  in_channels=3,
                  remap_to_cpu=True,
                  remove_module=False,
                  semantic_segmentation=False,
                  n_aux_features=None):
    from pytorchcv.model_provider import get_model
    import segmentation_models_pytorch as smp

    kwargs = {"pretrained": use_pretrained}
    if num_classes is not None:
        kwargs["num_classes"] = num_classes
    if in_channels is not None:
        kwargs["in_channels"] = in_channels
    if net_extra_kwargs is not None:
        kwargs.update(net_extra_kwargs)

    if not semantic_segmentation:
        if kwargs['pretrained']:
            kwargs['pretrained']=False
            net = get_model(model_name, **kwargs)
            net_shape_dict = {k:v.shape for k,v in net.state_dict().items()}
            kwargs['num_classes']=1000
            kwargs['pretrained']=True
            net_pretrained=get_model(model_name, **kwargs).state_dict()
            net.load_state_dict({k:v for k,v in net_pretrained.items() if v.shape==net_shape_dict[k]},strict=False)
        else:
            net = get_model(model_name, **(kwargs))

        if n_aux_features is not None:
            net=AuxNet(net,n_aux_features)

    else:
        net = smp.Unet(model_name, classes=num_classes, in_channels=in_channels)

    return net

In [21]:
def generate_model(architecture, num_classes, semantic_segmentation, pretrained=False, n_aux_features=None):
    #    from pytorchcv.pytorch.utils import prepare_model
    if os.path.exists(architecture):
        model = torch.load(architecture,map_location='cpu')
    else:
        model = prepare_model(architecture,
                          use_pretrained=pretrained,
                          pretrained_model_file_path='',
                          use_cuda=False,
                          num_classes=num_classes,
                          semantic_segmentation=semantic_segmentation,
                          n_aux_features=n_aux_features)
    return model

In [22]:
class ModelTrainerr:
    
    def __init__(self, model, n_epoch=300, validation_dataloader=None, optimizer_opts=dict(name='adam', lr=1e-3, weight_decay=1e-4), scheduler_opts=dict(scheduler='warm_restarts', lr_scheduler_decay=0.5, T_max=10, eta_min=5e-8, T_mult=2), loss_fn='ce', reduction='mean', num_train_batches=None, opt_level='O1', checkpoints_dir='checkpoints',tensor_dataset=False,transforms=None,semantic_segmentation=False,save_metric='loss',save_after_n_batch=0):
        self.model = model
        # self.amp_handle = amp.init(enabled=True)
        optimizers = {'adam': torch.optim.Adam, 'sgd': torch.optim.SGD}
        loss_functions = {'bce': nn.BCEWithLogitsLoss(reduction=reduction), 'ce': nn.CrossEntropyLoss(
            reduction=reduction), 'mse': nn.MSELoss(reduction=reduction), 'nll': nn.NLLLoss(reduction=reduction),'dice':DiceLoss()}
        if 'name' not in list(optimizer_opts.keys()):
            optimizer_opts['name'] = 'adam'
        self.optimizer = optimizers[optimizer_opts.pop('name')](
            self.model.parameters(), **optimizer_opts)
        if False and torch.cuda.is_available():
            self.cuda = True
            self.model, self.optimizer = amp.initialize(
                self.model, self.optimizer, opt_level=opt_level)
        else:
            self.cuda = False
        self.scheduler = Scheduler(
            optimizer=self.optimizer, opts=scheduler_opts)
        self.n_epoch = n_epoch
        self.validation_dataloader = validation_dataloader
        self.loss_fn = loss_functions[loss_fn]
        self.loss_fn_name = loss_fn
        self.bce = (self.loss_fn_name == 'bce')
        self.sigmoid = nn.Sigmoid()
        self.original_loss_fn = copy.deepcopy(loss_functions[loss_fn])
        self.num_train_batches = num_train_batches
        self.val_loss_fn = copy.deepcopy(loss_functions[loss_fn])
        self.verbosity=0
        self.checkpoints_dir=checkpoints_dir
        self.tensor_dataset=tensor_dataset
        self.transforms=transforms
        self.semantic_segmentation=semantic_segmentation
        self.save_metric=save_metric
        self.save_after_n_batch=save_after_n_batch
        self.train_batch_count=0
        self.initial_seed=0
        self.seed=0

    def save_checkpoint(self,model,epoch,batch=0):
        os.makedirs(self.checkpoints_dir,exist_ok=True)
        out_name = f"{batch}.batch" if batch else f"{epoch}.epoch"
        torch.save(model,os.path.join(self.checkpoints_dir,f"{out_name}.checkpoint.pth"))

    def calc_loss(self, y_pred, y_true):
        """Calculates loss supplied in init statement and modified by reweighting.
        Parameters
        ----------
        y_pred:tensor
                Predictions.
        y_true:tensor
                True values.
        Returns
        -------
        loss
        """

        return self.loss_fn(y_pred, y_true)

    def calc_val_loss(self, y_pred, y_true):
        """Calculates loss supplied in init statement on validation set.
        Parameters
        ----------
        y_pred:tensor
                Predictions.
        y_true:tensor
                True values.
        Returns
        -------
        val_loss
        """

        return self.val_loss_fn(y_pred, y_true)

    def reset_loss_fn(self):
        """Resets loss to original specified loss."""
        self.loss_fn = self.original_loss_fn

    def add_class_balance_loss(self, y, custom_weights=''):
        """Updates loss function to handle class imbalance by weighting inverse to class appearance.
        Parameters
        ----------
        dataset:DynamicImageDataset
                Dataset to balance by.
        """
        self.class_weights = compute_class_weight('balanced',np.unique(y),y)#dataset.get_class_weights() if not custom_weights else np.array(
            #list(map(float, custom_weights.split(','))))
        if custom_weights:
            self.class_weights = self.class_weights / sum(self.class_weights)
        print('Weights:', self.class_weights)
        self.original_loss_fn = copy.deepcopy(self.loss_fn)
        weight = torch.tensor(self.class_weights, dtype=torch.float)
        if torch.cuda.is_available():
            weight = weight.cuda()
        if self.loss_fn_name == 'ce':
            self.loss_fn = nn.CrossEntropyLoss(weight=weight)
        elif self.loss_fn_name == 'nll':
            self.loss_fn = nn.NLLLoss(weight=weight)
        else:  # modify below for multi-target
            self.loss_fn = lambda y_pred, y_true: sum([self.class_weights[i] * self.original_loss_fn(
                y_pred[y_true == i], y_true[y_true == i]) if sum(y_true == i) else 0. for i in range(3)])
    

    def loss_backward(self, loss):
        """Backprop using mixed precision for added speed boost.
        Parameters
        ----------
        loss:loss
                Torch loss calculated.
        """
        # with self.amp_handle.scale_loss(loss, self.optimizer) as scaled_loss:
        # 	scaled_loss.backward()
        # loss.backward()
        if self.cuda:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

    # @pysnooper.snoop()
    def train_loop(self, epoch, train_dataloader):
        """One training epoch, calculate predictions, loss, backpropagate.
        Parameters
        ----------
        epoch:int
                Current epoch.
        train_dataloader:DataLoader
                Training data.
        Returns
        -------
        float
                Training loss for epoch
        """
        self.model.train(True)
        running_loss = 0.
        n_batch = len(
            train_dataloader.dataset) // train_dataloader.batch_size if self.num_train_batches == None else self.num_train_batches
        for i, batch in enumerate(train_dataloader):
            starttime = time.time()
            X, y_true = batch[:2]
            if len(batch)==3: Z=batch[2]
            else: Z=None

            if i == n_batch:
                break

            # X = Variable(batch[0], requires_grad=True)
            # y_true = Variable(batch[1])

            if torch.cuda.is_available():
                X = X.cuda()
                y_true = y_true.cuda()
                if Z is not None: Z=Z.cuda()

            if self.tensor_dataset:
                if self.semantic_segmentation: X,y_true=self.transforms['train'](X,y_true)
                else: X=self.transforms['train'](X)

            y_pred = self.model(X) if Z is None else self.model(X,Z)
            # y_true=y_true.argmax(dim=1)

            loss = self.calc_loss(y_pred, y_true if self.semantic_segmentation else y_true.flatten())
          #  loss.requires_grad=True# .view(-1,1)
            train_loss = loss.item()
            running_loss += train_loss
            self.optimizer.zero_grad()
            self.loss_backward(loss)  # loss.backward()
            self.optimizer.step()
            torch.cuda.empty_cache()
            endtime = time.time()
            if self.verbosity >=1:
                print("Epoch {}[{}/{}] Time:{}, Train Loss:{}".format(epoch,
                                                                  i, n_batch, round(endtime - starttime, 3), train_loss))
            self.train_batch_count+=1
            if self.save_after_n_batch and self.train_batch_count%self.save_after_n_batch==0:
                val_loss,val_f1=self.val_loop(epoch, self.val_dataloader)
                self.batch_val_losses.append(val_loss)
                self.batch_val_f1.append(val_f1)
                self.save_best_val_model(val_loss, val_f1, self.batch_val_losses, self.batch_val_f1, epoch, True, self.train_batch_count)
                self.model.train(True)
        
        self.scheduler.step()
        running_loss /= n_batch
        return running_loss

    def val_loop(self, epoch, val_dataloader, print_val_confusion=True, save_predictions=True):
        """Calculate loss over validation set.
        Parameters
        ----------
        epoch:int
                Current epoch.
        val_dataloader:DataLoader
                Validation iterator.
        print_val_confusion:bool
                Calculate confusion matrix and plot.
        save_predictions:int
                Print validation results.
        Returns
        -------
        float
                Validation loss for epoch.
        """
        self.model.train(False)
        n_batch = len(val_dataloader.dataset) // val_dataloader.batch_size
        running_loss = 0.
        Y = {'pred': [], 'true': []}
        with torch.no_grad():
            for i, batch in enumerate(val_dataloader):
                # X = Variable(batch[0], requires_grad=True)
                # y_true = Variable(batch[1])
                X, y_true = batch[:2]
                if len(batch)==3: Z=batch[2]
                else: Z=None
                if torch.cuda.is_available():
                    X = X.cuda()
                    y_true = y_true.cuda()
                    if Z is not None: Z=Z.cuda()

                if self.tensor_dataset:
                    if self.semantic_segmentation: X,y_true=self.transforms['val'](X,y_true)
                    else: X=self.transforms['val'](X)

                y_pred = self.model(X) if Z is None else self.model(X,Z)
                # y_true=y_true.argmax(dim=1)
                # if save_predictions:
                Y['true'].append(
                    y_true.detach().cpu().numpy().astype(int).flatten())
                y_pred_numpy = ((y_pred if not self.bce else self.sigmoid(
                    y_pred)).detach().cpu().numpy()).astype(float)
                if self.loss_fn_name in ['ce','dice']:
                    y_pred_numpy = y_pred_numpy.argmax(axis=1)
                Y['pred'].append(y_pred_numpy.flatten())

                loss = self.calc_val_loss(y_pred, y_true if self.semantic_segmentation else y_true.flatten())  # .view(-1,1)
                val_loss = loss.item()
                running_loss += val_loss
                if self.verbosity >=1:
                    print("Epoch {}[{}/{}] Val Loss:{}".format(epoch, i, n_batch, val_loss))
        # if print_val_confusion and save_predictions:
        y_pred, y_true = np.hstack(Y['pred']).flatten(), np.hstack(Y['true']).flatten()
        print(classification_report(y_true, y_pred))
        running_loss /= n_batch
        return running_loss, f1_score(y_true, y_pred,average='macro')

    # @pysnooper.snoop("test_loop.log")
    def test_loop(self, test_dataloader):
        """Calculate final predictions on loss.
        Parameters
        ----------
        test_dataloader:DataLoader
                Test dataset.
        Returns
        -------
        array
                Predictions or embeddings.
        """
        # self.model.train(False) KEEP DROPOUT? and BATCH NORM??
        self.model.eval()
        y_pred = []
        Y_true = []
        running_loss = 0.
        n_batch = len(
            test_dataloader.dataset) // test_dataloader.batch_size
        with torch.no_grad():
            for i, batch in tqdm.tqdm(enumerate(test_dataloader),total=n_batch):
                #X = Variable(batch[0],requires_grad=False)
                X, y_true = batch[:2]
                if len(batch)==3: Z=batch[2]
                else: Z=None
                if torch.cuda.is_available():
                    X = X.cuda()
                    y_true = y_true.cuda()
                    if Z is not None: Z=Z.cuda()

                prediction = self.model(X) if Z is None else self.model(X,Z)
                y_pred.append(prediction.detach().cpu().numpy())
                Y_true.append(y_true.detach().cpu().numpy())
        y_pred = np.concatenate(y_pred, axis=0)  # torch.cat(y_pred,0)
        y_true = np.concatenate(Y_true, axis=0).flatten()
        return y_pred,y_true
    
    def fit(self, train_dataloader, verbose=False, print_every=10, save_model=True, plot_training_curves=False, plot_save_file=None, print_val_confusion=True, save_val_predictions=True):
        # choose model with best f1
        self.train_losses = []
        self.val_losses = []
        self.val_f1 = []
        self.batch_val_losses = []
        self.batch_val_f1 = []
        if verbose:
            self.verbosity+=1
        for epoch in range(self.n_epoch):
            self.seed=self.initial_seed+epoch
            np.random.seed(self.seed)
            start_time = time.time()
            train_loss = self.train_loop(epoch, train_dataloader)
            current_time = time.time()
            train_time = current_time - start_time
            self.train_losses.append(train_loss)
            val_loss, val_f1 = self.val_loop(epoch, self.validation_dataloader,
                                     print_val_confusion=print_val_confusion, save_predictions=save_val_predictions)
            val_time = time.time() - current_time
            self.val_losses.append(val_loss)
            self.val_f1.append(val_f1)
            self.batch_val_losses.append(val_loss)
            self.batch_val_f1.append(val_f1)
            # if True:#verbose and not (epoch % print_every):
            if plot_training_curves:
                self.plot_train_val_curves(plot_save_file)
            print("Epoch {}: Train Loss {}, Val Loss {}, Train Time {}, Val Time {}".format(
                epoch, train_loss, val_loss, train_time, val_time))
            print('Training complete in {:.0f}m {:.0f}s'.format(train_time // 60, train_time % 60))
            self.save_best_val_model(val_loss, val_f1, self.val_losses, self.val_f1, epoch, save_model)
            if "save_every" in dir(train_dataloader.dataset) and train_dataloader.dataset.save_every and epoch%train_dataloader.dataset.save_every==0:
                train_dataloader.dataset.load_image_annot()
        if save_model:
            print("Saving best model at epoch {}".format(self.best_epoch))
            self.model.load_state_dict(self.best_model_state_dict)
        return self, self.min_val_loss_f1, self.best_epoch

    def save_best_val_model(self, val_loss, val_f1, val_loss_list, val_f1_list, epoch, save_model=True, batch=0):
        if (val_loss <= min(val_loss_list) if self.save_metric=='loss' else val_f1 >= max(val_f1_list)) and save_model:
            print("New best model at epoch {}".format(epoch))
            self.min_val_loss_f1 = val_loss if self.save_metric=='loss' else val_f1
            self.best_epoch = epoch
            if batch: self.best_batch = batch
            self.best_model_state_dict = copy.deepcopy(self.model.state_dict())
            self.save_checkpoint(self.best_model_state_dict,epoch,batch)

    def plot_train_val_curves(self, save_file=None):
            """Plots training and validation curves.
            Parameters
            ----------
            save_file:str
                    File to save to.
            """
            plt.figure()
            sns.lineplot('epoch', 'loss', hue='variable',
                         data=pd.DataFrame(np.vstack((np.arange(len(self.train_losses)), self.train_losses, self.val_losses)).T,
                                           columns=['epoch', 'train', 'val']).melt(id_vars=['epoch'], value_vars=['train', 'val']))
            if save_file is not None:
                plt.savefig(save_file, dpi=300)
            

    def predict(self, test_dataloader):
            """Make classification segmentation predictions on testing data.
            Parameters
            ----------
            test_dataloader:DataLoader
                    Test data.
            Returns
            -------
            array
                    Predictions.
            """
            y_pred,y_true = self.test_loop(test_dataloader)
            return y_pred,y_true

    def fit_predict(self, train_dataloader, test_dataloader):
            """Fit model to training data and make classification segmentation predictions on testing data.
            Parameters
            ----------
            train_dataloader:DataLoader
                    Train data.
            test_dataloader:DataLoader
                    Test data.
            Returns
            -------
            array
                    Predictions.
            """
            return self.fit(train_dataloader)[0].predict(test_dataloader)

    def return_model(self):
            """Returns pytorch model.
            """
            return self.model


In [23]:
class Reshape(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,x):
        return x.view(x.shape[0],-1)

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

def generate_transformers(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    train_transform = [transforms.Resize((resize,resize))]
    if include_jitter:
        train_transform.append(transforms.ColorJitter(brightness=0.4,
                                            contrast=0.4, saturation=0.4, hue=0.1))
    train_transform.extend([transforms.RandomHorizontalFlip(p=0.5),
           transforms.RandomVerticalFlip(p=0.5),
           transforms.RandomRotation(90),
           transforms.RandomResizedCrop((image_size,image_size)),
           transforms.ToTensor(),
           transforms.Normalize(mean if mean else [0.5, 0.5, 0.5],
                                std if std else [0.1, 0.1, 0.1])
           ])
    train_transform=transforms.Compose(train_transform)
    val_transform = transforms.Compose([
        transforms.Resize((resize,resize)),
        transforms.CenterCrop((image_size,image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean if mean else [0.5, 0.5, 0.5],
                             std if std else [0.1, 0.1, 0.1])
    ])
    normalization_transform = transforms.Compose([transforms.Resize((resize,resize)),
                                                  transforms.CenterCrop(
                                                      (image_size,image_size)),
                                                  transforms.ToTensor()])
    return {'train': train_transform, 'val': val_transform, 'test': val_transform, 'norm': normalization_transform}

def generate_kornia_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False):
    mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5])
    std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1])
    if torch.cuda.is_available():
        mean=mean.cuda()
        std=std.cuda()
    train_transforms=[G.Resize((resize,resize))]
    if include_jitter:
        train_transforms.append(K.ColorJitter(brightness=0.4, contrast=0.4,
                                   saturation=0.4, hue=0.1))
    train_transforms.extend([K.RandomHorizontalFlip(p=0.5),
           K.RandomVerticalFlip(p=0.5),
           K.RandomRotation(90),
           K.RandomResizedCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ])
    val_transforms=[G.Resize((resize,resize)),
           K.CenterCrop((image_size,image_size)),
           K.Normalize(mean,std)
           ]
    transforms=dict(train=nn.Sequential(*train_transforms),
                val=nn.Sequential(*val_transforms))
    if torch.cuda.is_available():
        for k in transforms:
            transforms[k]=transforms[k].cuda()
    return transforms

class SegmentationTransform(nn.Module):
    def __init__(self,resize,image_size,mean,std,include_jitter=False,Set="train"):
        super().__init__()
        self.resize=G.Resize((resize,resize),align_corners=False)
        self.mask_resize=lambda x: torch.nn.functional.interpolate(x, size=(resize,resize), mode='nearest', align_corners=None)#G.Resize((resize,resize),interpolation='nearest',align_corners=False)#
        self.jit=K.ColorJitter(brightness=0.4, contrast=0.4,
                                   saturation=0.4, hue=0.1) if include_jitter else (lambda x: x)
        # self.rotations=nn.ModuleList([
        #        K.augmentation.RandomAffine([-90., 90.], [0., 0.15], [0.5, 1.5], [0., 0.15])
        #        # K.RandomHorizontalFlip(p=0.5),
        #        # K.RandomVerticalFlip(p=0.5),
        #        # K.RandomRotation(90),#K.RandomResizedCrop((image_size,image_size),interpolation="nearest")
        #        ])
        # self.rotations_mask=nn.ModuleList([
        #        K.augmentation.RandomAffine([-90., 90.], [0., 0.15], [0.5, 1.5], [0., 0.15],resample="NEAREST")
        #        ])
        self.affine=K.augmentation.RandomAffine([-90., 90.], [0., 0.15], None, [0., 0.15])
        self.affine_mask=K.augmentation.RandomAffine([-90., 90.], [0., 0.15], None, [0., 0.15],resample="NEAREST",align_corners=False)
        self.normalize=K.Normalize(mean,std)
        self.crop,self.mask_crop=K.CenterCrop((image_size,image_size)),K.CenterCrop((image_size,image_size),resample="NEAREST")
        self.Set=Set

    def forward(self,input,mask):
        mask=mask.unsqueeze(1).float()#torch.cat([mask.unsqueeze(1)]*3,1)
        if self.Set=='train':
            img=self.jit(self.resize(input))
            mask_out=self.mask_resize(mask)
            img=self.affine(img)
            mask_out=self.affine_mask(mask_out,self.affine._params)
            # for rotation in self.rotations: img=rotation(img)
            img=self.normalize(img)
            # for i in range(len(self.rotations_mask)): mask_out=self.rotations_mask[i](mask_out,self.rotations[i]._params)
        else:
            img=self.normalize(self.crop(self.resize(input)))
            mask_out=self.mask_crop(self.mask_resize(mask))
        return img,mask_out.squeeze(1).long()#[:,0,...]

def generate_kornia_segmentation_transforms(image_size=224, resize=256, mean=[], std=[], include_jitter=False):  # add this then IoU metric
    mean=torch.tensor(mean) if mean else torch.tensor([0.5, 0.5, 0.5])
    std=torch.tensor(std) if std else torch.tensor([0.1, 0.1, 0.1])
    transforms={k:SegmentationTransform(resize,image_size,mean,std,include_jitter=False,Set=k) for k in ['train','val']}
    if torch.cuda.is_available():
        for k in transforms:
            transforms[k]=transforms[k].cuda()
    return transforms

# @pysnooper.snoop()
def train_model(inputs_dir='inputs_training',
                learning_rate=1e-4,
                n_epochs=300,
                crop_size=224,
                resize=256,
                mean=[0.5, 0.5, 0.5],
                std=[0.1, 0.1, 0.1],
                num_classes=3,
                architecture='resnet50',
                batch_size=32,
                predict=False,
                model_save_loc='saved_model.pkl',
                pretrained_save_loc='pretrained_model.pkl',
                predictions_save_path='predictions.pkl',
                predict_set='test',
                verbose=False,
                class_balance=True,
                extract_embeddings="",
                extract_embeddings_df="",
                embedding_out_dir="./",
                gpu_id=-1,
                checkpoints_dir="checkpoints",
                tensor_dataset=False,
                pickle_dataset=True,
                label_map=dict(),
                semantic_segmentation=False,
                save_metric="loss",
                custom_dataset=None,
                save_predictions=True,
                pretrained=False,
                save_after_n_batch=0,
                include_test_set=False,
                use_npy_rotate=False,
                sample_frac=1.,
                sample_every=0,
                num_workers=0,
                npy_rotate_sets_pkl=""
                ):
    assert save_metric in ['loss','f1']
    if use_npy_rotate: tensor_dataset,pickle_dataset=False,False
    else: sample_every=0
    if predict: include_test_set=True
    if predict: assert not use_npy_rotate
    if extract_embeddings: assert predict, "Must be in prediction mode to extract embeddings"
    if tensor_dataset: assert not pickle_dataset, "Cannot have pickle and tensor classes activated"
    if semantic_segmentation and custom_dataset is None: assert tensor_dataset==True, "For now, can only perform semantic segmentation with TensorDataset"
    if gpu_id>=0: torch.cuda.set_device(gpu_id)
    transformers=generate_transformers if not tensor_dataset else generate_kornia_transforms
    if semantic_segmentation: transformers=generate_kornia_segmentation_transforms
    transformers = transformers(
        image_size=crop_size, resize=resize, mean=mean, std=std)
    if custom_dataset is not None:
        assert predict
        datasets={}
        datasets['custom']=custom_dataset
        predict_set='custom'
    else:
        if tensor_dataset:
            datasets = {x: torch.load(os.path.join(inputs_dir,f"{x}_data.pth")) for x in (['train','val']+(['test'] if include_test_set else [])) if os.path.exists(os.path.join(inputs_dir,f"{x}_data.pth"))}
            for k in datasets:
                if len(datasets[k].tensors[1].shape)>1 and not semantic_segmentation: datasets[k]=TensorDataset(datasets[k].tensors[0],datasets[k].tensors[1].flatten())
        elif pickle_dataset:
            datasets = {x: PickleDataset(os.path.join(inputs_dir,f"{x}_data.pkl"),transformers[x],label_map) for x in (['train','val']+(['test'] if include_test_set else [])) if os.path.exists(os.path.join(inputs_dir,f"{x}_data.pkl"))}
        elif use_npy_rotate:
            datasets = {x: NPYRotatingStack(os.path.join(inputs_dir,x),transformers[x],(sample_frac if x=='train' else 1.),sample_every,label_map,npy_rotate_sets_pkl,x) for x in (['train','val']+(['test'] if include_test_set else []))}
        else:
            datasets = {x: Datasets.ImageFolder(os.path.join(
                inputs_dir, x), transformers[x]) for x in (['train','val']+(['test'] if include_test_set else []))}

    if verbose: print(datasets)

    dataloaders = {x: DataLoader(
        datasets[x], batch_size=batch_size, num_workers=num_workers, shuffle=(x == 'train' and not predict), worker_init_fn=worker_init_fn) for x in datasets}

    model = generate_model(architecture,
                           num_classes,
                           semantic_segmentation=semantic_segmentation,
                           pretrained=pretrained,
                           n_aux_features=None if semantic_segmentation or "n_aux_features" not in dir(datasets.get('train',datasets.get('custom',None))) else datasets.get('train',datasets.get('custom',None)).n_aux_features)
    '''ct=0
    for child in model.children():
        if ct <=1:
            for param in child.parameters():
                param.requires_grad = False
        ct += 1
'''
    if verbose: print(model)

    if torch.cuda.is_available():
        model = model.cuda()

    optimizer_opts = dict(name='adam',
                          lr=learning_rate,
                          weight_decay=1e-4)

    scheduler_opts = dict(scheduler='warm_restarts',
                          lr_scheduler_decay=0.5,
                          T_max=10,
                          eta_min=5e-8,
                          T_mult=2)

    trainer = ModelTrainerr(model,
                           n_epochs,
                           None if predict else dataloaders['val'],
                           optimizer_opts,
                           scheduler_opts,
                           loss_fn='dice' if (semantic_segmentation and not class_balance) else 'ce',
                           checkpoints_dir=checkpoints_dir,
                           tensor_dataset=tensor_dataset,
                           transforms=transformers,
                           semantic_segmentation=semantic_segmentation,
                           save_metric=save_metric,
                           save_after_n_batch=save_after_n_batch)

    if os.path.exists(pretrained_save_loc):
        trainer.model.load_state_dict(torch.load(pretrained_save_loc,map_location=f"cuda:{gpu_id}" if gpu_id>=0 else "cpu"))

    if not predict:

        if class_balance:
            trainer.add_class_balance_loss(datasets['train'].targets if not tensor_dataset else datasets['train'].tensors[1].numpy().flatten())

        trainer, min_val_loss_f1, best_epoch=trainer.fit(dataloaders['train'],verbose=verbose)

        torch.save(trainer.model.state_dict(), model_save_loc)

        return trainer.model

    else:
        # assert not tensor_dataset, "Only ImageFolder and NPYDatasets allowed"

        if os.path.exists(model_save_loc):
            trainer.model.load_state_dict(torch.load(model_save_loc,map_location=f"cuda:{gpu_id}" if gpu_id>=0 else "cpu"))

        if extract_embeddings:
            assert not semantic_segmentation, "Semantic Segmentation not implemented for whole slide segmentation"
            trainer.model=nn.Sequential(trainer.model.fitfeatures,Reshape())#,trainer.model.output
            if predict_set=='custom':
                dataset=datasets['custom']
                assert 'embed' in dir(dataset), "Embedding method required for dataset with model input, batch size and embedding output directory as arguments."
            else:
                assert len(extract_embeddings_df)>0 and os.path.exists(extract_embeddings_df), "Must load data from SQL database or pickle if not using custom dataset"
                if extract_embeddings_df.endswith(".db"):
                    from pathflowai.utils import load_sql_df
                    patch_info=load_sql_df(extract_embeddings_df,resize)
                elif extract_embeddings_df.endswith(".pkl"):
                    patch_info=pd.read_pickle(extract_embeddings_df)
                    assert patch_info['patch_size'].iloc[0]==resize, "Patch size pickle does not match."
                else:
                    raise NotImplementedError
                dataset=NPYDataset(patch_info,extract_embeddings,transformers["test"],tensor_dataset)
            return dataset.embed(trainer.model,batch_size,embedding_out_dir)
            # return "Output Embeddings"
        else:
            Y = dict()

            Y['pred'],Y['true'] = trainer.predict(dataloaders[predict_set])

            # Y['model'] = trainer.model

            # Y['true'] = datasets[predict_set].targets

            if save_predictions: torch.save(Y, predictions_save_path)

            return Y

def main():
    fire.Fire(train_model)

if __name__ == '__main__':
    main()

KeyError: 'val'

In [None]:
#testing 
train_model(inputs_dir="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/cnn_model_input/",
                learning_rate=5e-4,
                n_epochs=1,
                num_classes=3,
                batch_size=64,
                predict=False,
                model_save_loc='/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/cnn_model_124.pth',
                verbose=2,
                class_balance=True,
                checkpoints_dir='/dartfs-hpc/rc/home/3/f006n33/checkpoints/',
                pickle_dataset=True,
                custom_dataset=None,
                save_predictions=True,
                include_test_set=False,
                use_npy_rotate=True,
                label_map={'y_true':'y_true'},
                sample_frac=1.,
                sample_every=2,
                num_workers=2,
                npy_rotate_sets_pkl="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/rotating_stack.pkl"
                )

In [None]:
train_model(inputs_dir="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/cnn_model_input/",
                learning_rate=5e-4,
                n_epochs=10,
                num_classes=3,
                batch_size=64,
                predict=False,
                model_save_loc='/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/cnn_model_123.pth',
                verbose=2,
                class_balance=True,
                pickle_dataset=True,
                custom_dataset=None,
                save_predictions=True,
                include_test_set=False,
                use_npy_rotate=True,
                label_map={'y_true':'y_true'},
                sample_frac=1.,
                sample_every=2,
                num_workers=2,
                npy_rotate_sets_pkl="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/users/Sophie_Chen/rotating_stack.pkl"
                )