# Start 



In [None]:
%ls


In [None]:
import os
import ast
from collections import namedtuple
import random
import collections
import uuid
from glob import glob
from datetime import datetime

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

from tqdm import tqdm
from PIL import Image

import joblib
from joblib import Parallel, delayed

import cv2
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.core.transforms_interface import DualTransform
from albumentations.augmentations.bbox_utils import denormalize_bbox, normalize_bbox

from sklearn.model_selection import StratifiedKFold

import torch
from torch.utils.data import DataLoader, Dataset
import torch.utils.data as data_utils

from matplotlib import pyplot as plt
import matplotlib.patches as patches
from matplotlib.image import imsave

In [None]:
# Constants
BASE_DIR = '/kaggle/input/global-wheat-detection'
WORK_DIR = '/kaggle/working'

# Set seed for numpy for reproducibility
#np.random.seed(1996)

In [None]:
%cp -r /kaggle/input/global-wheat-detection/train /kaggle/working
%cp /kaggle/input/global-wheat-detection/train.csv /kaggle/working

# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

# START Puzzle Augmentation
1. divide an image into 4 pieces(puzzles) with bbox
2. make pool of puzzles of all original images (in train_data)
3. sample 4 pieces & merge (repeat k-times)
4. save as a dataframe & jpg images

In [None]:
def adjust_bbox_in_image(
    i_w,
    i_h,
    bbox,
    min_bbox_size=20):
    """crop bbox if it cover the edge of the cropped image"""
    
    x, y, w, h = bbox
    if x < 0:
        w += x
        x = 0
    if y < 0:
        h += y
        y = 0
    if i_w < x+w:
        w -= (x+w-i_w)
    if i_h < y+h:
        h -= (y+h-i_h)
    
    on_border = x < 3 or y < 3 or i_w-3 < (x+w) or i_h-3 < (y+h)
    under_min_size = w < min_bbox_size or h < min_bbox_size
    if on_border and under_min_size:
        return None
    
    return (x, y, w, h)


def make_puzzles(
    image_id,
    bboxes,
    min_bbox_size=20,
    image_root='/kaggle/working/train/'):
    """divide given image into 4 pieces with bboxes"""

    img_path = os.path.join(image_root, '{}.jpg'.format(image_id))
    
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    row, col, _ = image.shape
    
    y_div = int(row/2)
    x_div = int(col/2)
    lt = image[:y_div, :x_div] # left top
    rt = image[:y_div, x_div:] # right top
    lb = image[y_div:, :x_div] # left bottom
    rb = image[y_div:, x_div:] # right bottom
    
    lt_bboxes, rt_bboxes, lb_bboxes, rb_bboxes = [], [], [], []
    
    for bbox in bboxes:
        x, y, w, h = bbox
        
        # Quandrant-2
        lt_x, lt_y = x, y
        if lt_y < y_div and lt_x < x_div:
            _bbox = adjust_bbox_in_image(col/2, row/2, bbox, min_bbox_size=min_bbox_size)
            if _bbox:
                lt_bboxes.append(_bbox)
        
        # Quandrant-1
        rt_x, rt_y = x+w-1, y
        if rt_y < y_div and x_div <= rt_x:
            _bbox = (bbox[0]-col/2, bbox[1], bbox[2], bbox[3])
            _bbox = adjust_bbox_in_image(col/2, row/2, _bbox, min_bbox_size=min_bbox_size)
            if _bbox:
                rt_bboxes.append(_bbox)
            
        # Quandrant-3
        lb_x, lb_y = x, y+h-1
        if y_div <= lb_y and lb_x < x_div:
            _bbox = (bbox[0], bbox[1]-row/2, bbox[2], bbox[3])
            _bbox = adjust_bbox_in_image(col/2, row/2, _bbox, min_bbox_size=min_bbox_size)
            if _bbox:
                lb_bboxes.append(_bbox)
        
        # Quandrant-4
        rb_x, rb_y = x+w-1, y+h-1
        if y_div <= rb_y and x_div <= rb_x:
            _bbox = (bbox[0]-col/2, bbox[1]-row/2, bbox[2], bbox[3])
            _bbox = adjust_bbox_in_image(col/2, row/2, _bbox, min_bbox_size=min_bbox_size)
            if _bbox:
                rb_bboxes.append(_bbox)
    
    puzzle_bbox_pairs = [
        (lt, lt_bboxes),
        (rt, rt_bboxes),
        (lb, lb_bboxes),
        (rb, rb_bboxes)
    ]
    
    return puzzle_bbox_pairs


