In [10]:
from omegaconf import DictConfig
from lib.utils import load_config
import numpy as np
import pandas as pd
from pathlib import Path
import cv2
from pytorch_metric_learning import samplers
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
from omegaconf import DictConfig
from typing import TypedDict
from lib.utils import load_config 
from collections import defaultdict



class MetaInfo:
    def __init__(self, cfg: DictConfig):
        dtype = {"image_id": str, "sketch_id": str}
        self.df = pd.read_csv(cfg.metainfo_path, dtype=dtype)

    @property
    def labels(self):
        return np.array(self.df["label"])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index: int):
        return self.df.iloc[index].to_dict()

class ShapeNetDataset:
    def __init__(
        self,
        cfg: DictConfig,
        stage: str = "train",
    ) -> None:
        self.stage = stage
        self.cfg = cfg
        self.metainfo = MetaInfo(cfg=cfg)
        self.data = self._load_data(cfg=cfg) 

    def _load_data(self, cfg: DictConfig):
        data = defaultdict(lambda: defaultdict(dict))
        for obj_id in cfg.obj_ids:
            for path in Path(cfg.dataset_path, obj_id, "images").glob("*.jpg"):
                data[obj_id]["images"][path.stem] = cv2.imread(path.as_posix())
            for path in Path(cfg.dataset_path, obj_id, "sketches").glob("*.jpg"):
                data[obj_id]["sketches"][path.stem] = cv2.imread(path.as_posix())
        return data

    def __len__(self):
        return len(self.metainfo)

    def __getitem__(self, index):
        # extract information
        info = self.metainfo[index]
        obj_id = info["obj_id"]
        image_id = info["image_id"]
        sketch_id = info["sketch_id"]
        label = info["label"]

        # gather the image data
        sketch = self.data[obj_id]["sketches"][sketch_id]
        image = self.data[obj_id]["images"][image_id]

        return {
            "sketch": sketch,
            "image": image,
            "label": label,
        }
    
cfg = load_config()
metainfo = MetaInfo(cfg=cfg)
dataset = ShapeNetDataset(cfg=cfg)
sampler = samplers.MPerClassSampler(metainfo.labels, m=4, batch_size=8, length_before_new_iter=20)
loader = DataLoader(dataset=dataset, sampler=sampler, batch_size=8, drop_last=True)
for index, batch in enumerate(iter(loader)):
    print(index, batch)

0 {'sketch': tensor([[[[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],

         [[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],

         [[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],

         ...,

         [[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],

         [[255, 255, 255],
          [255, 255, 255],
          [255, 255, 255],
          ...,
          [255, 255, 255],
          [255, 255, 255],
          [255, 255, 255]],

         [[255, 255, 255],
          [255, 255, 255],
          [255, 255, 25

In [7]:
metainfo.labels

array([0, 0, 0, ..., 1, 1, 1])

In [48]:
import pandas as pd
data = []
label = 0
for obj_id in cfg.obj_ids:
    for sketch_id in sorted(list(path.stem for path in Path(cfg.dataset_path, obj_id, "sketches").glob("*.jpg"))):
        for image_id in sorted(list(path.stem for path in Path(cfg.dataset_path, obj_id, "images").glob("*.jpg"))):
            data.append({
                "obj_id": str(obj_id),
                "sketch_id": str(sketch_id),
                "image_id": str(image_id),
                "label": label
            })
    label += 1 
pd.DataFrame(data).to_csv(cfg.metainfo_path, index=None)

In [46]:
import csv
df.to_csv("test.csv", index=None)
df1 = pd.read_csv("test.csv", dtype={"image_id": str, "sketch_id": str})
df1

Unnamed: 0,obj_id,sketch_id,image_id,label
0,3c4ed9c8f76c7a5ef51f77a6d7299806,00000,00000,0
1,3c4ed9c8f76c7a5ef51f77a6d7299806,00000,00001,0
2,3c4ed9c8f76c7a5ef51f77a6d7299806,00000,00002,0
3,3c4ed9c8f76c7a5ef51f77a6d7299806,00000,00003,0
4,3c4ed9c8f76c7a5ef51f77a6d7299806,00000,00004,0
...,...,...,...,...
4995,ffd9387a533fe59e251990397636975f,00049,00045,1
4996,ffd9387a533fe59e251990397636975f,00049,00046,1
4997,ffd9387a533fe59e251990397636975f,00049,00047,1
4998,ffd9387a533fe59e251990397636975f,00049,00048,1


In [30]:
df1

Unnamed: 0,index,obj_id,sketch_id,image_id,label
0,0,3c4ed9c8f76c7a5ef51f77a6d7299806,0,0,0
1,1,3c4ed9c8f76c7a5ef51f77a6d7299806,0,1,0
2,2,3c4ed9c8f76c7a5ef51f77a6d7299806,0,2,0
3,3,3c4ed9c8f76c7a5ef51f77a6d7299806,0,3,0
4,4,3c4ed9c8f76c7a5ef51f77a6d7299806,0,4,0
...,...,...,...,...,...
4995,4995,ffd9387a533fe59e251990397636975f,49,45,1
4996,4996,ffd9387a533fe59e251990397636975f,49,46,1
4997,4997,ffd9387a533fe59e251990397636975f,49,47,1
4998,4998,ffd9387a533fe59e251990397636975f,49,48,1


In [61]:
images = []
for obj

images = load_images("3c4ed9c8f76c7a5ef51f77a6d7299806", "sketches")

(256, 256, 3)

In [None]:


for sketch_id in range():