In [1]:
from copy import deepcopy
from functools import reduce
import unittest
import sys
sys.path.append('..')
from ktools.modelling.ktools_models.pytorch_nns.ffn_pytorch_embedding_model import FFNPytorchEmbeddingModel
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, CosineAnnealingWarmRestarts
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from ktools.fitting.cross_validation_executor import CrossValidationExecutor
from ktools.modelling.ktools_models.pytorch_embedding_model import PytorchEmbeddingModel
from ktools.modelling.model_transform_wrappers.survival_model_wrapper import SupportedSurvivalTransformation
from ktools.preprocessing.basic_feature_transformers import *
from ktools.utils.data_science_pipeline_settings import DataSciencePipelineSettings
from post_HCT_survival_notebooks.hct_utils import score
from ktools.modelling.ktools_models.pytorch_nns.odst_pytorch_embedding_model import ODSTPytorchEmbeddingModel
import math
from torch.optim.lr_scheduler import SequentialLR, CosineAnnealingLR
from typing import List
from pytorch_tabular.models.common.layers import ODST
from pytorch_lightning.utilities import grad_norm
from pytorch_lightning.callbacks import LearningRateMonitor, TQDMProgressBar, StochasticWeightAveraging

# preamble

In [2]:
import random

def set_seed(seed=42):
    random.seed(seed)  # Python random module
    np.random.seed(seed)  # NumPy
    torch.manual_seed(seed)  # PyTorch CPU
    torch.cuda.manual_seed_all(seed)  # PyTorch GPU (all devices)
    torch.backends.cudnn.deterministic = True  # Ensures deterministic behavior
    torch.backends.cudnn.benchmark = False  # Disables auto-tuning for convolutions

In [3]:
train_csv_path = "../data/post_hct_survival/train.csv"
test_csv_path = "../data/post_hct_survival/test.csv"
target_col_name = ['efs', 'efs_time']

def scci_metric(y_test, y_pred, id_col_name : str = "ID",
        survived_col_name : str = "efs",
        survival_time_col_name : str = "efs_time",
        stratify_col_name : str = "race_group"):
    idcs = y_test.index
    og_train = pd.read_csv(train_csv_path)
    
    y_true = og_train.loc[idcs, [id_col_name, survived_col_name, survival_time_col_name, stratify_col_name]].copy()
    y_pred_df = og_train.loc[idcs, [id_col_name]].copy()
    y_pred_df["prediction"] = y_pred
    scci = score(y_true.copy(), y_pred_df.copy(), id_col_name)
    return scci

In [26]:
@functools.lru_cache
def combinations(N):
    ind = torch.arange(N)
    comb = torch.combinations(ind, r=2)
    return comb

def pairwise_loss(race, event, event_time, risk, margin=0.2, weight_class:bool = False):
    n = event.shape[0]
    unq_races, race_counts = torch.unique(race, return_counts=True)
    pairwise_combinations = combinations(n)

    class_weighting = n/race_counts
    class_weighting /= class_weighting.sum()
    class_weighting *= len(race_counts)
    weighting_dict = dict(zip(unq_races.tolist(), class_weighting.tolist()))
    race_loss_weight = race.to(torch.float32).apply_(weighting_dict.get)

    # Find mask
    first_of_pair, second_of_pair = pairwise_combinations.T
    valid_mask = False
    valid_mask |= ((event[first_of_pair] == 1) & (event[second_of_pair] == 1))
    valid_mask |= ((event[first_of_pair] == 1) & (event_time[first_of_pair] < event_time[second_of_pair]))
    valid_mask |= ((event[second_of_pair] == 1) & (event_time[second_of_pair] < event_time[first_of_pair]))

    # pariwise hinge loss
    direction = 2*(event_time[first_of_pair] > event_time[second_of_pair]).int() - 1
    margin_loss = F.relu(direction*(risk[first_of_pair] - risk[second_of_pair]).squeeze() + margin)

    if weight_class:
        margin_loss = margin_loss * race_loss_weight

    return (margin_loss*valid_mask).sum()/valid_mask.sum()

