In [25]:
import os
from glob import glob
import gc
import sys
path = os.path.join(os.getcwd(), '..')
sys.path.append(path)

import cv2
import torch
import numpy as np
import pandas as pd

from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from ensemble_boxes import *

from src.config.GlobalConfig import GlobalConfig
from src.model.efficientdet import get_efficientdet
from src.data.df_utils import read_boxes_df, get_kfolds_df
from src.common.utils import read_image, xywh2ltrb

from 

In [15]:
CSV_PATH = '/userhome/34/h3509807/wheat-data/train.csv'
DATA_ROOT_PATH = '/userhome/34/h3509807/wheat-data'
MODEL_CKPT = '../models/effdet5-512/best-checkpoint-256.bin'

RESIZE_SZ = 512
SCALE_FACTOR = int(1024 / RESIZE_SZ)

#### 1. Dataset and Transforms

In [16]:
boxes_df = read_boxes_df(CSV_PATH)
kfolds_df = get_kfolds_df(boxes_df, kfolds = 5)



In [7]:
def get_valid_transforms():
    return A.Compose([
        A.Resize(height = RESIZE_SZ, width = RESIZE_SZ, p = 1.0),
        ToTensorV2(p = 1.)
        ], p = 1.)

In [27]:
x = kfolds_df[kfolds_df.fold == 0].index.values

In [44]:
class kFoldsDataset(Dataset):
    def __init__(self, boxes_df, kfolds_df, kfolds_idx, 
                 train_dir, transforms = None):
        super().__init__()
        
        self.image_ids = kfolds_df[kfolds_df.fold == kfolds_idx].index.values
        self.boxes_df = boxes_df
        self.transforms = transforms
        self.train_dir = train_dir
        
    def __len__(self) -> int:
        return self.image_ids.shape[0]

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        image, boxes = self.load_image_and_boxes(index)
        
        # only one class
        labels = torch.ones((boxes.shape[0],), dtype = torch.int64)
        
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        target['index'] = torch.tensor([index])

        if self.transforms:
            for i in range(10):
                sample = self.transforms(**{
                    'image': image,
                    'bboxes': target['boxes'],
                    'labels': labels
                })

                if len(sample['bboxes']) > 0:
                    image = sample['image']
                    target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
                    #yxyx: be warning
                    target['boxes'][:,[0,1,2,3]] = target['boxes'][:,[1,0,3,2]]  
                    target['labels'] = torch.as_tensor(sample['labels'])
                    break
                    
        assert target['boxes'].shape[0] ==target['labels'].shape[0], 'boxes len != labels len'
        return image, target, image_id
    
    def load_image_and_boxes(self, index):
        image_id = self.image_ids[index]
        image = read_image(image_id, self.train_dir)
        records = self.boxes_df[self.boxes_df['image_id'] == image_id]
        boxes = records[['x', 'y', 'w', 'h']].values
        boxes = xywh2ltrb(boxes)
        return image, boxes
    
    
def collate_fn(batch):
    return tuple(zip(*batch))

In [42]:
train_dir = f'{DATA_ROOT_PATH}/train'
tfms = get_valid_transforms()

ds = kFoldsDataset(boxes_df, kfolds_df, kfolds_idx = 0, 
                   train_dir = train_dir, transforms = tfms)

In [45]:
dl = DataLoader(ds, batch_size = 2, shuffle = False, 
                num_workers = 4, drop_last = False, 
                collate_fn = collate_fn)