def merge_random_4_puzzles(puzzle_bbox_pairs):
    lt_img, lt_bboxes = puzzle_bbox_pairs[0]
    rt_img, rt_bboxes = puzzle_bbox_pairs[1]
    lb_img, lb_bboxes = puzzle_bbox_pairs[2]
    rb_img, rb_bboxes = puzzle_bbox_pairs[3]
    
    row, col, ch = lt_img.shape
    x_div = col
    y_div = row
    
    empty_img = np.zeros((row*2, col*2, ch), np.uint8)
    
    empty_img[:y_div,:x_div,:] = lt_img
    empty_img[:y_div,x_div:,:] = rt_img
    empty_img[y_div:,:x_div,:] = lb_img
    empty_img[y_div:,x_div:,:] = rb_img
    
    _lt_bboxes = lt_bboxes[:]
    _rt_bboxes = rt_bboxes[:]
    for i, bbox in enumerate(_rt_bboxes):
        x, y, w, h = bbox
        _rt_bboxes[i] = (x+x_div, y, w, h)
    
    _lb_bboxes = lb_bboxes[:]
    for i, bbox in enumerate(_lb_bboxes):
        x, y, w, h = bbox
        _lb_bboxes[i] = (x, y+y_div, w, h)
    
    _rb_bboxes = rb_bboxes[:]
    for i, bbox in enumerate(_rb_bboxes):
        x, y, w, h = bbox
        _rb_bboxes[i] = (x+x_div, y+y_div, w, h)
        
    merged_bbox = _lt_bboxes + _rt_bboxes + _lb_bboxes + _rb_bboxes
    
    return (empty_img, merged_bbox)
    

def visualize_4_image_bbox(puzzle_bbox_pairs):
    fig, ax = plt.subplots(2, 2, figsize=(12, 12))
    ax = ax.flatten()

    labels = ['left top', 'right top', 'left bot', 'right bot']

    for i in range(4):
        image, bboxes = puzzle_bbox_pairs[i]
        for row in bboxes:
            x, y, w, h = (int(n) for n in row)
            cv2.rectangle(image,
                          (x, y),
                          (x+w, y+h),
                          (220, 0, 0), 3)
        ax[i].set_axis_off()
        ax[i].imshow(image)
        ax[i].set_title(labels[i], color='yellow')

        
def visualize_image_bbox(image, bboxes):
    fig, ax = plt.subplots(1, 1, figsize=(12, 12))

    for row in bboxes:
        x, y, w, h = (int(n) for n in row)
        cv2.rectangle(image,
                      (x, y),
                      (x+w, y+h),
                      (220, 0, 0), 3)
    ax.imshow(image)


In [None]:
#Make puzzle pool

df = pd.read_csv('train.csv')

bboxs = np.stack(df['bbox'].apply(lambda x: np.fromstring(x[1:-1], sep=',')))
for i, column in enumerate(['x', 'y', 'w', 'h']):
    df[column] = bboxs[:,i]
df.drop(columns=['bbox'], inplace=True)
df.head()


pool = []
image_ids = set(df.image_id.values)

for i, image_id in enumerate(image_ids):
    
    filtered = df[df['image_id'] == image_id]
    bboxes = filtered[['x', 'y', 'w', 'h']].values

    puzzles = make_puzzles(image_id, bboxes, min_bbox_size=20)
    
    pool += puzzles

In [None]:
#%rm -r merged_puzzles
#%rm merged_puzzles.csv

k = 5000 # NUM NEW IMAGES
AugData = collections.namedtuple('AugData', 'image_id,x_min,y_min,x_max,y_max,width,height,area,source')

aug_data = []

os.makedirs('./merged_puzzles')

