In [None]:
# default_exp train

In [None]:
# hide
%load_ext autoreload
%autoreload 2

# Training
> Functions that wrap the whole training processes.

In [None]:
# export
from grade_classif.imports import *
from pytorch_lightning.metrics import Accuracy

from grade_classif.data.modules import (
    ImageClassifDataModule,
    MILDataModule,
    NormDataModule,
    RNNAggDataModule,
    FeaturesClassifDataModule
)
from grade_classif.models.plmodules import (
    ImageClassifModel,
    MILModel,
    Normalizer,
    RNNAggregator,
    RNNAttention,
    WSITransformer
)

In [None]:
# export
def train_normalizer(hparams: Namespace) -> Normalizer:
    hparams = vars(hparams)
    dm = NormDataModule(**hparams)
    model = Normalizer(**hparams)
    # model.freeze_encoder()
    model.fit(dm)
    return model

Trains a `Normalizer` unet with parameters defined in `hparams`.

In [None]:
# export
def train_classifier(hparams: Namespace) -> ImageClassifModel:
    hparams = vars(hparams)
    classes = ["1", "3"]
    dm = ImageClassifDataModule(classes=classes, label_func=lambda x: x.parts[-3], **hparams)
    model = ImageClassifModel(
        classes=classes,
        n_classes=len(classes),
        **hparams
    )
    model.fit(dm, monitor="AUC_3")
    return model

In [None]:
# export
def train_transformer(hparams: Namespace) -> ImageClassifModel:
    hparams = vars(hparams)
    classes = ["1", "3"]
    dm = FeaturesClassifDataModule(classes=classes, get_id=lambda x: x.name, label_func=lambda x: x.parts[-2], **hparams)
    model = WSITransformer(
        classes=classes,
        n_classes=len(classes),
        **hparams
    )
    model.fit(dm, monitor="AUC_3")
    return model

In [None]:
# export
def train_rnn_attention(hparams: Namespace) -> RNNAttention:
    hparams = vars(hparams)
    classes = ["1", "3"]
    dm = ImageClassifDataModule(classes=classes, label_func=lambda x: x.parts[-3], **hparams)
    model = RNNAttention(
        classes=classes,
        n_classes=len(classes),
        **hparams
    )
    model.fit(dm)
    return model

In [None]:
# export
def train_reargmt(hparams: Namespace) -> ImageClassifModel:
    classes = ["NoReargmt", "DHL_THL"]
    def _label_func(x):
        return x.parts[-3]
    hparams = vars(hparams)
    dm = ImageClassifDataModule(classes=classes, label_func=_label_func, **hparams)
    model = ImageClassifModel(
        classes=classes,
        n_classes=len(classes),
        **hparams
    )
    model.fit(dm)
    return model

In [None]:
# export
def train_FLFH(hparams: Namespace) -> ImageClassifModel:
    classes = ["FH", "FL"]
    def _label_func(x):
        return x.parts[-3]
    hparams = vars(hparams)
    dm = ImageClassifDataModule(classes=classes, label_func=_label_func, **hparams)
    model = ImageClassifModel(
        classes=classes,
        n_classes=len(classes),
        **hparams
    )
    model.fit(dm)
    return model

Trains a `GradesImageClassifModel` for grade classifications with parameters defined in `hparams`.

In [None]:
# export
def train_discriminator(hparams: Namespace) -> ImageClassifModel:
    classes = ["04", "05", "08"]
    def _label_func(x):
        for cl in classes:
            if f"PACS{cl}" in x.name:
                return cl
    hparams = vars(hparams)
    dm = ImageClassifDataModule(classes=classes, label_func=_label_func, **hparams)
    model = ImageClassifModel(
        classes=classes,
        n_classes=len(classes),
        **hparams
    )
    model.fit(dm)
    return model

