# **Imports**🎇

In [26]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
import os
import json
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

import torch
from torch.utils.data import DataLoader, random_split, Dataset
import torch.optim as optim

import torchvision
from torchvision import transforms as T
from torchvision.utils import make_grid
from torchvision.transforms import functional as F

# **Utils**

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

n_epoch = 50

# **Prepare Labels and Files**

In [9]:
# !labelme

In [10]:
# for folder in ['train', 'valid', 'test']:
#     for file in os.listdir(os.path.join(folder, 'photo/')):
#         filename = file.split('.')[0]+'.json'
#         existing_filepath = os.path.join('labels', filename)
#         if os.path.exists(existing_filepath):
#             new_filepath = os.path.join(folder, 'label', filename)
#             os.replace(existing_filepath, new_filepath)


# **Functions**

In [48]:
load = json.load(open('valid/label/images46.json'))
load['shapes'][1]['points']

[[59.47916666666667, 6.916666666666665],
 [23.854166666666668, 52.95833333333333]]

In [50]:
load = json.load(open('valid/label/images46.json'))
points = []
for i in range(len(load['shapes'])):
    points.append(
        [load['shapes'][i]['points'][0][0],
         load['shapes'][i]['points'][0][1],
         load['shapes'][i]['points'][1][0],
         load['shapes'][i]['points'][1][1]])
points

[[243.02083333333334, 23.375, 218.43750000000003, 55.041666666666664],
 [59.47916666666667, 6.916666666666665, 23.854166666666668, 52.95833333333333],
 [99.47916666666667, 94.20833333333333, 56.35416666666667, 148.79166666666669]]

In [68]:
class TrumpDataset(Dataset):
    def __init__(self, phase):
        self.phase = phase

        self.images_list = []
        self.labels = []
        for item in os.listdir(os.path.join(phase, 'photo')):
            img_path = os.path.join(phase, 'photo', item)
            self.images_list.append(img_path)
            self.labels.append(os.path.join(phase, 'label', '.'.join(item.split('.')[:-1]) + '.json'))


    def __getitem__(self, item):
        img_path = os.path.join(self.images_list[item])
        img = Image.open(img_path).convert('RGB')
        img = F.to_tensor(img)
          
        with open(self.labels[item]) as f:
            loal = json.load(f)
            points = []
            for i in range(len(load['shapes'])):
                points.append(
                    [load['shapes'][i]['points'][0][0],
                     load['shapes'][i]['points'][0][1],
                     load['shapes'][i]['points'][1][0],
                     load['shapes'][i]['points'][1][1]])

        boxes = torch.tensor(points, dtype=torch.float32)
        labels = torch.ones((len(load['shapes']), ), dtype=torch.int64)
        #
        target = {}
        target['boxes'] = boxes
        target['labels'] = labels
        #
        return img, target

    def __len__(self):
        return len(self.images_list)

In [69]:
train = TrumpDataset('train')

In [70]:
train.__getitem__(1)

(tensor([[[0.4588, 0.4314, 0.3765,  ..., 0.2824, 0.2549, 0.2627],
          [0.4471, 0.4392, 0.4078,  ..., 0.2784, 0.2588, 0.2549],
          [0.4588, 0.4392, 0.4157,  ..., 0.2902, 0.2588, 0.2431],
          ...,
          [0.4745, 0.4588, 0.5020,  ..., 0.2784, 0.2549, 0.2549],
          [0.4706, 0.4824, 0.5059,  ..., 0.2627, 0.2314, 0.2314],
          [0.4549, 0.4745, 0.4980,  ..., 0.2667, 0.2627, 0.2824]],
 
         [[0.3647, 0.3529, 0.3059,  ..., 0.4667, 0.4863, 0.5451],
          [0.3608, 0.3608, 0.3333,  ..., 0.4275, 0.4627, 0.5137],
          [0.3843, 0.3686, 0.3490,  ..., 0.4078, 0.4353, 0.4588],
          ...,
          [0.5373, 0.5216, 0.5569,  ..., 0.2510, 0.2275, 0.2275],
          [0.5333, 0.5451, 0.5608,  ..., 0.2353, 0.2000, 0.2000],
          [0.5176, 0.5294, 0.5451,  ..., 0.2353, 0.2314, 0.2510]],
 
         [[0.3176, 0.3098, 0.2667,  ..., 0.5373, 0.5490, 0.5922],
          [0.3176, 0.3255, 0.3059,  ..., 0.5176, 0.5451, 0.5765],
          [0.3294, 0.3137, 0.3098,  ...,