# Import Dependencies

In [50]:
import os

import pandas as pd
import torchvision
from torchvision.transforms import (
    CenterCrop,
    Compose,
    ToTensor,
)
from PIL import Image
from torch.utils.data import DataLoader, Dataset

# Global Variables

In [28]:
DATA_PATH = "./../sample_data"
METADATA_DIR = f"{DATA_PATH}/metadata"
IMAGERY_DIR = f"{DATA_PATH}/imagery/realsense_overhead"

# Helper Functions for Data

In [None]:
def read_csv_variable_cols(filepath: str):
    """https://stackoverflow.com/a/57824142"""
    ### Loop the data lines
    with open(filepath, 'r') as temp_f:
        # get No of columns in each line
        col_count = [ len(l.split(",")) for l in temp_f.readlines() ]

    ### Generate column names  (names will be 0, 1, 2, ..., maximum columns - 1)
    column_names = [i for i in range(0, max(col_count))]

    ### Read csv
    return pd.read_csv(filepath, header=None, delimiter=",", names=column_names, low_memory=False)

# Ingredient and Dish Metadata (Groun Truths)

In [None]:
ingredient_metadata = pd.read_csv(f"{METADATA_DIR}/ingredients_metadata.csv", names=["name", "id", "cal", "fat", "carb", "protein"], skiprows=1)
print(ingredient_metadata)

In [None]:
# Metadata for dishes has variable numbers of columns per row.
# Can do similar stuff to dish_metadata_cafe2.csv
dish_metadata_1 = read_csv_variable_cols(f"{METADATA_DIR}/dish_metadata_cafe1.csv")
print(dish_metadata_1)

# Datasets and DataLoaders

In [52]:
class RGBDataset(Dataset):
    """4.2 The input resolution to the
    network is a 256x256 image, where images were downsized
    and center cropped in order to retain the most salient dish
    region.

    我们baseline应该只用RGB就行 (根据4.2).
    """

    # TODO: Also return the metadata here?
    def __init__(self, data_dir, transforms=Compose([CenterCrop((256, 256)), ToTensor()])):
        self.data_dir = data_dir
        self.transforms = transforms

        self.img_paths = list(
            map(
                lambda fname: os.path.join(self.data_dir, fname),
                sorted(os.listdir(self.data_dir)),
            )
        )

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        rgd_path = f"{self.img_paths[idx]}/rgb.png"
        return self.transforms(Image.open(rgd_path))

In [53]:
rgb_dataset = RGBDataset(IMAGERY_DIR)
for data in rgb_dataset:
    print(data.shape)
    # torch.Size([3, 256, 256])
    break

torch.Size([3, 256, 256])