def race_equality_loss(race, event, event_time, risk, margin=0.2):
    unq_races, race_counts = torch.unique(race, return_counts=True)
    race_specific_loss = torch.zeros(len(unq_races)).to(race.device)
    for i, r in enumerate(unq_races):
        idcs = race == r
        race_specific_loss[i] = pairwise_loss(race, event[idcs], event_time[idcs], risk[idcs], margin=margin, weight_class=False)
    return torch.std(race_specific_loss)

def event_time_loss(event, event_time, event_time_predicted):
    event_occurred_mask = (event == 1)
    loss = F.mse_loss(event_time, event_time_predicted, reduction='none')
    return (loss * event_occurred_mask).sum()/event_occurred_mask.sum()

def event_occurred_loss(event, event_predicted):
    return F.binary_cross_entropy(event, event_predicted)

In [60]:
class NN(nn.Module):

    def __init__(self,
                 input_dim : int,
                 output_dim : int,
                 categorical_sizes : List[int],
                 categorical_embedding : List[int],
                 projection_size : int = 112,
                 hidden_dim : int = 56,
                 dropout : float = 0.05,
                 aux_hidden_size : int = 3
                 ) -> None:
        
        super(NN, self).__init__()

        self._input_dim = input_dim
        self._output_dim = output_dim
        self._categorical_sizes = categorical_sizes
        self._categorical_embedding = categorical_embedding
        self._projection_size = projection_size
        self._hidden_dim = hidden_dim
        self._dropout = dropout
        self._aux_hidden_size = aux_hidden_size
        self.embedding_layers = self._create_embedding_layers()

        self.projection_embeddings = nn.Sequential(
            nn.Linear(sum(categorical_embedding), projection_size),
            nn.GELU(),
            nn.Linear(projection_size, projection_size)
        )
        self.model = self._create_model()

        self.risk_adapter = nn.Sequential(
                            nn.Linear(self._hidden_dim, self._output_dim),
                            # nn.GELU(),
                            # nn.Linear(self._hidden_dim//self._aux_hidden_size, self._output_dim),
                            # self._get_activation('none')
                            )

        self.event_time_adapter = nn.Sequential(
                            nn.Linear(self._hidden_dim, self._hidden_dim//self._aux_hidden_size),
                            nn.GELU(),
                            nn.Linear(self._hidden_dim//self._aux_hidden_size, self._output_dim),
                            self._get_activation('none')
                            )
        
        self.event_adapter = nn.Sequential(
                            nn.Linear(self._hidden_dim, self._hidden_dim//self._aux_hidden_size),
                            nn.GELU(),
                            nn.Linear(self._hidden_dim//self._aux_hidden_size, self._output_dim),
                            self._get_activation('sigmoid')
                            )
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)

    def _create_model(self):
        model = nn.Sequential(
            nn.Dropout(self._dropout),
            ODST(self._projection_size + self.num_numericals, self._hidden_dim),
            nn.BatchNorm1d(self._hidden_dim),
            nn.Dropout(self._dropout)
        )
        return model
    
    def forward(self, x_cat : torch.Tensor, x_num : torch.Tensor) -> torch.Tensor:
        embeddings = self.forward_embeddings(x_cat)
        x_proj = self.projection_embeddings(embeddings)
        x = torch.cat([x_proj, x_num], dim=1)
        x = self.model(x)
        risk = self.risk_adapter(x)
        efs_time = self.event_time_adapter(x)
        efs = self.event_adapter(x)
        return risk, efs_time, efs

    def forward_embeddings(self, x_cat : torch.Tensor) -> torch.Tensor:
        cat_inputs = ()
        for i in range(self.num_categories):
            cat_inputs += (self.embedding_layers[i](x_cat[:, i]),)
        embeddings = torch.cat(cat_inputs, dim=1)
        return embeddings
    
    def _create_embedding_layers(self):
        embeddings = []
        for i in range(self.num_categories):
            embeddings += [nn.Embedding(self._categorical_sizes[i], self._categorical_embedding[i])]
        return nn.ModuleList(embeddings)

    def _get_activation(self, activation):
        if activation == 'relu':
            return nn.ReLU()
        elif activation == 'gelu':
            return nn.GELU()
        elif activation == 'sigmoid':
            return nn.Sigmoid()
        elif activation == 'none':
            return nn.Identity()
        
    @property
    def num_categories(self):
        return len(self._categorical_sizes)
    
    @property
    def combined_emb_dim(self):
        return sum(self._categorical_embedding)
    
    @property
    def num_numericals(self):
        return self._input_dim - self.num_categories

# Create data

In [61]:
train_df = pd.read_csv(train_csv_path, index_col=0).drop(columns=['efs', 'efs_time'])
# cat_by_dtype = train_df.select_dtypes(['object', 'category']).columns.tolist()
numericals = train_df.select_dtypes(['number']).columns.tolist()
nunq = train_df[numericals].fillna(-1).nunique()
# categoricals = nunq[((2 < nunq.values) & (nunq.values < 30))].index.tolist()
# numericals = nunq[~((2 < nunq.values) & (nunq.values < 30))].index.tolist()
# categoricals = [f for f in train_df.columns if f in categoricals or f in cat_by_dtype]
from sklearn.impute import SimpleImputer

class AddFeatures(IFeatureTransformer):
    @staticmethod
    def transform(original_settings : DataSciencePipelineSettings):
        settings = deepcopy(original_settings)
        settings.combined_df['is_cyto_score_same'] = (settings.combined_df['cyto_score'] == settings.combined_df['cyto_score_detail']).astype(int)
        settings.categorical_col_names += ['is_cyto_score_same']
        settings.training_col_names += ['is_cyto_score_same']
        settings.combined_df['year_hct'] -= 2000
        return settings
    
# class FillNullMeanValuesWithIndicator(IFeatureTransformer):
#     @staticmethod
#     def transform(original_settings : DataSciencePipelineSettings, category_fill='missing'):
#         settings = deepcopy(original_settings)
#         for col_name in settings.training_col_names:
#             if pd.api.types.is_numeric_dtype(settings.combined_df[col_name]):
#                 # We only want to add indicator variables for features that are not intended to be categoricals
#                 not_categorical = col_name not in settings.categorical_col_names
#                 imputer = SimpleImputer(strategy="mean", add_indicator=not_categorical)
#                 result = imputer.fit_transform(settings.combined_df[[col_name]].values)
#                 if result.shape[1] > 1:
#                     print("adding indicator")
#                     settings.combined_df[[col_name, col_name + "_nan"]] = result
#                     settings.training_col_names += [col_name + "_nan"]
#                 else:
#                     settings.combined_df[col_name] = result
#             else:
#                 settings.combined_df[col_name] = settings.combined_df[col_name].fillna(category_fill)
#         return settings
    
class HCTTransform():
    @staticmethod
    def transform(original_settings : DataSciencePipelineSettings):
        settings = deepcopy(original_settings)
        applicable_to_cat = (2 < nunq.values) & (nunq.values < 30)
        categoricals = nunq[applicable_to_cat].index.tolist()
        numericals = nunq[~applicable_to_cat].index.tolist()

        for col in categoricals:
            cat_col_name = "cat_" + col
            settings.combined_df[cat_col_name] = settings.combined_df[col].astype(str)
            settings.categorical_col_names += [cat_col_name]
            settings.training_col_names += [cat_col_name]
        
        for col in numericals:
            imputer = SimpleImputer(strategy="mean", add_indicator=True)
            result = imputer.fit_transform(settings.combined_df[[col]].values)
            if result.shape[1] > 1:
                print("adding indicator")
                settings.combined_df[[col, col + "_nan"]] = result
                settings.training_col_names += [col + "_nan"]
        
        return settings

settings = DataSciencePipelineSettings(train_csv_path,
                                        test_csv_path,
                                        target_col_name,
                                        # categorical_col_names=categoricals
                                        )
transforms = [
            # AddFeatures.transform,
            StandardScaleNumerical.transform,
            HCTTransform.transform,
            FillNullValues.transform,
            OrdinalEncode.transform,
            ConvertObjectToCategorical.transform,
            # GenerateSurvivalTarget('kaplanmeier').transform
            ]

settings = reduce(lambda acc, func: func(acc), transforms, settings)
settings.update()

train, test_df = settings.update()
X, y = train.drop(columns=settings.target_col_name), train[settings.target_col_name]

INPUT_DIM = len(settings.training_col_names)
OUTPUT_DIM = 1

cat_names = settings.categorical_col_names
cat_sizes = [int(x) for x in X[cat_names].nunique().values]
# cat_emb = [x*2 for x in cat_sizes]
cat_emb = [16] * len(cat_sizes)
categorical_idcs = [X.columns.get_loc(col) for col in cat_names]
numerical_idcs = list(set(range(X.shape[1])).difference(set(categorical_idcs)))

In [63]:
print("number of numerical columns: ", len(numerical_idcs), "number of categorical columns: ", len(categorical_idcs))

number of numerical columns:  23 number of categorical columns:  55


In [64]:
standardize = lambda df, col : (df[col].values - df[col].mean())/df[col].std()

def create_dataloader(X : pd.DataFrame, y : pd.DataFrame, shuffle : bool = False):

    data = TensorDataset(
            torch.tensor(X.index, dtype=torch.long),
            torch.tensor(X.to_numpy()[:, categorical_idcs], dtype=torch.long),
            torch.tensor(X.to_numpy()[:, numerical_idcs], dtype=torch.float32),
            torch.tensor(y['efs'].to_numpy(), dtype=torch.int),
            torch.tensor(standardize(y, 'efs_time'), dtype=torch.float32)
        )

    dataloader = DataLoader(data, batch_size=2048, shuffle=shuffle)
    return dataloader

In [65]:
hparams = {
    "embedding_dim": 16,
    "projection_dim": 112,
    "hidden_dim": 56,
    "lr": 0.06464861983337984,
    "dropout": 0.05463240181423116,
    "aux_weight": 0.26545778308743806,
    "margin": 0.2588153271003354,
    "weight_decay": 0.0002773544957610778
}

In [75]:
class PairwiseRankingNeuralNetwork(pl.LightningModule):

    def __init__(self,
                 race_col_idx : int,
                 margin : float = 0.26,
                 learning_rate : float = 6.5e-3,
                 weight_decay : float = 3e-4,
                 dropout : float = 0.05,
                 race_std_loss_multiplier:int=10,
                 event_time_prediction_multiplier:float=0.1,
                 event_prediction_multiplier:float=0.001):
        
        super(PairwiseRankingNeuralNetwork, self).__init__()
        self.save_hyperparameters()

        self._race_col_idx = race_col_idx
        self._margin = margin
        self._learning_rate = learning_rate
        self._weight_decay = weight_decay
        self._race_std_loss_multiplier = race_std_loss_multiplier
        self._etpm = event_time_prediction_multiplier
        self._epm = event_prediction_multiplier

        self.model = NN(INPUT_DIM,
                        OUTPUT_DIM,
                        cat_sizes,
                        cat_emb,
                        dropout=dropout)
        
        self.targets = []

        self.binary_loss = nn.BCELoss()

    # def on_before_optimizer_step(self, optimizer):
    #     norms = grad_norm(self.model, norm_type=2)
    #     self.log_dict(norms)
        
    def forward(self, x_cat, x_num):
        x = self.model(x_cat, x_num)
        return x
        
    def training_step(self, batch, batch_idx):
        idx, x_cat, x_num, event, event_time = batch
        risk_score, efs_time, efs = self(x_cat, x_num)
        efs_time = efs_time.squeeze()
        efs = efs.squeeze()
        
        loss = pairwise_loss(x_cat[:, self._race_col_idx], event, event_time, risk_score, margin=self._margin)
        race_std_loss = race_equality_loss(x_cat[:, self._race_col_idx], event, event_time, risk_score, margin=self._margin)
        efs_time_loss = event_time_loss(event, event_time, efs_time)
        efs_loss = self.binary_loss(efs, event.float())#event_occurred_loss(event.float(), efs)
        
        # self.log('train_pairwise_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        # self.log('train_race_std_loss', race_std_loss, on_epoch=True, prog_bar=True, logger=True)
        return (loss 
                + self._race_std_loss_multiplier*race_std_loss
                + self._etpm*efs_time_loss
                + self._epm*efs_loss
                )
    
    def validation_step(self, batch, batch_idx):
        idx, x_cat, x_num, event, event_time = batch
        risk_score, efs_time, efs = self(x_cat, x_num)
        efs_time = efs_time.squeeze()
        efs = efs.squeeze()
        self.targets.append([event_time, risk_score.detach(), event, idx])

        loss = pairwise_loss(x_cat[:, self._race_col_idx], event, event_time, risk_score, margin=self._margin)
        race_std_loss = race_equality_loss(x_cat[:, self._race_col_idx], event, event_time, risk_score, margin=self._margin)
        efs_time_loss = event_time_loss(event, event_time, efs_time)
        efs_loss =  self.binary_loss(efs, event.float())

        # self.log('val_race_std_loss', race_std_loss, on_epoch=True, prog_bar=True, logger=True)
        # self.log('val_pairwise_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        return (loss 
                + self._race_std_loss_multiplier*race_std_loss
                + self._etpm*efs_time_loss
                + self._epm*efs_loss
                )

    def on_validation_epoch_end(self):
        """
        At the end of the validation epoch, it computes and logs the concordance index
        """
        scci, scci_rev = self._calc_cindex()
        self.log("cindex", scci, on_epoch=True, prog_bar=True, logger=True)
        # self.log("neg cindex", scci_rev, on_epoch=True, prog_bar=True, logger=True)
        self.targets.clear()

    def _calc_cindex(self):
        """
        Calculate c-index accounting for each race_group or global.
        """
        y = torch.cat([t[0] for t in self.targets]).cpu().numpy()
        y_hat = torch.cat([t[1] for t in self.targets]).cpu().numpy()
        efs = torch.cat([t[2] for t in self.targets]).cpu().numpy()
        idx = torch.cat([t[3] for t in self.targets]).cpu().numpy()
        val_idcs = pd.Series(idx, index=idx)
        scci = scci_metric(val_idcs, y_hat)
        scci_rev = scci_metric(val_idcs, -y_hat)
        return scci, scci_rev

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self._learning_rate, weight_decay=self._weight_decay)
        scheduler_config = {
            "scheduler": torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=45,
                eta_min=6e-3
            ),
            "interval": "epoch",
            "frequency": 1,
            "strict": False,
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler_config}

