In [1]:
import os
import torch
import torch.nn as nn
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

In [2]:
DATASET_PATH = "./DATASET/"
data = pd.read_csv('annotations.csv', index_col=0)
data.head()

Unnamed: 0,file,bbox,class,size (cm)
0,26_05_21-B20,"[2272.628996958517, 1685.2591150516498, 837.97...",Sepia officinalis,8.435085
1,26_05_21-B25,"[1807.393678506843, 1585.7146585117644, 1094.7...",Mullus barbatus,16.21142
2,26_05_21-B25,"[2648.726235342625, 1368.0723135423739, 345.68...",Mullus barbatus,14.422977
3,26_05_21-B25,"[2845.1096085383197, 1180.704037335252, 426.91...",Mullus barbatus,15.604945
4,26_05_21-B25,"[3192.0335090997633, 1235.4288323372514, 337.1...",Mullus barbatus,14.829357


In [3]:
class DeepFishDataset(Dataset):
    def __init__(self, root_dir: str, data: pd.DataFrame) -> None:
        super().__init__()
        self.root_dir = root_dir
        self.data = data
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
        ])
        self.__filter()

    def __filter(self) -> None:
        empty = []
        for i, image_name in enumerate(self.data['file']):
            exists = self.__is_image_exist(image_name=image_name)
            if not exists:
                empty.append(i)
        self.data.drop(index=empty, inplace=True)


    def __is_image_exist(self, image_name: str) -> bool:
        return os.path.exists(os.path.join(self.root_dir, f"{image_name}.jpg"))

    def __getitem__(self, index: int):
        image_name = self.__get_image_name(index)
        image = self.__get_image(image_name)
        bbox = self.__get_bbox(image_name)
        size = self.__get_size(image_name)
        
        if self.transform:
            image = self.transform(image)

        return {
            'image': image,
            'bbox': bbox,
            'size': size,
        }

    def __get_image_name(self, index: int) -> str:
        return self.data.iloc[index]['file']

    def __get_image(self, image_name: str) -> Image.Image:
        return Image.open(os.path.join(self.root_dir, f"{image_name}.jpg")).convert("RGB")

    def __get_bbox(self, image_name: str) -> list:
        strings = self.data[self.data['file'] == image_name]['bbox']
        strings = strings.str[1:-1]
        strings = strings.str.split(',')
        bbox = [float(s) for s in strings[0]]
        return torch.tensor(bbox, dtype=torch.float32)

    def __get_size(self, image_name: str) -> float:
        return self.data[self.data['file'] == image_name]['size (cm)']

    def __len__(self) -> int:
        return len(self.data)
    
dataset = DeepFishDataset(root_dir=DATASET_PATH, data=data)
dataset[0]


{'image': tensor([[[-1.6555, -1.6555, -1.6384,  ..., -0.8164, -0.7822, -0.7822],
          [-1.6042, -1.6213, -1.6384,  ..., -0.7308, -0.7137, -0.7822],
          [-1.6042, -1.6042, -1.6213,  ..., -0.7822, -0.7650, -0.8335],
          ...,
          [-1.0904, -1.0219, -1.0048,  ..., -1.2959, -1.3644, -1.3815],
          [-1.0562, -1.0390, -1.0048,  ..., -1.2788, -1.3473, -1.3815],
          [-1.0562, -1.0733, -1.0390,  ..., -1.2788, -1.3473, -1.3473]],
 
         [[-1.5455, -1.5455, -1.5280,  ..., -0.9153, -0.9153, -0.9153],
          [-1.4930, -1.5105, -1.5280,  ..., -0.8277, -0.8452, -0.9153],
          [-1.4930, -1.4930, -1.5105,  ..., -0.8803, -0.8627, -0.9328],
          ...,
          [-1.0728, -1.0028, -0.9853,  ..., -1.2829, -1.3529, -1.3704],
          [-1.0378, -1.0203, -0.9853,  ..., -1.2654, -1.3354, -1.3704],
          [-1.0378, -1.0553, -1.0203,  ..., -1.2654, -1.3354, -1.3354]],
 
         [[-1.2816, -1.2816, -1.2641,  ..., -0.6890, -0.6890, -0.6890],
          [-1.2293,

In [4]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [5]:
len(dataset), len(train_dataset), len(test_dataset)

(3751, 3000, 751)

KeyError: 0