# Visualizing Swin Transformer

**by Pio Lauren T. Mendoza**

In [None]:
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

In [5]:
from PIL import Image
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torchinfo import summary

import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import timm



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"device: {device}")

%load_ext blackcellmagic

device: cpu
The blackcellmagic extension is already loaded. To reload it, use:
  %reload_ext blackcellmagic


In [2]:
with open("imagenet_classes.txt") as f:
    content = f.readlines()

labels = [label.strip("""'"\n""") for label in content]

In [3]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

NameError: name 'T' is not defined

In [4]:
%%capture
model = timm.create_model("swin_base_patch4_window7_224", pretrained=True)
model.eval()
model.to(device)

In [None]:
url = input()
img = Image.open(requests.get(url, stream=True).raw).resize((224,224)).convert('RGB')
img_tens = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
  output = model(img_tens)
display(img)
labels[output.max(-1).indices]

In [5]:
summary(model, input_size = (8, 3, 224, 224))

Layer (type:depth-idx)                             Output Shape              Param #
SwinTransformer                                    --                        --
├─Sequential: 1                                    --                        --
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-1                        --                        397,896
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-2                        --                        1,582,224
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-3                        --                        56,791,584
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-4                        --                        25,203,264
├─PatchEmbed: 1-1                                  [8, 3136, 128]            --
│    └─

In [4]:
list(model.children())

[PatchEmbed(
   (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
   (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
 ),
 Dropout(p=0.0, inplace=False),
 Sequential(
   (0): BasicLayer(
     dim=128, input_resolution=(56, 56), depth=2
     (blocks): ModuleList(
       (0): SwinTransformerBlock(
         (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
         (attn): WindowAttention(
           (qkv): Linear(in_features=128, out_features=384, bias=True)
           (attn_drop): Dropout(p=0.0, inplace=False)
           (proj): Linear(in_features=128, out_features=128, bias=True)
           (proj_drop): Dropout(p=0.0, inplace=False)
           (softmax): Softmax(dim=-1)
         )
         (drop_path): Identity()
         (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
         (mlp): Mlp(
           (fc1): Linear(in_features=128, out_features=512, bias=True)
           (act): GELU()
           (fc2): Linear(in_features=512, out_features

In [71]:
type(list(list(list(model.children())[2][3].children())[0].children())[0])

timm.models.swin_transformer.SwinTransformerBlock

In [8]:
for name, module in model.named_modules():
    print(name)


patch_embed
patch_embed.proj
patch_embed.norm
pos_drop
layers
layers.0
layers.0.blocks
layers.0.blocks.0
layers.0.blocks.0.norm1
layers.0.blocks.0.attn
layers.0.blocks.0.attn.qkv
layers.0.blocks.0.attn.attn_drop
layers.0.blocks.0.attn.proj
layers.0.blocks.0.attn.proj_drop
layers.0.blocks.0.attn.softmax
layers.0.blocks.0.drop_path
layers.0.blocks.0.norm2
layers.0.blocks.0.mlp
layers.0.blocks.0.mlp.fc1
layers.0.blocks.0.mlp.act
layers.0.blocks.0.mlp.fc2
layers.0.blocks.0.mlp.drop
layers.0.blocks.1
layers.0.blocks.1.norm1
layers.0.blocks.1.attn
layers.0.blocks.1.attn.qkv
layers.0.blocks.1.attn.attn_drop
layers.0.blocks.1.attn.proj
layers.0.blocks.1.attn.proj_drop
layers.0.blocks.1.attn.softmax
layers.0.blocks.1.drop_path
layers.0.blocks.1.norm2
layers.0.blocks.1.mlp
layers.0.blocks.1.mlp.fc1
layers.0.blocks.1.mlp.act
layers.0.blocks.1.mlp.fc2
layers.0.blocks.1.mlp.drop
layers.0.downsample
layers.0.downsample.reduction
layers.0.downsample.norm
layers.1
layers.1.blocks
layers.1.blocks.0
la

In [6]:
deit = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
deit.eval();

Downloading: "https://github.com/facebookresearch/deit/archive/main.zip" to /home/pio/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /home/pio/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
100.0%


In [7]:
for name, module in deit.named_modules():
    print(name)


patch_embed
patch_embed.proj
patch_embed.norm
pos_drop
blocks
blocks.0
blocks.0.norm1
blocks.0.attn
blocks.0.attn.qkv
blocks.0.attn.attn_drop
blocks.0.attn.proj
blocks.0.attn.proj_drop
blocks.0.drop_path
blocks.0.norm2
blocks.0.mlp
blocks.0.mlp.fc1
blocks.0.mlp.act
blocks.0.mlp.fc2
blocks.0.mlp.drop
blocks.1
blocks.1.norm1
blocks.1.attn
blocks.1.attn.qkv
blocks.1.attn.attn_drop
blocks.1.attn.proj
blocks.1.attn.proj_drop
blocks.1.drop_path
blocks.1.norm2
blocks.1.mlp
blocks.1.mlp.fc1
blocks.1.mlp.act
blocks.1.mlp.fc2
blocks.1.mlp.drop
blocks.2
blocks.2.norm1
blocks.2.attn
blocks.2.attn.qkv
blocks.2.attn.attn_drop
blocks.2.attn.proj
blocks.2.attn.proj_drop
blocks.2.drop_path
blocks.2.norm2
blocks.2.mlp
blocks.2.mlp.fc1
blocks.2.mlp.act
blocks.2.mlp.fc2
blocks.2.mlp.drop
blocks.3
blocks.3.norm1
blocks.3.attn
blocks.3.attn.qkv
blocks.3.attn.attn_drop
blocks.3.attn.proj
blocks.3.attn.proj_drop
blocks.3.drop_path
blocks.3.norm2
blocks.3.mlp
blocks.3.mlp.fc1
blocks.3.mlp.act
blocks.3.mlp.fc2