In [3]:
from src.generators import BoxDatasetGenerator, ImageDataset, MixDatasetGenerator, MyAugmenter
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import os
from math import floor

In [7]:
path = '/Users/vankudr/Documents/НИР-data/dataset_1'

### CREATING BOX DATASET FROM HANDWRITTEN

In [8]:
import math
from torch import default_generator, randperm
from torch._utils import _accumulate
from torch.utils.data.dataset import Subset

def random_split(dataset, lengths,
                 generator=default_generator):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    If a list of fractions that sum up to 1 is given,
    the lengths will be computed automatically as
    floor(frac * len(dataset)) for each fraction provided.

    After computing the lengths, if there are any remainders, 1 count will be
    distributed in round-robin fashion to the lengths
    until there are no remainders left.

    Optionally fix the generator for reproducible results, e.g.:

    >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
    >>> random_split(range(30), [0.3, 0.3, 0.4], generator=torch.Generator(
    ...   ).manual_seed(42))

    Args:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths or fractions of splits to be produced
        generator (Generator): Generator used for the random permutation.
    """
    if math.isclose(sum(lengths), 1) and sum(lengths) <= 1:
        subset_lengths: List[int] = []
        for i, frac in enumerate(lengths):
            if frac < 0 or frac > 1:
                raise ValueError(f"Fraction at index {i} is not between 0 and 1")
            n_items_in_split = int(
                math.floor(len(dataset) * frac)  # type: ignore[arg-type]
            )
            subset_lengths.append(n_items_in_split)
        remainder = len(dataset) - sum(subset_lengths)  # type: ignore[arg-type]
        # add 1 to all the lengths in round-robin fashion until the remainder is 0
        for i in range(remainder):
            idx_to_add_at = i % len(subset_lengths)
            subset_lengths[idx_to_add_at] += 1
        lengths = subset_lengths
        for i, length in enumerate(lengths):
            if length == 0:
                warnings.warn(f"Length of split at index {i} is 0. "
                              f"This might result in an empty dataset.")

    # Cannot verify that dataset is Sized
    if sum(lengths) != len(dataset):    # type: ignore[arg-type]
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths), generator=generator).tolist()  # type: ignore[call-overload]
    return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]

In [9]:
hwr_path = os.path.join(path, 'input/hwr/')
boxes_train_path = os.path.join(path, 'train/boxes/')
boxes_test_path = os.path.join(path, 'test/boxes/')
boxes_val_path = os.path.join(path, 'val/boxes/')

def pad(img, padding):
    h, w, _ = img.shape
    b_p, d_p, l_p, r_p = padding
    return img[int(b_p * h): int((1 - d_p) * h), int(l_p * w): int((1 - r_p) * w)]

hwr_transformer = lambda img: pad(img=img, padding=[0.23, 0.25, 0.01, 0.01])
hwr_dataset = ImageDataset(path=hwr_path, transform=hwr_transformer)
hwr_train, hwr_test, hwr_val = random_split(hwr_dataset, [0.7, 0.2, 0.1])

def create_box_dataset(hwr_dataset, boxes_path, label):
    print(f'Creating {label} boxes dataset')

    hwr_dataloader = DataLoader(hwr_dataset, batch_size=1, shuffle=True)
    box_gen = BoxDatasetGenerator(hwr_dataloader=hwr_dataloader,
                                boxes_path=boxes_path,
                                hwr_threshold=160)
    box_gen.create_dataset()

In [5]:
create_box_dataset(hwr_train, boxes_train_path, 'train')
create_box_dataset(hwr_test, boxes_test_path, 'test')
create_box_dataset(hwr_val, boxes_val_path, 'val')

Creating train boxes dataset


Handwritten image processed: 100%|██████████| 1078/1078 [03:39<00:00,  4.90it/s]


Creating test boxes dataset


Handwritten image processed: 100%|██████████| 308/308 [00:49<00:00,  6.20it/s]


Creating val boxes dataset


Handwritten image processed: 100%|██████████| 153/153 [00:24<00:00,  6.25it/s]


### CREATING MIX DATASET FROM BOXES AND PRINTED

In [10]:
def create_mix_dataset(printed_dataset, boxes_dataset, result_path, boxes_per_printed=250):
    n1 = boxes_per_printed * len(printed_dataset)
    n2 = len(boxes_dataset)
    if n1 > n2:
        raise ValueError(f'{n1} > {n2}')

    printed_dataloader = DataLoader(printed_dataset, batch_size=1, shuffle=False)
    boxes_dataloader = DataLoader(boxes_dataset, batch_size=1, shuffle=True)
    boxes_aug = MyAugmenter(random_scaling=[0.15, 0.25], random_rotation=[-5, 5])
    mix_gen = MixDatasetGenerator(printed_dataloader=printed_dataloader,
                                printed_replica_factor=1,
                                boxes_dataloader=boxes_dataloader,
                                boxes_per_printed=boxes_per_printed,
                                result_path=result_path,
                                printed_threshold=100,
                                hwr_threshold=160,
                                boxes_augmenter=boxes_aug)

    mix_gen.create_dataset()

In [11]:
# printed_path = os.path.join(path, 'input/printed/')
printed_path = '/Users/vankudr/Documents/НИР-data/other/PubLayNet10k'
printed_dataset = ImageDataset(path=printed_path)
printed_train, printed_test, printed_val = random_split(printed_dataset, [0.7, 0.2, 0.1])

In [12]:
boxes_train_path = os.path.join(path, 'train/boxes/')
boxes_train = ImageDataset(path=boxes_train_path, read_in_memory=True, multiplication_factor=100)
result_train_path = os.path.join(path, 'train/result/')

Boxes read: 100%|██████████| 33181/33181 [01:19<00:00, 415.86it/s]


In [13]:
import cProfile

# create_mix_dataset(printed_train, boxes_train, result_train_path)
cProfile.run('create_mix_dataset(printed_train, boxes_train, result_train_path)', 'results_train.prof')

Printed images batch processed: 100%|██████████| 7057/7057 [11:49<00:00,  9.94it/s]


In [14]:
boxes_test_path = os.path.join(path, 'test/boxes/')
boxes_test = ImageDataset(path=boxes_test_path, read_in_memory=True, multiplication_factor=100)
result_test_path = os.path.join(path, 'test/result/')

Boxes read: 100%|██████████| 9296/9296 [00:31<00:00, 295.13it/s]


In [15]:
import cProfile

# create_mix_dataset(printed_train, boxes_train, result_train_path)
cProfile.run('create_mix_dataset(printed_test, boxes_test, result_test_path)', 'results_test.prof')

Printed images batch processed: 100%|██████████| 2016/2016 [03:26<00:00,  9.76it/s]


In [16]:
boxes_val_path = os.path.join(path, 'val/boxes/')
boxes_val = ImageDataset(path=boxes_val_path, read_in_memory=True, multiplication_factor=100)
result_val_path = os.path.join(path, 'val/result/')

Boxes read: 100%|██████████| 4587/4587 [00:17<00:00, 258.72it/s]


In [17]:
import cProfile

# create_mix_dataset(printed_train, boxes_train, result_train_path)
cProfile.run('create_mix_dataset(printed_val, boxes_val, result_val_path)', 'results_val.prof')

Printed images batch processed: 100%|██████████| 1008/1008 [01:42<00:00,  9.84it/s]


In [18]:
import pstats

# загружаем результаты профилирования в pstats
stats = pstats.Stats('results_train.prof')

# сортируем по времени выполнения и выводим топ-10 функций
stats.sort_stats('tottime').print_stats(10)

Sun Apr 30 15:28:45 2023    results_train.prof

         131288736 function calls (129489104 primitive calls) in 710.137 seconds

   Ordered by: internal time
   List reduced from 264 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  1771307   93.904    0.000   93.904    0.000 {built-in method torch.stack}
    14114   91.796    0.007   91.796    0.007 {imwrite}
     7057   87.688    0.012   87.688    0.012 {imread}
     7057   53.785    0.008  514.567    0.073 /Users/vankudr/Yandex.Disk.localized/MIPT/НИР/Handwriting-Segmentation/Dataset-generation/src/generators/hwr_generator.py:153(cover_printed_with_boxes)
  1771307   37.638    0.000   37.638    0.000 {cvtColor}
  1764250   37.270    0.000   37.270    0.000 {warpAffine}
  1764250   32.177    0.000   32.177    0.000 {resize}
  1764250   21.843    0.000   21.843    0.000 {dilate}
3542614/1771307   19.583    0.000  130.886    0.000 /opt/anaconda3/envs/scidev/lib/python3.9/site-pack

<pstats.Stats at 0x7fc190ad6d60>