for i in range(k):
    random.shuffle(pool)
    a = ([bbox for img, bbox in pool[:4]])
    merged_image, merged_bboxes = merge_random_4_puzzles(pool[:4])
    ih, iw, ch = merged_image.shape
    image_id = str(uuid.uuid4())
    for bbox in merged_bboxes:
        x, y, w, h = bbox
        
        #aug_data.append(AugData(image_id=image_id, width=iw, height=ih, source='aug', x=x, y=y, w=w, h=h))
        
        aug_data.append(AugData(image_id=image_id, x_min=x, y_min=y, x_max=x+w, y_max=y+h, width=w, height=h, area=w*h, source='aug'))

    merged_image = cv2.cvtColor(merged_image, cv2.COLOR_RGB2BGR)
    cv2.imwrite('merged_puzzles/{}.jpg'.format(image_id), merged_image)         

# SAVE new data

In [None]:
aug_df = pd.DataFrame(data=aug_data)
aug_df.to_csv('merged_puzzles.csv', index=False)

In [None]:
train_dfvv = pd.read_csv('merged_puzzles.csv')
train_dfvv.head()


In [None]:
final_train_df = pd.read_csv(os.path.join('merged_puzzles.csv'))

num_images_final = final_train_df['image_id'].unique()
print(f'Total number of training images in csv: {len(num_images_final)}')


# combined train.csv + merged_puzzles.csv

In [None]:
extension = 'csv'
all_filenames = [i for i in glob('*.{}'.format(extension))]
all_filenames

In [None]:
#combine all files in the list
combined_csv = pd.concat([pd.read_csv(f) for f in all_filenames ])
#export to csv
combined_csv.to_csv( "combined_csv.csv", index=False, encoding='utf-8-sig')


In [None]:
!zip -r -qq merged_puzzles.zip merged_puzzles


# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

# Start work with traing data

In [None]:
final_train_df = pd.read_csv('combined_csv.csv')

print(len(df_train['image_id'].unique()))
print((df_train.shape[0])/len(df_train['image_id'].unique()))

In [None]:
print("Image_id v/s # of bounding boxes")
print(df_train['image_id'].value_counts())

In [None]:
print("Height")
print(df_train['height'].value_counts())
print("Width")
df_train['width'].value_counts()

In [None]:
train_df = pd.read_csv('combined_csv.csv')

# Let's expand the bounding box coordinates and calculate the area of all the bboxes
train_df[['x_min','y_min', 'width', 'height']] = pd.DataFrame([ast.literal_eval(x) for x in train_df.bbox.tolist()], index= train_df.index)
train_df = train_df[['image_id', 'bbox', 'source', 'x_min', 'y_min', 'width', 'height']]
train_df['area'] = train_df['width'] * train_df['height']
train_df['x_max'] = train_df['x_min'] + train_df['width']
train_df['y_max'] = train_df['y_min'] + train_df['height']
train_df = train_df.drop(['bbox'], axis=1)
train_df = train_df[['image_id', 'x_min', 'y_min', 'x_max', 'y_max', 'width', 'height', 'area', 'source']]

# There are some buggy annonations in training images having huge bounding boxes. Let's remove those bboxes
train_df = train_df[train_df['area'] < 100000]

train_df.head()

In [None]:
print(train_df.shape)

In [None]:
image_ids = train_df['image_id'].unique()
print(f'Total number of training images: {len(image_ids)}')

