# **Importing libraries**

In [None]:
import pandas as pd
import numpy as np
from PIL import Image

import matplotlib.pyplot as plt
from matplotlib import patches

import torch
from torch.utils.data import Dataset

import os
import json

In [None]:
DATA_DIR = "/kaggle/input/tensorflow-great-barrier-reef/"

# **Looking at data**

In [None]:
train_df = pd.read_csv(os.path.join(DATA_DIR, "train.csv"))
train_df.head(20)

In [None]:
train_df.shape

In [None]:
def vizualize(img, bboxes, class_name, color):
    fig, ax = plt.subplots(1, 1, figsize=(20, 20))
    plt.axis("off")
    ax.imshow(img)
    for box in bboxes:
        x, y, w, h = box
        ax.add_patch(patches.Rectangle((x, y), w, h, edgecolor=color, fill=False, linewidth=2))
        ax.text(x, y, class_name, bbox={"facecolor": color, "alpha": 0.9}, fontsize=11)
    plt.show()

In [None]:
for _, row in train_df[20:22].iterrows():
    bboxes = []
    vid = row["video_id"]
    frame = row["video_frame"]
    annots = json.loads(row["annotations"].replace("'", '"'))
    for annot in annots:
        x = annot["x"]
        y = annot["y"]
        w = annot["width"]
        h = annot["height"]
        bboxes.append([x, y, w, h])
    img = np.array(Image.open(os.path.join(DATA_DIR, f"train_images/video_{vid}/{frame}.jpg")))
    vizualize(img, bboxes, "starfish", "orange")

# **Data for net**

In [None]:
class StarfishDataset(Dataset):
    def __init__(self, df, data_dir):
        self.df = df.copy()
        self.data_dir = data_dir
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        vid = row["video_id"]
        fid = row["video_frame"]
        annots = json.loads(row["annotations"].replace("'", '"'))
        
        img = Image.open(os.path.join(self.data_dir, f"video_{vid}/{fid}.jpg"))
        img = np.array(img, dtype=np.float32) / 255
        img = torch.from_numpy(img)
        
        bboxes = []
        
        for annot in annots:
            x = annot["x"]
            y = annot["y"]
            w = annot["width"]
            h = annot["height"]
            bboxes.append([x, y, w, h])
        
        labels = torch.ones((len(bboxes), ))
        
        return img, bboxes
    
    def __len__(self):
        return len(self.df)

In [None]:
train_dataset = StarfishDataset(train_df, os.path.join(DATA_DIR, "train_images"))

In [None]:
train_dataset[20]