## **Test Training Step**
This notebook will be used to test the training step function in the `ModelTrainer` class

In [None]:
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
import itertools
from torchvision import transforms

import reconstruction_deep_network
from reconstruction_deep_network.data_loader.custom_loader import CustomDataLoader
from reconstruction_deep_network.trainer.trainer import ModelTrainer

In [None]:
module_dir = reconstruction_deep_network.__path__[0]
root_dir = os.path.dirname(module_dir)
data_dir = os.path.join(root_dir, "data", "v1")
embedding_dir = os.path.join(data_dir, "text_embeddings")
img_encoding_dir = os.path.join(data_dir, "image_latents")
if not os.path.isdir(img_encoding_dir):
    os.makedirs(img_encoding_dir)

In [None]:
def load_text_encoding(scan_id: str, img_name: str):
    file_name = os.path.join(embedding_dir, scan_id, f"{img_name}.npz")
    np_tensor = np.load(file_name, allow_pickle=True)
    embedding_vec = []
    for key in np_tensor.files:
        data_dict = np_tensor[key].item()
        embeddings = torch.from_numpy(data_dict["embeddings_1"])
        embedding_vec.append(embeddings)
    
    return torch.stack(embedding_vec, dim=1)

def load_img_encoding(scan_id: str, img_name: str):
    file_name = os.path.join(img_encoding_dir, scan_id, f"{img_name}.npz")
    np_tensor = np.load(file_name)["latent"]
    return np_tensor

In [None]:
# scan_id = "17DRP5sb8fy"
# img_name = "00ebbf3782c64d74aaf7dd39cd561175"
# embedding_vec = load_text_encoding(scan_id, img_name)

In [None]:
# scan_id = "17DRP5sb8fy"
# img_name = "00ebbf3782c64d74aaf7dd39cd561175"
# latents = load_img_encoding(scan_id, img_name)

In [None]:
model_trainer = ModelTrainer()

In [None]:
dataset = CustomDataLoader()
train_loader = torch.utils.data.DataLoader(
                    dataset,
                    batch_size = 4,
                    shuffle = False,
                    num_workers = 0,
                    drop_last = True)

In [None]:
# for batch in train_loader:
#     loss = model_trainer.training_step(batch)
#     print(loss)
#     break


## **Test Inference Pipeline**

In [None]:
val_loader = torch.utils.data.DataLoader(
                    dataset,
                    batch_size = 1,
                    shuffle = False,
                    num_workers = 0,
                    drop_last = True)

In [None]:
for batch in val_loader:
    loss = model_trainer.validation_step(batch, 0)
    print(loss)
    break

In [None]:
# images = loss.squeeze()
# combs = list(itertools.product(range(2), range(4)))

# fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(12, 5))
# for itr, comb in enumerate(combs):
#     axs[comb].imshow(images[itr])

# plt.show()


In [None]:
preprocess = transforms.Compose(
    [
        transforms.Resize(299),
        transforms.CenterCrop(299),                 
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)

x_inp = preprocess(torch.randn(8, 3, 512, 512)) 

In [None]:
x_inp.shape