In [8]:
from pydantic import BaseModel, Field
from typing import List, Optional
from fastapi import FastAPI, File, UploadFile

import torch
from skimage.io import imread
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch import nn
import os
import numpy as np
import torchmetrics
import torchvision
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
import albumentations as A
import warnings

# import pytorch_lightning as pl
# from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.callbacks.early_stopping import EarlyStopping
# from pytorch_lightning.tuner import Tuner

In [9]:
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.2023, 0.1994, 0.2010)

In [None]:
class LogoDataset(Dataset):
    def __init__(self,
        mode,
        data_dir,
        fraction: float = 0.7,
        transform=None,
    ):
        ## list of tuples: (img_path, label)
        self._items = []

        ## will be used later for augmentations
        self._transform = transform

        ## we can't store all the images in memory at the same time,
        ## because sometimes we have to work with very large datasets
        ## so we will only store data paths
        ## (also this is convenient for further augmentations)
        all_img_paths = []
        all_y = []

        ## Pipeline for zero-shot labeling
        pipeline = None

        for img_name in os.listdir(data_dir):
            all_img_paths.append(f"{data_dir}/{img_name}")
            # all_y.append(train_gt[img_name])
            all_y.append(0)

        train_size = int(fraction * len(all_img_paths))
        test_size = len(all_img_paths) - train_size

        if mode == "train":
            for obj_ind in range(train_size):
                self._items.append((all_img_paths[obj_ind], all_y[obj_ind]))
        elif mode == "test":
            for obj_ind in range(test_size):
                self._items.append(
                    (all_img_paths[train_size + obj_ind], all_y[train_size + obj_ind])
                )

    def __len__(self):
        return len(self._items)

    def __getitem__(self, index):
        img_path, label = self._items[index]

        ## read image
        image = imread(img_path)
        label = np.copy(label).astype("float")

        ## if image has only one color channel
        if len(image.shape) == 2:
            image = np.dstack((image, image, image))

        ## make augmentation for image
        if self._transform:
            pass

        return image, label

In [22]:
ds_train = LogoDataset(mode='train', data_dir='data_sirius')
ds_test = LogoDataset(mode='test', data_dir='data_sirius')