In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.simplefilter('ignore')
import gc

from os import path
import sys
sys.path.append(path.abspath('..'))

In [None]:
import nvidia.dali.ops as ops
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import TorchPythonFunction, DALIGenericIterator, LastBatchPolicy


import kornia.augmentation as augs
import cupy as cp
import torch
from tqdm.notebook import tqdm

from src.dali import ExternalInputIterator

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt

def show_images(image_batch):
    columns = 4
    rows = (len(image_batch) + 1) // (columns)
    fig = plt.figure(figsize = (32,(32 // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        plt.imshow(image_batch.at(j))

In [None]:

def add_number(image, number=20.0):
    result = cp.copy(image)
    result[:, :, 0] = result[:, :, 0] + number
    
    return result

# NHWC -> NCHW
def channel_shuffle(image):
    image = image.permute(2, 0, 1).float()
    image = augs.RandomChannelShuffle()(image)[0]
    image = image.permute(1, 2, 0).type(torch.uint8)
    return image

def channel_shuffle_batch(images):
    images = torch.stack(images).permute(0, 3, 1, 2).float()
    images = augs.RandomChannelShuffle()(images)
    images = images.permute(0, 2, 3, 1).type(torch.uint8)
    return images

In [None]:
# Brightness 
# BrightnessContrast 
# ColorTwist 
# Contrast 
# Erase 
# GaussianBlur 
# Hsv 
# Jitter 
# Rotate 
# Sphere 
# WarpAffine 
# Water

In [None]:
list(filter(lambda x: '_' not in x, dir(ops)))

In [None]:
class AugmentationsPipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id=0):
        super().__init__(
            batch_size=batch_size,
            num_threads=num_threads,
            device_id=device_id,
            seed=0xDEAD,
            # for custom functions
            exec_async=False,
            exec_pipelined=False,
        )
        self.input = ops.ExternalSource(
            source=ExternalInputIterator(batch_size),
            num_outputs=2,
        )
        self.decode = ops.ImageDecoder(device = 'mixed', output_type = types.RGB)
        self.resize = ops.Resize(device='gpu', size=(224, 224), interp_type=types.INTERP_TRIANGULAR)
        
#         self.rotate = ops.Rotate(device='gpu')
#         self.rotate_range = ops.Uniform(range=(-20., 20.))

#         self.sphere = ops.Sphere(device='gpu')
#         self.sphere_apply = ops.CoinFlip()
        
#         Adjusts hue, saturation and brightness of the image
#         self.twist = ops.ColorTwist(device='gpu')
#         self.range1 = ops.Uniform(range=[0.5, 2.])
#         self.range2 = ops.Uniform(range=[-15, 15])

#         self.add_number = ops.PythonFunction(
#             function=add_number,
#             num_outputs=1,
#             device='gpu',
#             output_layouts=types.NHWC,
#         )
#         self.random_number = ops.Uniform(values=list(range(0, 20)))

        self.channel_shiffle = TorchPythonFunction(
            function=channel_shuffle,
            num_outputs=1,
            device='gpu',
            batch_processing=False,
            output_layouts=types.NHWC,
        )
#         self.channel_shiffle_batch = TorchPythonFunction(
#             function=channel_shuffle_batch,
#             num_outputs=1,
#             device='gpu',
#             batch_processing=True,
#             output_layouts=types.NHWC,
#         )
#         self.normalization = ops.CropMirrorNormalize(
#             device="gpu",
#             dtype=types.FLOAT,
#             output_layout=types.NCHW,
#             mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
#             std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
#         )
    
    
    def define_graph(self):
        jpegs, labels = self.input()
        images = self.decode(jpegs)
        images = self.resize(images)
        
#         angle = self.rotate_range()
#         images = self.rotate(images, angle=angle, fill_value=0, keep_size=True)


#         images = self.sphere(images, mask=self.sphere_apply())

#         images = self.twist(
#             images,
#             saturation=self.range1(),
#             contrast=self.range1(),
#             brightness=self.range1(),
#             hue=self.range2(),
#         )

#         number = self.random_number().gpu() # рандом не работает на gpu - тащим сами
#         images = self.add_number(images, number)

        images = self.channel_shiffle(images)
#         images = self.channel_shiffle_batch(images)
#         images = self.normalization(images)
        
        
        return images, labels

# Посмотрим картиночки

In [None]:
pipeline = AugmentationsPipeline(batch_size=12, num_threads=1)
pipeline.build()

In [None]:
pipeline_output = pipeline.run()
images, labels = pipeline_output

In [None]:
show_images(images.as_cpu())

In [None]:
# оригинал
show_images(images.as_cpu())

# Сравним скорость

In [None]:
pipeline = AugmentationsPipeline(batch_size=60, num_threads=8)

loader = DALIGenericIterator(
            pipeline,
            ['image', 'label'],
            size=68811-7,
            auto_reset=True,
        )

In [None]:
for data in tqdm(loader):
    _ = data[0]['image'].shape
#     break

In [None]:
# с асинхронностью и прочим 00:23 
# без асинхронности 00:34 - жить можно
# без асинхронности + sphere 00:35 - ваще норм
# без асинхронности + add number 00:42 - уже хуже
# без асинхронности + channel_shuffle 01:06, больше минуты - очень плохо, но возможно всё равно лучше чем не dali
# без асинхронности + channel_shuffle_batch 00:44 - прям топ

# Сравним скорость обучения

In [None]:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer

from timm.models import gernet_s
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

In [None]:
class LitDALI(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = gernet_s(num_classes=5)
    
    def forward(self, x):
        return self.model(x)
    
    def process_batch(self, batch):
        return batch[0]['image'], batch[0]['label']
    
    def training_step(self, batch, batch_idx):
        x, y = self.process_batch(batch)
        logits = self(x)
        loss = F.binary_cross_entropy_with_logits(logits, y)
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)
    
    def prepare_data(self):
        pipeline = AugmentationsPipeline(
            batch_size=60,
            num_threads = 8,
            device_id = 0,
         )
        self.train_loader = DALIGenericIterator(
            pipeline,
            ['image', 'label'],
            size=68811-6,
            auto_reset=True,
            last_batch_policy=LastBatchPolicy.PARTIAL,
        )
        
    def train_dataloader(self):
        return self.train_loader

In [None]:
model = LitDALI()
trainer = Trainer(gpus=1, max_epochs=1)

In [None]:
%%time
trainer.fit(model)

In [None]:
# с асинхронностью и прочим, без ауг - 02:27
# без асинхронности, без ауг - 02:44
# без асинхронности, sphere - 02:45
# без асинхронности, add_number - 02:54
# без асинхронности, channel_shuffle - 03:17
# без асинхронности, channel_shuffle_batch - 02:56