In [76]:
import pytorch_lightning
import torch
import pytorch_tabular

print(torch.__version__)
print(pytorch_lightning.__version__)
print(pytorch_tabular.__version__)

2.5.0
2.4.0
1.1.1


In [None]:
class HCTLightningWrapper():
    
    def __init__(self,
                 categorical_idcs : List[int],
                 numerical_idcs : List[int],
                 *model_args,
                 epochs=60,
                 random_state : int = 42,
                 verbose : int = 1,
                 **model_kwargs) -> None:
        
        set_seed(random_state)

        self._categorical_idcs = categorical_idcs
        self._numerical_idcs = numerical_idcs
        self._epochs = epochs
        self._random_state = random_state
        self._verbose = verbose
        self._model_args = model_args
        self._model_kwargs = model_kwargs
        
    def initialize_weights(self):
        for module in self.model.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        
    def fit(self, X, y, validation_set = None, val_size=0.05):

        if validation_set is None:
            X_train, X_valid, y_train, y_valid = train_test_split(X, 
                                                                  y, 
                                                                  test_size=val_size, 
                                                                  random_state=self._random_state)
        else:
            X_train, y_train = X, y
            X_valid, y_valid = validation_set

        dl_train = create_dataloader(X_train, y_train, shuffle=True)
        dl_test = create_dataloader(X_valid, y_valid)
    
    
        self.model = PairwiseRankingNeuralNetwork(*self._model_args, **self._model_kwargs)
        checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="cindex", save_top_k=1)
        trainer = pl.Trainer(
            accelerator='cuda',
            max_epochs=self._epochs,
            log_every_n_steps=5,
            callbacks=[
                checkpoint_callback,
                LearningRateMonitor(logging_interval='epoch'),
                TQDMProgressBar(),
                StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=40, annealing_epochs=15)
            ],
        )
        trainer.fit(self.model, train_dataloaders=dl_train, val_dataloaders=dl_test)

        return self
        
    def predict(self, X):
        self.model.eval()
        y_pred, _, _ = self.model.cuda()(
                    torch.tensor(X.to_numpy()[:, self._categorical_idcs], dtype=torch.long).cuda(),
                    torch.tensor(X.to_numpy()[:, self._numerical_idcs], dtype=torch.float32).cuda(),
            )
        return y_pred

