In [1]:
import torch

from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights
from models.model import get_model

from torch import nn
import torchinfo
from torchprofile import profile_macs

import timm

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:

model = timm.create_model('convnextv2_large.fcmae_ft_in22k_in1k', pretrained=True)

model.safetensors:   0%|          | 0.00/792M [00:00<?, ?B/s]

In [8]:
model = get_model("ConvNeXtLarge", pretrained=True, num_classes=200, freeze=True)

In [13]:
model

ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [9]:
torchinfo.summary(model, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
ConvNeXt                                      [1, 200]                  --
├─Sequential: 1-1                             [1, 1536, 7, 7]           --
│    └─Conv2dNormActivation: 2-1              [1, 192, 56, 56]          --
│    │    └─Conv2d: 3-1                       [1, 192, 56, 56]          (9,408)
│    │    └─LayerNorm2d: 3-2                  [1, 192, 56, 56]          (384)
│    └─Sequential: 2-2                        [1, 192, 56, 56]          --
│    │    └─CNBlock: 3-3                      [1, 192, 56, 56]          (306,048)
│    │    └─CNBlock: 3-4                      [1, 192, 56, 56]          (306,048)
│    │    └─CNBlock: 3-5                      [1, 192, 56, 56]          (306,048)
│    └─Sequential: 2-3                        [1, 384, 28, 28]          --
│    │    └─LayerNorm2d: 3-6                  [1, 192, 56, 56]          (384)
│    │    └─Conv2d: 3-7                       [1, 384, 28, 28] 

In [6]:

sample_input = torch.randn(1, 3, 224, 224).to(DEVICE)

flops = profile_macs(model, sample_input)

print(flops / 1e9)

34.398090276




In [2]:
model = get_model('ConvTransNeXtTiny', pretrained=True, num_classes=200, freeze=True).to(DEVICE)

sample_input = torch.randn(32, 3, 224, 224).to(DEVICE)

model(sample_input).shape



torch.Size([32, 200])

In [3]:
torchinfo.summary(model, input_size=(1, 3, 224, 224))

Layer (type:depth-idx)                             Output Shape              Param #
ConvTransNeXtTiny                                  [1, 200]                  --
├─ConvNeXt: 1-3                                    --                        (recursive)
│    └─Sequential: 2-1                             [1, 768, 7, 7]            --
│    │    └─Conv2dNormActivation: 3-1              [1, 96, 56, 56]           (4,896)
│    │    └─Sequential: 3-2                        [1, 96, 56, 56]           (237,888)
│    │    └─Sequential: 3-3                        [1, 192, 28, 28]          (74,112)
│    │    └─Sequential: 3-4                        [1, 192, 28, 28]          (918,144)
│    │    └─Sequential: 3-5                        [1, 384, 14, 14]          (295,680)
│    │    └─Sequential: 3-6                        [1, 384, 14, 14]          (10,817,280)
│    │    └─Sequential: 3-7                        [1, 768, 7, 7]            1,181,184
├─TransformerEncoder: 1-2                          [49, 1

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sample_input = torch.randn(1, 3, 224, 224).to(DEVICE)

flops = profile_macs(model, sample_input)

print(flops / 1e9)

4.340629261


