In [5]:
import os
import csv

import pandas as pd
from torchvision.io import read_image

from torch.utils.data import Dataset


class SingleBag(Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        self.root = root
        self.pd_csv = pd.read_csv(os.path.join(root, "label.csv"))
        
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.pd_csv)

    def __getitem__(self, idx):
        image = read_image(
            os.path.join(self.root, self.pd_csv.iloc[idx, 1])
        )
        
        vel_ang = self.pd_csv.iloc[idx, 2:]
        
        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            vel_ang = self.target_transform(vel_ang)
            
        return image, vel_ang


class MultiBag(SingleBag):
    def __init__(self, root, transform=None, target_transform=None):
        self.root = root
        
        self.transform = transform
        self.target_transform = target_transform
        
        self.data = []
        self.splits = []
        
        start = 0
        end = 0
        for dir in os.listdir(self.root):
            SingleBag(dir, transform, target_transform)
            csv_path = os.path.join(self.root, dir, "label.csv")
            
            data = []
            
            csvfile = pd.read_csv(csv_path)
            csv_reader = csv.reader(csvfile, deliminator=" ", quotechar="|")
            
            for path, vel, ang in csv_reader:
                self.data.append(path, (vel, ang))
                end += 1
                
            csvfile.close()
            
            self.splits.append((start, end))
            
            start = end
            
    def __getitem__(self, idx):
        path, vel, ang = self.data[idx]
        img = read_img(os.path.join(self.root, path))
        
        if self.transform is not None:
            img = self.transform(img)
        
        if self.target_transform is not None:
            label = self.target_transform(label)
        return img, (vel, ang)
    
    def __len__(self):
        return len(self.data)
    
single_dataset = SingleBag("/work/dataset/2021-08-24-12-41-32")

x, y = single_dataset[0]

In [9]:
import torch

y = torch.tensor(y)