In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
import cv2
import torchvision.transforms as transforms

In [20]:
class ThreatDataset(Dataset):
    def __init__(self, data, loader_type='train', transforms = None):
        self.folder_names = ['carrying', 'threat', 'normal']
        self.data = data
        self.transform = transforms
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.folder_names.index(data.parent.name)
        image = cv2.imread(str(data))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image)
        return image, label

In [13]:
import random
from pathlib import Path

def split_data(data_dir, train_size=0.8, val_size = 0.1):
    random.seed(1234)
    data = Path('data').glob('*/*')
    data = [x for x in data if x.is_file() and x.suffix != '.zip']
    random.shuffle(data)
    train_size = int(len(data) * train_size)
    val_size = int(len(data) * val_size)
    train_data = data[:train_size]
    val_data = data[train_size:train_size+val_size]
    test_data = data[train_size+val_size:]

    return train_data, val_data, test_data


train_data, val_data, test_data = split_data('data')

In [21]:
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transforms = {
    'train': train_transforms,
    'val': val_transforms,
    'test': test_transforms
}

train_dataset = ThreatDataset(train_data, loader_type="train", transforms=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [32]:
from scripts import dataloader

train_loader, val_loader, test_loader = dataloader.get_loaders()

In [33]:
x, y = iter(train_loader).next()
y

tensor([2, 1, 0, 0, 0, 2, 1, 1, 0, 1, 0, 1, 2, 1, 2, 1, 2, 0, 1, 1, 0, 0, 0, 2,
        0, 0, 1, 0, 1, 1, 1, 2])