In [None]:
# default_exp train

In [None]:
# hide
%load_ext autoreload
%autoreload 2

# Training
> Functions that wrap the whole training processes.

In [None]:
# export
from grade_classif.imports import *
from grade_classif.models.metrics import accuracy, f_1, precision, recall
from grade_classif.models.plmodules import (
    DiscrimDataModule,
    GradeClassifDataModule,
    GradesClassifModel,
    Normalizer,
    NormalizerAN,
    NormDataModule,
    PACSDiscriminator,
)

In [None]:
# export
def train_normalizer(hparams: Namespace) -> Union[Normalizer, NormalizerAN]:
    hparams = vars(hparams)
    dm = NormDataModule(**hparams)
    if hparams.adversarial:
        model = NormalizerAN(**hparams)
    else:
        model = Normalizer(**hparams)
    # model.freeze_encoder()
    model.fit(dm)
    return model

Trains a `Normalizer` unet with parameters defined in `hparams`.

In [None]:
# export
def train_classifier(hparams: Namespace) -> GradeClassifModel:
    hparams = vars(hparams)
    dm = GradeClassifDataModule(**hparams)
    model = GradesClassifModel(
        **hparams,
        metrics=[accuracy]
        + [
            met
            for i in range(2)
            for met in (
                partial(precision, cat=i),
                partial(recall, cat=i),
                partial(f_1, cat=i),
            )
        ]
    )
    model.fit(dm)
    return model

Trains a `GradesClassifModel` for grade classifications with parameters defined in `hparams`.

In [None]:
# export
def train_discriminator(hparams: Namespace) -> PACSDiscriminator:
    hparams = vars(hparams)
    dm = DiscrimDataModule(**hparams)
    model = PACSDiscriminator(
        **hparams, metrics=[accuracy] + [met for met in (precision, recall, f_1)]
    )
    model.fit(dm)
    return model

In [None]:
from grade_classif.params.parser import hparams

In [None]:
df = pd.read_csv(hparams.concepts)

In [None]:
classes_df = pd.read_csv(hparams.concept_classes, index_col=0)

In [None]:
from grade_classif.data.dataset import ImageClassifDataset

In [None]:
if hparams.concepts is not None and hparams.concept_classes is not None:
    conc_classes_df = pd.read_csv(hparams.concept_classes, index_col=0)
    ok = conc_classes_df.loc[conc_classes_df["type"] == "K_inter"].index.values
    conc_df = pd.read_csv(hparams.concepts, index_col="patchId")

    def filt(x):
        return conc_df.loc[x.stem, "concept"] in ok


else:
    filt = None
filt = None
data = (
    ImageClassifDataset.from_folder(
        Path(hparams.data),
        lambda x: x.parts[-3],
        classes=["1", "3"],
        extensions=[".png"],
        include=["1", "3"],
        open_mode="3G",
        filterfunc=filt,
    )
    .split_by_csv(hparams.data_csv)
    .to_tensor(tfm_y=False)
)

In [None]:
labels = data.train.labels

In [None]:
n1 = len(np.argwhere(labels == "1"))
n3 = len(np.argwhere(labels == "3"))

In [None]:
n3 / n1

11.067571469823852

In [None]:
os.environ["COMET_API_KEY"] = "4p7hCzb8hjWG7Qb8CtNRRQkcG"

In [None]:
# hide
from nbdev.export import notebook2script

notebook2script()

Converted 00_core.ipynb.
Converted 01_train.ipynb.
Converted 02_predict.ipynb.
Converted 10_data.read.ipynb.
Converted 11_data.loaders.ipynb.
Converted 12_data.dataset.ipynb.
Converted 13_data.utils.ipynb.
Converted 14_data.transforms.ipynb.
Converted 15_data.color.ipynb.
Converted 20_models.plmodules.ipynb.
Converted 21_models.modules.ipynb.
Converted 22_models.utils.ipynb.
Converted 23_models.hooks.ipynb.
Converted 24_models.metrics.ipynb.
Converted 25_models.losses.ipynb.
Converted 80_params.defaults.ipynb.
Converted 81_params.parser.ipynb.
Converted 99_index.ipynb.