In [None]:
# Read the image on which data augmentaion is to be performed
image_id = 'c14c1e300'
image = cv2.imread(os.path.join(WORK_DIR, 'train', f'{image_id}.jpg'), cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
image /= 255.0
plt.figure(figsize = (10, 10))
plt.imshow(image)
plt.show()

There are two major formats of bounding boxes:

1. **pascal_voc**, which is [x_min, y_min, x_max, y_max]
2. **COCO**, which is [x_min, y_min, width, height]

We'll see how to perform image augmentations for both the formats. Let's first start with **pascal_voc** format.

In [None]:

pascal_voc_boxes = train_df[train_df['image_id'] == image_id][['x_min', 'y_min', 'x_max', 'y_max']].astype(np.int32).values
coco_boxes = train_df[train_df['image_id'] == image_id][['x_min', 'y_min', 'width', 'height']].astype(np.int32).values
assert(len(pascal_voc_boxes) == len(coco_boxes))
labels = np.ones((len(pascal_voc_boxes), ))


# class WheatDataset

In [None]:
class WheatDataset(Dataset):
    
    def __init__(self, df):
        self.df = df
        self.image_ids = self.df['image_id'].unique()

    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image = cv2.imread(os.path.join(WORK_DIR, 'train', f'{image_id}.jpg'), cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        image /= 255.0  # Normalize
        
        # Get bbox coordinates for each wheat head(s)
        bboxes_df = self.df[self.df['image_id'] == image_id]
        boxes, areas = [], []
        n_objects = len(bboxes_df)  # Number of wheat heads in the given image

        for i in range(n_objects):
            x_min = bboxes_df.iloc[i]['x_min']
            x_max = bboxes_df.iloc[i]['x_max']
            y_min = bboxes_df.iloc[i]['y_min']
            y_max = bboxes_df.iloc[i]['y_max']

            boxes.append([x_min, y_min, x_max, y_max])
            areas.append(bboxes_df.iloc[i]['area'])

        return {
            'image_id': image_id,
            'image': image,
            'boxes': boxes,
            'area': areas,
        }
    
    
def collate_fn(batch):
    images, bboxes, areas, image_ids = ([] for _ in range(4))
    for data in batch:
        images.append(data['image'])
        bboxes.append(data['boxes'])
        areas.append(data['area'])
        image_ids.append(data['image_id'])

    return np.array(images), np.array(bboxes), np.array(areas), np.array(image_ids)     

In [None]:
BATCH_SIZE = 16

train_dataset = WheatDataset(train_df)
train_loader = data_utils.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=collate_fn)

# class CustomCutout - AUGMENTATION square cutout regions

In [None]:
class CustomCutout(DualTransform):
    """
    Custom Cutout augmentation with handling of bounding boxes 
    Note: (only supports square cutout regions)

    """
    
    def __init__(
        self,
        fill_value=0,
        bbox_removal_threshold=0.50,
        min_cutout_size=120,  # SIZE
        max_cutout_size=512,
        always_apply=False,
        p=0.5
    ):
        """
        Class construstor
        
        :param fill_value: Value to be filled in cutout (default is 0 or black color)
        :param bbox_removal_threshold: Bboxes having content cut by cutout path more than this threshold will be removed
        :param min_cutout_size: minimum size of cutout (192 x 192)
        :param max_cutout_size: maximum size of cutout (512 x 512)
        """
        super(CustomCutout, self).__init__(always_apply, p)  # Initialize parent class
        self.fill_value = fill_value
        self.bbox_removal_threshold = bbox_removal_threshold
        self.min_cutout_size = min_cutout_size
        self.max_cutout_size = max_cutout_size
        
    def _get_cutout_position(self, img_height, img_width, cutout_size):
        """
        Randomly generates cutout position as a named tuple
        
        :param img_height: height of the original image
        :param img_width: width of the original image
        :param cutout_size: size of the cutout patch (square)
        :returns position of cutout patch as a named tuple
        """
        position = namedtuple('Point', 'x y')
        return position(
            np.random.randint(0, img_width - cutout_size + 1),
            np.random.randint(0, img_height - cutout_size + 1)
        )
        
    def _get_cutout(self, img_height, img_width):
        """
        Creates a cutout pacth with given fill value and determines the position in the original image
        
        :param img_height: height of the original image
        :param img_width: width of the original image
        :returns (cutout patch, cutout size, cutout position)
        """
        cutout_size = np.random.randint(self.min_cutout_size, self.max_cutout_size + 1)
        cutout_position = self._get_cutout_position(img_height, img_width, cutout_size)
        return np.full((cutout_size, cutout_size, 3), self.fill_value), cutout_size, cutout_position
        
    def apply(self, image, **params):
        """
        Applies the cutout augmentation on the given image
        
        :param image: The image to be augmented
        :returns augmented image
        """
        image = image.copy()  # Don't change the original image
        self.img_height, self.img_width, _ = image.shape
        cutout_arr, cutout_size, cutout_pos = self._get_cutout(self.img_height, self.img_width)
        
        # Set to instance variables to use this later
        self.image = image
        self.cutout_pos = cutout_pos
        self.cutout_size = cutout_size
        
        image[cutout_pos.y:cutout_pos.y+cutout_size, cutout_pos.x:cutout_size+cutout_pos.x, :] = cutout_arr
        return image
    
    def apply_to_bbox(self, bbox, **params):
        """
        Removes the bounding boxes which are covered by the applied cutout
        
        :param bbox: A single bounding box coordinates in pascal_voc format
        :returns transformed bbox's coordinates
        """

        # Denormalize the bbox coordinates
        bbox = denormalize_bbox(bbox, self.img_height, self.img_width)
        x_min, y_min, x_max, y_max = tuple(map(int, bbox))

        bbox_size = (x_max - x_min) * (y_max - y_min)  # width * height
        overlapping_size = np.sum(
            (self.image[y_min:y_max, x_min:x_max, 0] == self.fill_value) &
            (self.image[y_min:y_max, x_min:x_max, 1] == self.fill_value) &
            (self.image[y_min:y_max, x_min:x_max, 2] == self.fill_value)
        )

        # Remove the bbox if it has more than some threshold of content is inside the cutout patch
        if overlapping_size / bbox_size > self.bbox_removal_threshold:
            return normalize_bbox((0, 0, 0, 0), self.img_height, self.img_width)

        return normalize_bbox(bbox, self.img_height, self.img_width)

    def get_transform_init_args_names(self):
        """
        Fetches the parameter(s) of __init__ method
        :returns: tuple of parameter(s) of __init__ method
        """
        return ('fill_value', 'bbox_removal_threshold', 'min_cutout_size', 'max_cutout_size', 'always_apply', 'p')

# Add augmentation albumentations

In [None]:
#rm
first_version = '''
augmentation0 = albumentations.Compose([
    CustomCutout(p=0.5),
    A.Flip(p=0.60),
    A.RandomRotate90(p=0.5),
    A.RandomBrightness(limit=0.3, p=0.60),
    A.OneOf([  # One of blur or adding gauss noise
        A.Blur(p=0.50),  # Blurs the image
        A.GaussNoise(var_limit=5.0 / 255.0, p=0.50)  # Adds Gauss noise to image
    ], p=0.5)
], bbox_params = {
    'format': 'pascal_voc',
    'label_fields': ['labels']
})

'''

### determining the augmentations used 

In [None]:
# CustomCutout(p=.5) # function call

flip = A.Flip(p=.6)
rot90 = A.RandomRotate90(p=.5)

br_contr = A.RandomBrightnessContrast(brightness_limit=.3, contrast_limit=.3, p=.5)
brigh = A.RandomBrightness(limit=.3, p=.6)
contrast = A.RandomContrast(limit=.3, p=.6)

blur = A.Blur(p=.3)
noise = A.GaussNoise(var_limit=5.0 / 255.0, p=.3)

def oneof(arr=[blur, noise], p=.5):
    return A.OneOf( arr, p )

In [None]:
required_aug = [
    [
        CustomCutout(p=.5), 
        flip, 
        rot90, 
        brigh, 
        oneof()], 
    [
        CustomCutout(p=.5), 
        flip, 
        rot90, 
        contrast, 
        oneof()]
]

# Visualization augmented images

In [None]:
def get_bbox(bboxes, col, color='white'):
    for i in range(len(bboxes)):
        # Create a Rectangle patch
        rect = patches.Rectangle(
            (bboxes[i][0], bboxes[i][1]),
            bboxes[i][2] - bboxes[i][0], 
            bboxes[i][3] - bboxes[i][1], 
            linewidth=2, 
            edgecolor=color, 
            facecolor='none')

        # Add the patch to the Axes
        col.add_patch(rect)

In [None]:
transform = A.Compose(
    required_aug[0], 
    bbox_params = {
    'format': 'pascal_voc',
    'label_fields': ['labels']
})

In [None]:
num_images = 5
rand_start = np.random.randint(0, len(image_ids) - 5)
fig, ax = plt.subplots(nrows=num_images, ncols=2, figsize=(16, 40))

for index, image_id in enumerate(image_ids[rand_start : rand_start + num_images]):
    # Read the image from image id
    image = cv2.imread(os.path.join(WORK_DIR, 'train', f'{image_id}.jpg'), cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
    image /= 255.0  # Normalize
    
    # Get the bboxes details and apply all the augmentations
    bboxes = train_df[train_df['image_id'] == image_id][['x_min', 'y_min', 'x_max', 'y_max']].astype(np.int32).values
    labels = np.ones((len(bboxes), ))  # As we have only one class (wheat heads)
        
    aug_result = transform(image=image, bboxes=bboxes, labels=labels)

    get_bbox(bboxes, ax[index][0], color='red')
    ax[index][0].grid(False)
    ax[index][0].set_xticks([])
    ax[index][0].set_yticks([])
    ax[index][0].title.set_text('Original Image')
    ax[index][0].imshow(image)

    get_bbox(aug_result['bboxes'], ax[index][1], color='red')
    ax[index][1].grid(False)
    ax[index][1].set_xticks([])
    ax[index][1].set_yticks([])
    ax[index][1].title.set_text(f'Augmented Image: Removed bboxes: {len(bboxes) - len(aug_result["bboxes"])}')
    ax[index][1].imshow(aug_result['image'])
plt.show()

# Create dataset

In [None]:
required_aug = [
    [
        CustomCutout(p=.5), 
        flip, 
        rot90, 
        brigh, 
        oneof()
    ], 
    [
        CustomCutout(p=.5), 
        flip, 
        rot90, 
        contrast, 
        oneof()
    ]
]

In [None]:
coefficient = len(required_aug)
coefficient

In [None]:
def create_dataset(index, image_id, coefficient = 0):
    # Read the image from image id
    image = cv2.imread(os.path.join(WORK_DIR, 'train', f'{image_id}.jpg'), cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Get the bboxes details and apply all the augmentations
    bboxes = train_df[train_df['image_id'] == image_id][['x_min', 'y_min', 'x_max', 'y_max']].astype(np.int32).values
    source = train_df[train_df['image_id'] == image_id]['source'].unique()[0]
    labels = np.ones((len(bboxes), ))  # As we have only one class (wheat heads)
    
    some_aug = required_aug[coefficient]
    
    transform = A.Compose(
        some_aug, 
        bbox_params = {
        'format': 'pascal_voc',
        'label_fields': ['labels']})
    
    aug_result = transform(image=image, bboxes=bboxes, labels=labels)
    name_img_aug = f'{image_id}_aug_{coefficient}'

    aug_image = aug_result['image']
    aug_bboxes = aug_result['bboxes']
    
    Image.fromarray(image).save(os.path.join(WORK_DIR, 'train', f'{image_id}.jpg'))
    Image.fromarray(aug_image).save(os.path.join(WORK_DIR, 'train', f'{name_img_aug}.jpg'))

    image_metadata = []
    for bbox in aug_bboxes:
        bbox = tuple(map(int, bbox))
        image_metadata.append({
            'image_id': name_img_aug,
            'x_min': bbox[0],
            'y_min': bbox[1],
            'x_max': bbox[2],
            'y_max': bbox[3],
            'width': bbox[2] - bbox[0],
            'height': bbox[3] - bbox[1],
            'area': (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]),
            'source': source
        })
    return image_metadata

In [None]:
#%rm train.csv
#%rm -r train

if not os.path.isdir('train'):
    os.mkdir('train')

In [None]:
for el_k in range(coefficient):
    image_metadata = Parallel(n_jobs=8)(delayed(create_dataset)(index, image_id, el_k) for index, image_id in tqdm(enumerate(image_ids), total=len(image_ids)))
    image_metadata = [item for sublist in image_metadata for item in sublist]
    coefficient -= 1
    aug_train_df = pd.DataFrame(image_metadata)
    train_df = pd.concat([train_df, aug_train_df]).reset_index(drop=True)

In [None]:
print(aug_train_df.shape)
train_df.shape

# Save result

In [None]:
# Add a new column to store kfold indices
train_df.loc[:, 'kfold'] = -1

In [None]:
image_source = train_df[['image_id', 'source']].drop_duplicates()

# get lists for image_ids and sources
image_ids = image_source['image_id'].to_numpy()
sources = image_source['source'].to_numpy()

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=1996)
split = skf.split(image_ids, sources) # second arguement is what we are stratifying by

for fold, (train_idx, val_idx) in enumerate(split):
    translated_val_idx = train_df[train_df['image_id'].isin(image_ids[val_idx])].index.values
    print(len(translated_val_idx))
    train_df.loc[translated_val_idx, 'kfold'] = fold
    
train_df.to_csv('train.csv', index=False)

# Check result

In [None]:
final_train_df = pd.read_csv(os.path.join('train.csv'))

num_images_final = final_train_df['image_id'].unique()
print(f'Total number of training images: {len(num_images_final)}')

In [None]:
num_files = len(os.listdir('./train/'))
num_files