<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from advanced_autoencoders.config import ConfigMaeLarge, ConfigVQVAE
from advanced_autoencoders.dataset import MyImageDataset
from advanced_autoencoders.models import (
    get_embeddings_mae,
    get_embeddings_vae,
    get_mae_model,
    get_vae_model,
)
from advanced_autoencoders.utils import (
    get_test_transforms,
    make_images_dataframe,
    seed_everything,
)

In [1]:
#| echo: false
#| output: asis
show_doc(get_test_data_loader)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/inference.py#L29){target="_blank" style="float:right; font-size:smaller"}

### get_test_data_loader

>      get_test_data_loader (df, cnfg)

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def get_test_data_loader(df, cnfg):
    dataset = MyImageDataset(df, augmentations=get_test_transforms(cnfg))
    dl = DataLoader(dataset, batch_size=32, shuffle=False)

    return dl

In [2]:
#| echo: false
#| output: asis
show_doc(load_model)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/inference.py#L36){target="_blank" style="float:right; font-size:smaller"}

### load_model

>      load_model (cnfg, model)

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def load_model(cnfg, model):
    model.load_state_dict(torch.load(f"{cnfg.MODELS_DIR}{cnfg.model_name}.bin"))
    model.cuda()
    model.eval()

    return model

In [3]:
#| echo: false
#| output: asis
show_doc(generate_embeddings_df)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/inference.py#L44){target="_blank" style="float:right; font-size:smaller"}

### generate_embeddings_df

>      generate_embeddings_df (cnfg, model, dl, embd_name, get_embd_fnc)

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def generate_embeddings_df(cnfg, model, dl, embd_name, get_embd_fnc):
    all_embeddings = []
    with torch.no_grad():
        for i, samples in enumerate(tqdm(dl)):
            embeddings = get_embd_fnc(model, samples.cuda())
            # embeddings = torch.flatten(encoded, start_dim=1).cpu().numpy()
            all_embeddings.extend(embeddings)

    final_df = pd.DataFrame(
        data=np.array(all_embeddings),
        columns=cnfg.EMBEDDING_COL_NAMES,
    )
    final_df[embd_name] = new_df[embd_name].values

    return final_df

In [4]:
#| echo: false
#| output: asis
show_doc(save_data)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/inference.py#L61){target="_blank" style="float:right; font-size:smaller"}

### save_data

>      save_data (df, cnfg)

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def save_data(df, cnfg):
    df.to_csv(cnfg.EMBEDDING_FILE_PATH, index=False)

In [5]:
#| echo: false
#| output: asis
show_doc(generate_embedding_mae_pipeline)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/inference.py#L65){target="_blank" style="float:right; font-size:smaller"}

### generate_embedding_mae_pipeline

>      generate_embedding_mae_pipeline ()

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def generate_embedding_mae_pipeline():
    CONFIG = ConfigMaeLarge()
    seed_everything(seed=CONFIG.seed)
    df_all = make_images_dataframe(CONFIG)
    print(df_all.shape, df_all.image_name.unique().size)
    dl = get_test_data_loader(df_all, CONFIG)
    model = get_mae_model(cnfg)
    model = load_model(CONFIG, model)

    final_df = generate_embeddings_df(CONFIG, model, dl, "mae_emb", get_embeddings_mae)

    save_data(final_df, CONFIG)
    return

In [6]:
#| echo: false
#| output: asis
show_doc(generate_embedding_vqvae_pipeline)

---

[source](https://github.com/rohitMalhotra07/advanced_autoencoders/blob/main/advanced_autoencoders/inference.py#L80){target="_blank" style="float:right; font-size:smaller"}

### generate_embedding_vqvae_pipeline

>      generate_embedding_vqvae_pipeline ()

In [None]:
#| code-fold: show
#| code-summary: "Exported source"
def generate_embedding_vqvae_pipeline():
    CONFIG = ConfigVQVAE()
    seed_everything(seed=CONFIG.seed)
    df_all = make_images_dataframe(CONFIG)
    print(df_all.shape, df_all.image_name.unique().size)
    dl = get_test_data_loader(df_all, CONFIG)
    model = get_vae_model(cnfg)
    model = load_model(CONFIG, model)

    final_df = generate_embeddings_df(CONFIG, model, dl, "vae_emb", get_embeddings_vae)

    save_data(final_df, CONFIG)
    return