In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import re
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from albumentations import *
import random
import cv2
import torch
from matplotlib import pyplot as plt
# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))


In [None]:
#show nhiều ảnh
def plot_imgs(imgs, cols=5, size=7, is_rgb=False):
    rows = len(imgs)//cols + 1
    fig = plt.figure(figsize=(cols*size, rows*size))
    for i, img in enumerate(imgs):
        fig.add_subplot(rows, cols, i+1)
        if is_rgb:
            plt.imshow(img)
        else:
            plt.imshow(img[:,:,::-1])
    plt.show()

# vẽ bounding box lên ảnh
def visualize_bbox(img, boxes, thickness=3, color=(0, 0, 255)):
    img_copy = img.copy()
    for box in boxes:
        img_copy = cv2.rectangle(
            img_copy,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, thickness)
    return img_copy


In [None]:
def expand_bbox(x):
    r = np.array(re.findall("([0-9]+[.]?[0-9]*)", x))
    if len(r) == 0:
        r = [-1, -1, -1, -1]
    return r

def read_data_in_csv(csv_path="./wheat-dataset/train.csv"):
    df = pd.read_csv(csv_path)
    df['x'], df['y'],  df['w'], df['h'] = -1, -1, -1, -1
    df[['x', 'y', 'w', 'h']] = np.stack(df['bbox'].apply(lambda x: expand_bbox(x)))
    df.drop(columns=['bbox'], inplace=True)
    df['x'] = df['x'].astype(np.float)
    df['y'] = df['y'].astype(np.float)
    df['w'] = df['w'].astype(np.float)
    df['h'] = df['h'].astype(np.float)
    objs = []
    img_ids = set(df["image_id"])
    
    for img_id in tqdm(img_ids):
        records = df[df["image_id"] == img_id]
        boxes = records[['x', 'y', 'w', 'h']].values
        area = boxes[:,2]*boxes[:,3]
        boxes[:,2] = boxes[:,0] + boxes[:,2]
        boxes[:,3] = boxes[:,1] + boxes[:,3]

        obj = {
            "img_id": img_id,
            "boxes": boxes,
            "area":area
        }
        objs.append(obj)
    return objs



In [None]:
Compose([
        HorizontalFlip(p=0.5)
        VerticalFlip(p=0.5),
        ToGray(p=0.01),
        GaussNoise(p=0.2),
        Blur(blur_limit=3, p=0.1),
        RandomBrightnessContrast(),
        HueSaturationValue(p=0.25)],
        bbox_params=BboxParams(format='coco', min_area=0, min_visibility=0, label_fields=['labels']))


In [None]:
def get_aug(aug):
    return Compose(aug, bbox_params=BboxParams(format='pascal_voc', min_area=0, min_visibility=0, label_fields=['labels']))

def load_img(img_id, folder):
    img_fn = f"{folder}/{img_id}.jpg"
    img = cv2.imread(img_fn).astype(np.float32)
    img /= 255.0
    return img

class WheatDataset(Dataset):
    def __init__(self, df, img_dir, img_size, mode='train', bbox_removal_threshold=0.25):
        self.df = df
        self.img_size = img_size
        self.img_dir = img_dir
        assert mode in  ['train', 'valid']
        self.mode = mode
        self.bbox_removal_threshold = bbox_removal_threshold

#         if self.mode == 'train':
#             random.shuffle(self.df)
            
        self.transform = get_aug([
             HorizontalFlip(p=1)
#             VerticalFlip(p=0.5),
#             ToGray(p=0.01),
#             OneOf([
#             GaussNoise(p=1)
#             ], p=0.2),
#             OneOf([
#                 MotionBlur(p=0.2),
#                 MedianBlur(blur_limit=3, p=0.1),
#                 Blur(blur_limit=3, p=0.1),
#             ], p=0.2),
#             RandomBrightnessContrast(),
#             HueSaturationValue(p=0.25)
        
        ])
        
        
        
    def __getitem__(self, idx):
        img_data = self.df[idx]
        bboxes = img_data["boxes"]
        box_nb = len(bboxes)
        labels = torch.ones((box_nb,), dtype=torch.int64)
        iscrowd = torch.zeros((box_nb,), dtype=torch.int64)
        img = load_img(img_data["img_id"], self.img_dir)
        area = img_data["area"]
        if self.transform is not None:
            sample = {
                "image":img,
                "bboxes": bboxes,
                "labels": labels,
                "area": area
            }
            sample = self.transform(**sample)
            img = sample['image']
            area = sample["area"]
            bboxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)

        target = {}
        target['boxes'] = bboxes.type(torch.float32)
        target['labels'] = labels
        target['area'] = torch.as_tensor(area, dtype=torch.float32)
        target['iscrowd'] = iscrowd
        target["image_id"] = torch.tensor([idx])
        return img, target
        
    def __len__(self):
        return len(self.df)
    

def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
df = read_data_in_csv('../input/global-wheat-detection/train.csv')


In [None]:
data_set = WheatDataset(df[0:2],'../input/global-wheat-detection/train',1024)

In [None]:
train_loader = DataLoader(
    data_set,
    batch_size=16,
    shuffle=True,
    num_workers=2,
    collate_fn=collate_fn)

In [None]:
temp = iter(train_loader)

In [None]:
inputs, labels = next(temp)

In [None]:
plot_imgs(inputs,is_rgb=True)

In [None]:
def visualize_bbox(img, boxes, thickness=3, color=(255,0,0)):
    img_copy = img.copy()
    for box in boxes:
        img_copy = cv2.rectangle(
            img_copy,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, thickness)
    return img_copy

temp = visualize_bbox(inputs[0],labels[0]['boxes'].numpy())
fig = plt.figure(figsize=(7,7))

plt.imshow(temp)

In [None]:
image_1 = load_img('740d3c904','../input/global-wheat-detection/train')
image_2 = load_img('6d1879f19','../input/global-wheat-detection/train')
plot_imgs([image_1,image_2],is_rgb=True)

In [None]:
xc, yc = [int(random.uniform(1024 * 0.25, 1024 * 0.75)) for _ in range(2)]
xc, yc