In [None]:
# export
def train_cancer_detector(hparams: Namespace) -> ImageClassifModel:
    hparams = vars(hparams)
    classes = ["artefact", "cancer", "non_cancer"]
    dm = ImageClassifDataModule(
        classes=classes,
        label_func=lambda x: x.parent.name,
        get_id=lambda x: "_".join(x.name.split("_")[:-2]),
        **hparams
    )
    model = ImageClassifModel(
        classes=classes,
        n_classes=len(classes),
        **hparams
    )
    model.fit(dm, monitor="f_1_cancer")
    return model

In [None]:
# export
def train_mil_cancer_detector(hparams: Namespace) -> MILModel:
    hparams = vars(hparams)
    dm = MILDataModule(classes=["None", "Infilt"], **hparams)
    model = MILModel(
        **hparams
    )
    model.fit(dm, num_sanity_val_steps=0, reload_dataloaders_every_epoch=True)
    return model    

In [None]:
# export
def train_mil_reargmt(hparams: Namespace) -> MILModel:
    hparams = vars(hparams)
    dm = MILDataModule(
        classes=["NoReargmt", "DHL_THL"],
        extensions=[".mrxs", ".svs"],
        label_func=lambda x: x.parts[-3],
        **hparams
    )
    model = MILModel(
        **hparams
    )
    model.fit(
        dm,
        num_sanity_val_steps=0,
        reload_dataloaders_every_epoch=True,
        check_val_every_n_epochs=5,
    )
    return model

In [None]:
# export
def train_rnn_reargmt(hparams: Namespace) -> RNNAggregator:
    hparams = vars(hparams)
    classes = ["NoReargmt", "DHL_THL"]
    dm = RNNAggDataModule(
        classes=classes,
        extensions=[".mrxs", ".svs"],
        label_func=lambda x: x.parts[-3],
        **hparams
    )
    model = RNNAggregator(
        classes=classes,
        **hparams,
        metrics=[accuracy, precision, recall, f_1]
    )
    model.fit(
        dm,
        log_every_n_steps=5
    )
    return model

In [None]:
from grade_classif.params.parser import hparams

In [None]:
df = pd.read_csv(hparams.concepts)

In [None]:
classes_df = pd.read_csv(hparams.concept_classes, index_col=0)

In [None]:
from grade_classif.data.dataset import ImageClassifDataset

In [None]:
if hparams.concepts is not None and hparams.concept_classes is not None:
    conc_classes_df = pd.read_csv(hparams.concept_classes, index_col=0)
    ok = conc_classes_df.loc[conc_classes_df["type"] == "K_inter"].index.values
    conc_df = pd.read_csv(hparams.concepts, index_col="patchId")

    def filt(x):
        return conc_df.loc[x.stem, "concept"] in ok


else:
    filt = None
filt = None
data = (
    ImageClassifDataset.from_folder(
        Path(hparams.data),
        lambda x: x.parts[-3],
        classes=["1", "3"],
        extensions=[".png"],
        include=["1", "3"],
        open_mode="3G",
        filterfunc=filt,
    )
    .split_by_csv(hparams.data_csv)
    .to_tensor(tfm_y=False)
)

In [None]:
labels = data.train.labels

In [None]:
n1 = len(np.argwhere(labels == "1"))
n3 = len(np.argwhere(labels == "3"))

In [None]:
n3 / n1

11.067571469823852

In [None]:
os.environ["COMET_API_KEY"] = "4p7hCzb8hjWG7Qb8CtNRRQkcG"

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()

Converted 00_core.ipynb.
Converted 01_train.ipynb.
Converted 02_predict.ipynb.
Converted 10_data.read.ipynb.
Converted 11_data.loaders.ipynb.
Converted 12_data.dataset.ipynb.
Converted 13_data.utils.ipynb.
Converted 14_data.transforms.ipynb.
Converted 15_data.color.ipynb.
Converted 16_data.modules.ipynb.
Converted 20_models.plmodules.ipynb.
Converted 21_models.modules.ipynb.
Converted 22_models.utils.ipynb.
Converted 23_models.hooks.ipynb.
Converted 24_models.metrics.ipynb.
Converted 25_models.losses.ipynb.
Converted 80_params.defaults.ipynb.
Converted 81_params.parser.ipynb.
Converted 99_index.ipynb.
