In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

import warnings
warnings.simplefilter('ignore')

import gc

from os import path
import sys
sys.path.append(path.abspath('..'))

In [None]:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from timm import create_model
from torch.nn import functional as F
from torch.optim import Adam
from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy

from src.dali import CustomPipeline

In [None]:
class LitDALI(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = create_model(model_name='gernet_s', num_classes=5)
    
    def forward(self, x):
        return self.model(x)
    
    def process_batch(self, batch):
        return batch[0]['image'], batch[0]['label']
    
    def training_step(self, batch, batch_idx):
        x, y = self.process_batch(batch)
        logits = self(x)
        loss = F.binary_cross_entropy_with_logits(logits, y)
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)
    
    def prepare_data(self):
        pipeline = CustomPipeline(
            batch_size=60,
            num_threads=4,
            device_id=0,
         )
        self.train_loader = DALIGenericIterator(
            pipeline,
            ['image', 'label'],
            size=68811-7,
            auto_reset=True,
            last_batch_policy=LastBatchPolicy.PARTIAL,
        )
        
    def train_dataloader(self):
        return self.train_loader

In [None]:
model = LitDALI()

In [None]:
trainer = Trainer(gpus=1, max_epochs=1)

In [None]:
%%time
trainer.fit(model)

In [None]:
# 2min 24s