In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
"""
The main script that serves as the entry-point for all kinds of training experiments.
"""

import argparse
import dataclasses
import gc
import logging
import os

import numpy as np
import torch
from das.data.data_args import DataArguments
from das.utils.arg_parser import DASArgumentParser
from das.utils.basic_args import BasicArguments
from das.utils.basic_utils import configure_logger, create_logger

# setup logging
logger = create_logger(__name__)

# define dataclasses to parse arguments from
ARG_DATA_CLASSES = [BasicArguments, DataArguments]

# torch hub bug fix https://github.com/pytorch/vision/issues/4156
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True


def parse_args(cfg):
    """
    Parses script arguments.
    """
    # initialize the argument parsers
    arg_parser = DASArgumentParser(ARG_DATA_CLASSES)

    return arg_parser.parse_yaml_file(os.path.abspath(cfg))

def print_args(title, args):
    """
    Pretty prints the arguments.
    """
    args_message = f"\n{title}:\n"
    for (k, v) in dataclasses.asdict(args).items():
        args_message += f"\t{k}: {v}\n"
    print(args_message)


def empty_cache():
    torch.cuda.empty_cache()
    gc.collect()

"""
Initializes the training of a model given dataset, and their configurations.
"""

# empty cuda cache
empty_cache()

# parse arguments
cfg = 'dataset.yaml'
basic_args, data_args = parse_args(cfg)

# configure pytorch-lightning logger
pl_logger = logging.getLogger("pytorch_lightning")
configure_logger(pl_logger)

# intialize torch random seed
torch.manual_seed(basic_args.seed)

<torch._C.Generator at 0x7f69240334f0>

In [3]:
from das.data.data_modules.base import DataModuleFactory
import matplotlib.pyplot as plt
from pathlib import Path
import ocrodeg 
import scipy.ndimage as ndi

# initialize data-handling module, set collate_fns later
datamodule = DataModuleFactory.create_datamodule(
    basic_args, data_args)

# prepare the modules
datamodule.prepare_data()
datamodule.setup()

def random_transform_fn(image, params):
    return ocrodeg.transform_image(image, **ocrodeg.random_transform())

def random_distortion_fn(image, params):
    noise = ocrodeg.bounded_gaussian_noise(image.shape, params['sigma'], 5.0)
    return ocrodeg.distort_with_noise(image, noise)

def ruled_surface_distortion_fn(image, params):
    noise = ocrodeg.noise_distort1d(image.shape, magnitude=params['noise'])
    return ocrodeg.distort_with_noise(image, noise)

def gaussian_fn(image, params):
    return ndi.gaussian_filter(image, params['str'])

def threshold_fn(image, params):
    blurred = ndi.gaussian_filter(image, params['str'])
    return 1.0*(blurred>0.5)

def binary_blur_fn(image, params):
    return ocrodeg.binary_blur(image, params['str'])

def noisy_binary_blur_fn(image, params):
    return ocrodeg.binary_blur(image, 0.1, noise=params['noise'])

def random_blotches_fn(image, params=None):
    return ocrodeg.random_blotches(image, 3e-4, 1e-4)

def fibrous_noise_fn(image, params=None):
    return ocrodeg.printlike_fibrous(image)
    # noise = ocrodeg.make_fibrous_image((256, 256), 700, 300, 0.01)

def multiscale_noise_fn(image, params=None):
    return ocrodeg.printlike_multiscale(image)

distortions = {
    'random_transform': (random_transform_fn, [{}]),
    'random_distortion': (random_distortion_fn, [{'sigma': 1.0}, {'sigma': 2.0}, {'sigma': 5.0}, {'sigma': 10.0}, {'sigma': 20.0}]),
    'ruled_surface_distortion': (ruled_surface_distortion_fn, [{'noise': 5.0}, {'noise': 20.0}, {'noise': 100.0}, {'noise': 150.0}, {'noise': 200.0}]),
    'gaussian': (gaussian_fn, [{'str': 1.0}, {'str': 2.0}, {'str': 3.0}, {'str': 4.0}, {'str': 5.0}]),
    'threshold': (threshold_fn, [{'str': 0.1}, {'str': 0.25}, {'str': 0.5}, {'str': 0.75}, {'str': 1.0}]),
    'binary_blur': (binary_blur_fn, [{'str': 0.1}, {'str': 0.5}, {'str': 1.0}, {'str': 1.5}, {'str': 2.0}]),
    'noisy_binary_blur': (noisy_binary_blur_fn, [{'noise': 0.01}, {'noise': 0.05}, {'noise': 0.1}, {'noise': 0.2}, {'noise': 0.3}]),
    'random_blotches': (random_blotches_fn, [{}]),
    'fibrous_noise': (fibrous_noise_fn, [{}]),
    'multiscale_noise': (multiscale_noise_fn, [{}]),
}

# load the data
output_dir = Path("/netscratch/saifullah/rvl-cdip-wo-tobacco3842-c/")
if not output_dir.exists():
    output_dir.mkdir()

# s
for idx in range(1): #len(datamodule.train_dataset)):
    sample = datamodule.train_dataset[idx]
    output_image_path = output_dir / sample['image_file_path'].split('images/')[1]
    if not output_image_path.parent.exists():
        output_image_path.parent.mkdir(parents=True)
    plt.imsave(str(output_image_path.with_suffix(''))+"_orig.jpg", sample['image'].squeeze(), cmap='gray')
    for distortion, distortion_cfg in distortions.items():
        distortion_fn = distortion_cfg[0]
        distortion_params = distortion_cfg[1]

        for param in distortion_params:
            output_image_path_dist = Path(output_image_path.with_suffix("") / distortion).with_suffix('')
            if not output_image_path_dist.exists():
                output_image_path_dist.mkdir(parents=True)

            params_string = ''
            for k, v, in param.items():
                params_string += f'{k}={v}_' 

            if params_string == '':
                params_string = 'params=fixed'

            params_string += '.jpg'

            distorted_image = distortion_fn(sample['image'].squeeze().cpu().numpy(), param)
            plt.imsave(output_image_path_dist / params_string, distorted_image, cmap='gray')

2021-11-16 11:16:51 port-41xx das.data.data_modules.base[29061] INFO Preparing / preprocesing dataset and saving to cache...
2021-11-16 11:16:51 port-41xx das.data.data_modules.base[29061] INFO Setting up train/validation dataset...
2021-11-16 11:16:51 port-41xx das.data.data_modules.base[29061] INFO Training stage == None
2021-11-16 11:16:51 port-41xx das.data.data_modules.base[29061] INFO Setting up train/validation dataset...
2021-11-16 11:16:51 port-41xx das.data.datasets.base[29061] INFO Loading dataset [rvlcdip-train] from cache directory: //netscratch/saifullah/document_analysis_stack/datasets/rvlcdip/train
2021-11-16 11:16:51 port-41xx das.data.datasets.base[29061] INFO Defining data transformations [train]:
2021-11-16 11:16:51 port-41xx das.data.datasets.base[29061] INFO Loading dataset [rvlcdip-val] from cache directory: //netscratch/saifullah/document_analysis_stack/datasets/rvlcdip/val
2021-11-16 11:16:51 port-41xx das.data.datasets.base[29061] INFO Defining data transforma

	 ConvertImageDtype()
	 ConvertImageDtype()
	 ConvertImageDtype()
