In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.profilers import SimpleProfiler
import wandb

from stn import STN_ON
from svtrnet import SVTRNet
from rnn import SequenceEncoder
from ctc_head import CTCHead
from dataloader.dataset import TNGODataset
from rec_postprocess import CTCLabelDecode

In [32]:
class LitSVTR(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.transform = STN_ON()
        self.backbone = SVTRNet()
        self.neck = SequenceEncoder(in_channels=384, encoder_type="reshape")
        self.head = CTCHead(in_channels=384, out_channels=228)
        self.criterion = torch.nn.CTCLoss(zero_infinity=True)        
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        image, label, label_length = batch['image'], batch['label'], batch['length']
        x = self.transform(image)
        x = self.backbone(x)
        x = self.neck(x)
        output = self.head(x)
        permuted_output = output[0].permute(1, 0, 2)
        N, B, _ = permuted_output.shape
        output_length = torch.tensor([N]*B, dtype=torch.long)
        loss = self.criterion(permuted_output, label, output_length, label_length)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        image, label, label_length = batch['image'], batch['label'], batch['length']
        x = self.transform(image)
        x = self.backbone(x)
        x = self.neck(x)
        output = self.head(x)
        permuted_output = output[0].permute(1, 0, 2)
        N, B, _ = permuted_output.shape
        output_length = torch.tensor([N]*B, dtype=torch.long)
        loss = self.criterion(permuted_output, label, output_length, label_length)
        self.log("val_loss", loss)
        return loss
    
    def test_step(self, batch, batch_idx):
        image, label, label_length = batch['image'], batch['label'], batch['length']
        x = self.transform(image)
        x = self.backbone(x)
        x = self.neck(x)
        output = self.head(x)
        permuted_output = output[0].permute(1, 0, 2)
        N, B, _ = permuted_output.shape
        output_length = torch.tensor([N]*B, dtype=torch.long)
        loss = self.criterion(permuted_output, label, output_length, label_length)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=2.5 / (10**4), weight_decay=0.05)
        return optimizer
    
    def forward(self, x):
        x = self.transform(x)
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

In [33]:
model = LitSVTR()

In [34]:
data = torch.randn(3,3,64,200)

In [36]:
x = model(data)
print(x)

(tensor([[[-5.3292, -5.3887, -5.5442,  ..., -5.3901, -5.7036, -5.3127],
         [-5.3604, -5.5886, -5.5700,  ..., -5.4864, -5.4082, -5.2874],
         [-5.3934, -5.5143, -5.3608,  ..., -5.6791, -5.5466, -5.3420],
         ...,
         [-5.3495, -5.6305, -5.5196,  ..., -5.6110, -5.3837, -5.4450],
         [-5.2740, -5.5005, -5.5561,  ..., -5.5618, -5.5543, -5.3591],
         [-5.5143, -5.3324, -5.3566,  ..., -5.4195, -5.5106, -5.3526]],

        [[-5.3880, -5.5124, -5.4925,  ..., -5.5264, -5.6279, -5.4979],
         [-5.4535, -5.5706, -5.5068,  ..., -5.6015, -5.4023, -5.3388],
         [-5.4316, -5.4360, -5.2944,  ..., -5.5977, -5.4082, -5.3362],
         ...,
         [-5.3925, -5.3003, -5.4736,  ..., -5.3267, -5.5204, -5.5819],
         [-5.2556, -5.3577, -5.5833,  ..., -5.3967, -5.5319, -5.4793],
         [-5.2211, -5.4525, -5.4842,  ..., -5.4151, -5.4296, -5.4203]],

        [[-5.4516, -5.5308, -5.3927,  ..., -5.4739, -5.4734, -5.3238],
         [-5.4681, -5.4653, -5.3301,  ..., -