In [None]:
import numpy as np
import pandas as pd
import cv2
import re
import matplotlib.pyplot as plt

## Read the CSV Files

In [None]:
train_df = pd.read_csv('../input/global-wheat-detection/train.csv')

In [None]:
print(f"Number of rows in training CSV: {len(train_df)}")

In [None]:
train_df.head()

In [None]:
print(type(train_df))
train_df['image_id'][1]

In [None]:
print(train_df['bbox'][0])

In [None]:
def expand_bbox(x):
    """
    Function to separate the `[834.0, 222.0, 56.0, 36.0` string to
    separate `['834.0' '222.0' '56.0' '36.0']`.
    """
    r = np.array(re.findall("([0-9]+[.]?[0-9]*)", x))
    if len(r) == 0:
        r = [-1, -1, -1, -1]
    return r

box = expand_bbox(train_df['bbox'][0])
print(box)
box_0 = box[0].astype(np.float)
print(box_0)

## Seperate the X, Y, W, and H for Each BBox

In [None]:
print(train_df.shape)

In [None]:
train_df['x'] = -1
train_df['y'] = -1
train_df['w'] = -1
train_df['h'] = -1

In [None]:
train_df[['x', 'y', 'w', 'h']] = np.stack(train_df['bbox'].apply(lambda x: expand_bbox(x)))
train_df.drop(columns=['bbox'], inplace=True)
train_df['x'] = train_df['x'].astype(np.float)
train_df['y'] = train_df['y'].astype(np.float)
train_df['w'] = train_df['w'].astype(np.float)
train_df['h'] = train_df['h'].astype(np.float)

In [None]:
train_df.head()

In [None]:
# save the CSV 
train_df.to_csv('global_wheat_detection_formatted.csv', index=False)

### Group the CSV File According to the image_id

In [None]:
grouped = train_df.groupby(['image_id'])
print(len(grouped))
grouped.head()

In [None]:
count = 0
for key, item in grouped:
    print(key)
    print(type(item))
    print(item)
    print(len(item))
    print('\n')
    count += 1
    if count == 2:
        break

print(count)

## Visualize Images 

In [None]:
DIR_TRAIN = '../input/global-wheat-detection/train'

In [None]:
count = 0
# number of images to show
viz_thres = 25
for key, item in grouped:
    print(key)
    print('\n')
    image = cv2.imread(f"{DIR_TRAIN}/{key}.jpg")
    cv2.putText(image, f"{len(item)} wheat heads", (10, 50), 
                cv2.FONT_HERSHEY_SIMPLEX, 2.0, (255, 255, 0), 5)
    for i in range(len(item)):
        x1 = int(item.iloc[i]['x'])
        y1 = int(item.iloc[i]['y'])
        x2 = int(x1 + item.iloc[i]['w'])
        y2 = int(y1 + item.iloc[i]['h'])
        cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 3)
    plt.figure(figsize=(15, 12))
    plt.imshow(image[:, :, ::-1])
    plt.axis('off')
    plt.show()
    count += 1
    if count == viz_thres:
        break