<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
import warnings

import pandas as pd
from colorama import Fore, Style
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

from advanced_autoencoders.config import ConfigMaeLarge, ConfigVQVAE
from advanced_autoencoders.dataset import MyImageDataset
from advanced_autoencoders.trainers import (
    PlModelMAE,
    PlModelVQVAE,
    get_model_checkpoint_callback,
    get_trainer_mae,
    get_trainer_vq,
)
from advanced_autoencoders.utils import (
    get_train_transforms,
    make_images_dataframe,
    seed_everything,
)

r_ = Fore.RED
b_ = Fore.BLUE
c_ = Fore.CYAN
g_ = Fore.GREEN
y_ = Fore.YELLOW
m_ = Fore.MAGENTA
sr_ = Style.RESET_ALL
warnings.filterwarnings("ignore")

In [1]:
#| echo: false
#| output: asis
show_doc(make_train_test_split)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/training_pipelines.py#L40){target="_blank" style="float:right; font-size:smaller"}

### make_train_test_split

>      make_train_test_split (df:pandas.core.frame.DataFrame, cnfg)

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def make_train_test_split(df: pd.DataFrame, cnfg):
    df_train, df_val = train_test_split(df, test_size=cnfg.size_val, random_state=42)

    return df_train, df_val

In [2]:
#| echo: false
#| output: asis
show_doc(get_reconstruction_sample)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/training_pipelines.py#L46){target="_blank" style="float:right; font-size:smaller"}

### get_reconstruction_sample

>      get_reconstruction_sample (cnfg, df)

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def get_reconstruction_sample(cnfg, df):
    val_reconstruction_dataset = MyImageDataset(
        df, augmentations=get_train_transforms(cnfg)
    )
    val_dl_rec = DataLoader(val_reconstruction_dataset, batch_size=4, shuffle=False)
    for test_sample in val_dl_rec:
        break
    test_sample = test_sample.cuda(f"cuda:{cnfg.gpus[0]}")

    return test_sample

In [3]:
#| echo: false
#| output: asis
show_doc(train_pipeline_mae)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/training_pipelines.py#L58){target="_blank" style="float:right; font-size:smaller"}

### train_pipeline_mae

>      train_pipeline_mae ()

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def train_pipeline_mae():
    CONFIG = ConfigMaeLarge()
    seed_everything(seed=CONFIG.seed)
    df_all = make_images_dataframe(CONFIG)
    print(f"Making Image DF from DIR DONE! TOTAL IMAGES:{df_all.shape[0]}")
    df_train, df_val = make_train_test_split(df_all, CONFIG)
    print(
        f"TRAIN VAL SPLIT DONE! TOTAL IMAGES TRAIN:{df_train.shape[0]}, VAL:{df_val.shape[0]}"
    )
    test_sample = get_reconstruction_sample(CONFIG, df_val)
    checkpoint_callback = get_model_checkpoint_callback(CONFIG)
    trainer = get_trainer_mae(CONFIG, checkpoint_callback)
    pl_model = PlModelMAE(CONFIG, df_train, df_val, test_sample)

    trainer.fit(pl_model)

In [4]:
#| echo: false
#| output: asis
show_doc(train_pipeline_vqvae)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/training_pipelines.py#L75){target="_blank" style="float:right; font-size:smaller"}

### train_pipeline_vqvae

>      train_pipeline_vqvae ()

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def train_pipeline_vqvae():
    CONFIG = ConfigVQVAE()
    seed_everything(seed=CONFIG.seed)
    df_all = make_images_dataframe(CONFIG)
    print(f"Making Image DF from DIR DONE! TOTAL IMAGES:{df_all.shape[0]}")
    df_train, df_val = make_train_test_split(df_all, CONFIG)
    print(
        f"TRAIN VAL SPLIT DONE! TOTAL IMAGES TRAIN:{df_train.shape[0]}, VAL:{df_val.shape[0]}"
    )
    test_sample = get_reconstruction_sample(CONFIG, df_val)
    checkpoint_callback = get_model_checkpoint_callback(CONFIG)
    trainer = get_trainer_vq(CONFIG, checkpoint_callback)
    pl_model = PlModelVQVAE(CONFIG, df_train, df_val, test_sample)

    trainer.fit(pl_model)