In [1]:
import torch

from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights
from models.model import get_model

from torch import nn
import torchinfo
from torchprofile import profile_macs

import timm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:

model = timm.create_model('convnextv2_large.fcmae_ft_in22k_in1k', pretrained=True)

model.safetensors:   0%|          | 0.00/792M [00:00<?, ?B/s]

In [2]:
model = get_model("ConvNeXtBase", pretrained=True, num_classes=200, freeze=True)

In [3]:
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (1): Permute()
          (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=128, out_features=512, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=512, out_features=128, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (1): Permute()
          (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (3): Linear(

In [4]:
torchinfo.summary(model, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
ConvNeXt                                      [1, 200]                  --
├─Sequential: 1-1                             [1, 1024, 7, 7]           --
│    └─Conv2dNormActivation: 2-1              [1, 128, 56, 56]          --
│    │    └─Conv2d: 3-1                       [1, 128, 56, 56]          (6,272)
│    │    └─LayerNorm2d: 3-2                  [1, 128, 56, 56]          (256)
│    └─Sequential: 2-2                        [1, 128, 56, 56]          --
│    │    └─CNBlock: 3-3                      [1, 128, 56, 56]          (138,496)
│    │    └─CNBlock: 3-4                      [1, 128, 56, 56]          (138,496)
│    │    └─CNBlock: 3-5                      [1, 128, 56, 56]          (138,496)
│    └─Sequential: 2-3                        [1, 256, 28, 28]          --
│    │    └─LayerNorm2d: 3-6                  [1, 128, 56, 56]          (256)
│    │    └─Conv2d: 3-7                       [1, 256, 28, 28] 

In [5]:

sample_input = torch.randn(1, 3, 224, 224).to(DEVICE)

flops = profile_macs(model, sample_input)

print(flops / 1e9)

15.368613888




In [6]:
model = get_model('ConvTransNeXtBase', pretrained=True, num_classes=200, freeze=True).to(DEVICE)

sample_input = torch.randn(32, 3, 224, 224).to(DEVICE)

model(sample_input).shape



torch.Size([32, 200])

In [7]:
torchinfo.summary(model, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                             Output Shape              Param #
ConvTransNeXtBase                                  [1, 200]                  --
├─ConvNeXt: 1-4                                    --                        (recursive)
│    └─Sequential: 2-1                             [1, 512, 14, 14]          --
│    │    └─Conv2dNormActivation: 3-1              [1, 128, 56, 56]          (6,528)
│    │    └─Sequential: 3-2                        [1, 128, 56, 56]          (415,488)
│    │    └─Sequential: 3-3                        [1, 256, 28, 28]          (131,584)
│    │    └─Sequential: 3-4                        [1, 256, 28, 28]          (1,617,408)
│    │    └─Sequential: 3-5                        [1, 512, 14, 14]          (525,312)
│    │    └─Sequential: 3-6                        [1, 512, 14, 14]          (57,424,896)
├─PositionalEncoding: 1-2                          [196, 1, 512]             --
│    └─Dropout: 2-2                                [196, 1, 5

In [8]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample_input = torch.randn(1, 3, 224, 224).to(DEVICE)

flops = profile_macs(model, sample_input)

print(flops / 1e9)

15.258423821