In [77]:
race_idx = X[cat_names].columns.get_loc('race_group')

In [None]:
hct_nn = HCTLightningWrapper(categorical_idcs,
                             numerical_idcs,
                             race_idx, 
                             margin=hparams['margin'],
                             learning_rate=hparams['lr'],
                             weight_decay=hparams['weight_decay'],
                             dropout=hparams['dropout'], 
                             race_std_loss_multiplier=0.1)

# kf = KFold(5, shuffle=True, random_state=42)
score_tuple, oofs, model_list, test_preds = CrossValidationExecutor(hct_nn,
                                                                    scci_metric,
                                                                    kf,
                                                                    verbose=2).run(X, y, test_data=test_df, groups=X['race_group'].values)

In [79]:
kf = StratifiedKFold(5, shuffle=True, random_state=42)
set_seed(129)  # Call this before training

for i, (train_idx, test_idx) in enumerate(kf.split(X, X['race_group'])):
    X_train, y_train = X.loc[train_idx], y.loc[train_idx]
    X_test, y_test = X.loc[test_idx], y.loc[test_idx]

    dl_train = create_dataloader(X_train, y_train, shuffle=True)
    dl_test = create_dataloader(X_test, y_test)


    model = PairwiseRankingNeuralNetwork(race_idx, hparams['margin'],hparams['lr'],hparams['weight_decay'],hparams['dropout'], race_std_loss_multiplier=0.1)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="cindex", save_top_k=1)
    trainer = pl.Trainer(
        accelerator='cpu',
        max_epochs=60,
        log_every_n_steps=1,
        callbacks=[
            checkpoint_callback,
            LearningRateMonitor(logging_interval='epoch'),
            TQDMProgressBar(),
            StochasticWeightAveraging(swa_lrs=1e-5, swa_epoch_start=40, annealing_epochs=15)
        ],
    )
    trainer.fit(model, train_dataloaders=dl_train, val_dataloaders=dl_test)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/yuwei-1/anaconda3/envs/ktools/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name        | Type    | Params | Mode 
------------------------------------------------
0 | model       | NN      | 140 K  | train
1 | binary_loss | BCELoss | 0      | train
------------------------------------------------
140 K     Trainable params
769       Non-trainable params
140 K     Total params
0.563     Total estimated model params size (MB)
79        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/yuwei-1/anaconda3/envs/ktools/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/yuwei-1/anaconda3/envs/ktools/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined