In [1]:
from project.retinanet.model import resnet34
from torchvision.transforms import Compose
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import collections
import torch.optim as optim
import numpy as np
from tqdm import tqdm

In [2]:
class MultiEpochsDataLoader(torch.utils.data.DataLoader):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._DataLoader__initialized = False
        self.batch_sampler = _RepeatSampler(self.batch_sampler)
        self._DataLoader__initialized = True
        self.iterator = super().__iter__()

    def __len__(self):
        return len(self.batch_sampler.sampler)

    def __iter__(self):
        for i in range(len(self)):
            yield next(self.iterator)


class _RepeatSampler(object):
    """ Sampler that repeats forever.

    Args:
        sampler (Sampler)
    """

    def __init__(self, sampler):
        self.sampler = sampler

    def __iter__(self):
        while True:
            yield from iter(self.sampler)

In [3]:
from project.datasetLoader.dataloader import EstimatedDeepFish, CustomPILToTensor, Normalizer, Resizer


DATASET_PATH = "./DATASET/"

dataset= EstimatedDeepFish('./annotations.csv',
                           DATASET_PATH, 
                           Compose([
                               CustomPILToTensor(),
                               Normalizer(
                                   mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225],
                               ),
                               Resizer((480, 480), antialias=True),
                            ]))
dataloader= MultiEpochsDataLoader(dataset, 64, num_workers=1, shuffle=True)

for i in tqdm(dataloader):
    pass

print("DONE")

305 rows were removed from image dataset!
3751
Index(['file', 'bbox', 'class', 'size (cm)'], dtype='object')


  0%|          | 0/9 [00:00<?, ?it/s]

For getting from Loader: [0.12] seconds | After Transform Took: [0.41] seconds for Index: 170
For getting from Loader: [0.09] seconds | After Transform Took: [0.34] seconds for Index: 263
For getting from Loader: [0.09] seconds | After Transform Took: [0.32] seconds for Index: 433
For getting from Loader: [0.09] seconds | After Transform Took: [0.34] seconds for Index: 50
For getting from Loader: [0.09] seconds | After Transform Took: [0.34] seconds for Index: 316
For getting from Loader: [0.10] seconds | After Transform Took: [0.36] seconds for Index: 210
For getting from Loader: [0.09] seconds | After Transform Took: [0.33] seconds for Index: 477
For getting from Loader: [0.09] seconds | After Transform Took: [0.31] seconds for Index: 267
For getting from Loader: [0.09] seconds | After Transform Took: [0.32] seconds for Index: 9
For getting from Loader: [0.09] seconds | After Transform Took: [0.33] seconds for Index: 217
For getting from Loader: [0.09] seconds | After Transform Took:

 | After Transform Took: [0.36] seconds for Index: 177
For getting from Loader: [0.09] seconds

In [None]:
dataset.num_classes()[0]

In [None]:
class PipelineModel(nn.Module):
    def __init__(self,num_classes):
        super().__init__()
        self.detector= resnet34(num_classes=num_classes,pretrained=True)

    def forward(self,x):
        return self.detector((x['img'],x['annot'],x['number']))
model= PipelineModel(len(dataset.num_classes()[0]))

from torchinfo import summary
summary(model)