In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
import torch
import torch.nn as nn

## ResNet18

In [3]:
from torchvision import models

resnet18 = models.resnet18()

In [86]:
from thesis.writer_code.models import BatchNorm2dAdaptive

def replace_bn_adaptive(module, writer_code_size):
    """
    Replace all nn.BatchNorm2d layers with BatchNorm2dAdaptive layers.
    
    Returns:
        list of all newly added BatchNorm2dAdaptive modules
    """
    new_mods = []
    if isinstance(module, BatchNorm2dAdaptive):
        return new_mods
    for attr_str in dir(module):
        attr = getattr(module, attr_str)
        if type(attr) == nn.BatchNorm2d:
            new_bn = BatchNorm2dAdaptive(attr, writer_code_size)
            setattr(module, attr_str, new_bn)
            new_mods.append(new_bn)
            
    for child_module in module.children():
        new_mods.extend(replace_bn_adaptive(child_module, writer_code_size))
    return new_mods
            
bn_layers = replace_bn_adaptive(resnet18, 64)

In [87]:
from typing import Sequence

def set_writer_code(bn_layers: Sequence[BatchNorm2dAdaptive], writer_code: torch.Tensor):
    """Set writer_code attribute for all BatchNorm2dAdaptive layers"""
    for l in bn_layers:
        l.writer_code = writer_code
        
set_writer_code(bn_layers, torch.rand(2, 64))

In [4]:
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [80]:
from torchsummary import summary

summary(resnet18, input_size=(3, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
              ReLU-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
             ReLU-10           [-1, 64, 32, 32]               0
       BasicBlock-11           [-1, 64, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]          36,864
      BatchNorm2d-13           [-1, 64, 32, 32]             128
             ReLU-14           [-1, 64,

## ResNet31

In [2]:
from htr.models.sar.sar import ShowAttendRead

In [3]:
model = ShowAttendRead()
resnet31 = model.encoder
resnet31

TypeError: __init__() missing 1 required positional argument: 'label_encoder'

In [17]:
from torchsummary import summary

summary(resnet31, input_size=(1, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]             640
       BatchNorm2d-2         [-1, 64, 128, 128]             128
              ReLU-3         [-1, 64, 128, 128]               0
            Conv2d-4        [-1, 128, 128, 128]          73,856
       BatchNorm2d-5        [-1, 128, 128, 128]             256
              ReLU-6        [-1, 128, 128, 128]               0
         MaxPool2d-7          [-1, 128, 64, 64]               0
            Conv2d-8          [-1, 256, 64, 64]         294,912
       BatchNorm2d-9          [-1, 256, 64, 64]             512
             ReLU-10          [-1, 256, 64, 64]               0
           Conv2d-11          [-1, 256, 64, 64]         589,824
      BatchNorm2d-12          [-1, 256, 64, 64]             512
           Conv2d-13          [-1, 256, 64, 64]          32,768
      BatchNorm2d-14          [-1, 256,

## Transformer

In [10]:
from htr.models.fphtr.fphtr import FullPageHTRDecoder

fphtr_decoder = FullPageHTRDecoder(10, 10, 0, 1, 2, 8, 2, 2, 2, 0.1)

In [11]:
fphtr_decoder

FullPageHTRDecoder(
  (emb): Embedding(10, 8)
  (pos_emb): PositionalEmbedding1D()
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0): TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)
        )
        (linear1): Linear(in_features=8, out_features=2, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2, out_features=8, bias=True)
        (norm1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
        (dropout3): Dropout(p=0

In [8]:
fphtr_decoder.decoder.layers[0].norm1.normalized_shape

(4,)