# 模型介绍

**Transformer in Transformer**

github pytorch代码: [https://github.com/huawei-noah/noah-research/tree/master/TNT](https://github.com/huawei-noah/noah-research/tree/master/TNT)

论文地址: [https://arxiv.org/pdf/2103.00112.pdf](https://arxiv.org/pdf/2103.00112.pdf)

* 	1. 用于对patch级和pixel级的表征进行建模
* 	2. 在每个TNT Block中，outer transformer block用于处理patch embedding，而inner transformer block则从pixel embedding中提取局部特征。
* 	3. 通过线性变换层将pixel级特征投影到patch embedding的空间，然后将其添加到patch中
* 	4. 嵌套transformer的思想就是先一个Patch， 再对里面的pixel进行transformer
*   5. 对于patch而言，每个patch有一个独立的可学习得位置编码，而对于pixel级的序列而言，位置是在patch中相对位置的编码，每个patch的对应位置的pixel的位置编码是相同的


![](https://ai-studio-static-online.cdn.bcebos.com/7efb4a8855594784a93e95e6e15326a5afc8a518c5454100b17048146aef6eee)

![](https://ai-studio-static-online.cdn.bcebos.com/96958dec231445429d7cfcc7a6bc7932eeb010d2b2eb4b248400055114b562b3)

![](https://ai-studio-static-online.cdn.bcebos.com/0b2b469436214903895bb3cca68363ba647d317a141d4a2c87e2bbd6bc22d074)




# 关于数据集ImageNet

ImageNet图像数据集始于2009年，当时李飞飞教授等在CVPR2009上发表了一篇名为《ImageNet: A Large-Scale Hierarchical Image Database》的论文，之后就是基于ImageNet数据集的7届ImageNet挑战赛(2010年开始)，2017年后，ImageNet由Kaggle(Kaggle公司是由联合创始人兼首席执行官Anthony Goldbloom 2010年在墨尔本创立的，主要是为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台)继续维护。

本AIStudio项目在线下进行的训练， 所以只使用了验证集进行验证

![](https://ai-studio-static-online.cdn.bcebos.com/1e8613aebb754b96bc799dd3c0c51278da5ab0599e264467912c9e2782821a24)


In [None]:
#数据集解压
!mkdir ~/data/ILSVRC2012
!tar -xf ~/data/data68594/ILSVRC2012_img_val.tar -C ~/data/ILSVRC2012

In [3]:
#加载数据集
import os
import shutil
import numpy as np
import paddle
from paddle.io import Dataset
from paddle.vision.datasets import DatasetFolder, ImageFolder
# from paddle.vision.transforms import Compose, Resize, Transpose, Normalize
import paddle.vision.transforms as T
train_parameters = {
    'train_image_dir': '/home/aistudio/data/ILSVRC2012',
    'eval_image_dir': '/home/aistudio/data/ILSVRC2012',
    'test_image_dir': '/home/aistudio/data/ILSVRC2012',
}

class CatDataset(Dataset):
    def __init__(self, mode='train'):
        super(CatDataset, self).__init__()
        train_image_dir = train_parameters['train_image_dir']
        eval_image_dir = train_parameters['eval_image_dir']
        test_image_dir = train_parameters['test_image_dir']

        data_transforms = T.Compose([
            T.Resize(256, interpolation='bicubic'),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        train_data_folder = DatasetFolder(train_image_dir, transform=data_transforms)
        eval_data_folder = DatasetFolder(eval_image_dir, transform=data_transforms)
        test_data_folder = ImageFolder(test_image_dir, transform=data_transforms)
        self.mode = mode
        if self.mode  == 'train':
            self.data = train_data_folder
        elif self.mode  == 'eval':
            self.data = eval_data_folder
        elif self.mode  == 'test':
            self.data = test_data_folder
        print(mode, len(self.data))

    def __getitem__(self, index):
        data = self.data[index][0].astype('float32')
        if self.mode  == 'test':
            return data
        else:
            label = np.array([self.data[index][1]]).astype('int64')
            return data, label

    def __len__(self):
        return len(self.data)

# 模型结构搭建

In [1]:
import paddle
import paddle.nn as nn
from paddle.nn.initializer import TruncatedNormal, Constant
import math
import copy
import paddle.nn.functional as F
import numpy as np
import pickle

# 参数初始化配置
trunc_normal_ = TruncatedNormal(std=.02)
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
from paddle.io import Dataset

# 独立层，即什么操作都没有的网络层
class Identity(nn.Layer):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, input):
        return input

class PixelEmbed(nn.Layer):
    """ Image to Pixel Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2
        self.img_size = img_size
        self.num_patches = num_patches
        self.in_dim = in_dim
        new_patch_size = math.ceil(patch_size / stride)
        self.new_patch_size = new_patch_size
        self.proj = nn.Conv2D(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)

    def forward(self, x, pixel_pos):
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size, \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
        x = self.proj(x)
    
        x = F.unfold(x, self.new_patch_size, self.new_patch_size)
        x = x.transpose((0, 2, 1)).reshape((B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size))
        x = x + pixel_pos
        x = x.reshape((B * self.num_patches, self.in_dim, -1)).transpose((0, 2, 1))
        return x

class Attention(nn.Layer):
    """ Multi-Head Attention
    """
    def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        head_dim = hidden_dim // num_heads
        self.head_dim = head_dim
        self.scale = head_dim ** -0.5

        self.qk = nn.Linear(dim, hidden_dim * 2, bias_attr=qkv_bias)
        self.v = nn.Linear(dim, dim, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qk = self.qk(x).reshape((B, N, 2, self.num_heads, self.head_dim)).transpose((2, 0, 3, 1, 4))
        q, k = qk[0], qk[1]   # make torchscript happy (cannot use tensor as tuple)
        v = self.v(x).reshape((B, N, self.num_heads, -1)).transpose((0, 2, 1, 3))

        attn = (q @ k.transpose((0, 1, 3, 2))) * self.scale
        attn = F.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose((0, 2, 1, 3)).reshape((B, N, -1))
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def drop_path(x, drop_prob=0., training=False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class Block(nn.Layer):
    """ TNT Block
    """
    def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
                 qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        # Inner transformer
        self.norm_in = norm_layer(in_dim)
        self.attn_in = Attention(
            in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop) # attention to check

        self.norm_mlp_in = norm_layer(in_dim)
        self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4),
                          out_features=in_dim, act_layer=act_layer, drop=drop) #MLP to check

        self.norm1_proj = norm_layer(in_dim)
        self.proj = nn.Linear(in_dim * num_pixel, dim, bias_attr=True)

        # Outer transformer
        self.norm_out = norm_layer(dim)
        self.attn_out = Attention(
            dim, dim, num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() # to check droppath

        self.norm_mlp = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
                       out_features=dim, act_layer=act_layer, drop=drop)   #MLP to check

    def forward(self, pixel_embed, patch_embed):
        # inner
        pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed)))
        pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))

        # outer
        B, N, C = patch_embed.shape
        patch_embed[:, 1:] = patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape((B, N - 1, -1)))
        patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
        patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
        return pixel_embed, patch_embed


class TNT(nn.Layer):
    """TNT"""
    def __init__(
            self,
            img_size=224,
            patch_size=16,
            in_chans=3,
            num_classes=1000,
            embed_dim=384,
            in_dim=24,
            depth=12,
            num_heads=6,
            in_num_head=4,
            mlp_ratio=4.,
            qkv_bias=False,
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.,
            norm_layer=nn.LayerNorm,
            first_stride=4):
        super(TNT, self).__init__()

        assert embed_dim % num_heads == 0
        assert img_size % patch_size == 0

        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.pixel_embed = PixelEmbed(img_size, patch_size, in_chans, in_dim, first_stride)
        num_patches = self.pixel_embed.num_patches
        self.num_patches = num_patches
        new_patch_size = self.pixel_embed.new_patch_size
        num_pixel = new_patch_size ** 2

        self.norm1_proj = norm_layer(num_pixel * in_dim)
        self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
        self.norm2_proj = norm_layer(embed_dim)
        self.cls_token = self.create_parameter(shape=(1, 1, embed_dim), default_initializer=zeros_)
        self.patch_pos = self.create_parameter(shape=(1, self.num_patches + 1, embed_dim), default_initializer=zeros_)
        self.pixel_pos = self.create_parameter(shape=(1, in_dim, new_patch_size, new_patch_size), default_initializer=zeros_)
        self.pos_drop = nn.Dropout(1. - drop_rate)

        dpr = list(np.linspace(0, drop_rate, depth))
        blocks = []
        for i in range(depth):
            blocks.append(Block(
                dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head,
                mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,
                drop_path=dpr[i], norm_layer=norm_layer))
        self.blocks = nn.LayerList(blocks)
        self.norm = norm_layer(embed_dim)

        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else Identity()

        trunc_normal_(self.cls_token)
        trunc_normal_(self.patch_pos)
        trunc_normal_(self.pixel_pos)

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else Identity()

    def forward_features(self, x):

        B = x.shape[0]
        pixel_embed = self.pixel_embed(x, self.pixel_pos)

        patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape((B, self.num_patches, -1))))) 
        patch_embed = paddle.concat((self.cls_token.expand((B, -1, -1)), patch_embed), axis=1)
        patch_embed = patch_embed + self.patch_pos
        patch_embed = self.pos_drop(patch_embed)
        for blk in self.blocks:
            pixel_embed, patch_embed = blk(pixel_embed, patch_embed)

        patch_embed = self.norm(patch_embed)
        return patch_embed[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def convert_to_list(value, n, name, dtype=np.int):


#  精度对齐

因为是简单的图像分类模型，这里只做一个相同输入下的输出结果验证

**torch的输出**

![](https://ai-studio-static-online.cdn.bcebos.com/f7932104c75e462c9b7dc20466b9fbf7d5fab4861a234e0f92d840e971972327)

**paddle的输出**

![](https://ai-studio-static-online.cdn.bcebos.com/b304ef2f36d24353a9656312c02d77a5e8aa6864dd994c51b474e41eb2dca76c)


# 训练模型

由于训练集特别大, AIStduio暂时还受不了, 这里只用验证集数据训练了两轮

In [7]:
# 在AIStuido里测试时加载的数据集
import cv2
from PIL import Image

transforms = T.Compose([
    T.Resize(256, interpolation='bicubic'),
    T.CenterCrop(224),
    T.ToTensor(),
    # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 构建数据集
class ILSVRC2012(paddle.io.Dataset):
    def __init__(self, root, label_list, transform, backend='pil'):
        self.transform = transform
        self.root = root
        self.label_list = label_list
        self.backend = backend
        self.load_datas()

    def load_datas(self):
        self.imgs = []
        self.labels = []
        with open(self.label_list, 'r') as f:
            for line in f:
                img, label = line[:-1].split(' ')
                self.imgs.append(os.path.join(self.root, img))
                self.labels.append(int(label))

    def __getitem__(self, idx):
        label = self.labels[idx]
        image = self.imgs[idx]
        if self.backend=='cv2':
            image = cv2.imread(image)
        else:
            image = Image.open(image).convert('RGB')
        image = self.transform(image)
        return image.astype('float32'), np.array(label).astype('int64')

    def __len__(self):
        return len(self.imgs)

val_dataset = ILSVRC2012('data/ILSVRC2012', transform=transforms, label_list='data/data68594/val_list.txt')

In [5]:
#保存训练结果
callback = paddle.callbacks.ModelCheckpoint(save_dir='./checkpoints', save_freq=1)

#加载模型及预训练参数
model = TNT(num_classes=1000)
run_model = paddle.Model(model)

#模型训练
optim = paddle.optimizer.SGD(learning_rate=0.0001, weight_decay=6e-5, parameters=run_model.parameters())
run_model.prepare(optimizer= optim,
              loss=paddle.nn.CrossEntropyLoss(),
              metrics=paddle.metric.Accuracy())
run_model.fit(val_dataset, val_dataset, epochs=2, batch_size=128, callbacks=callback, verbose=1)

The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/2


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
  return (isinstance(seq, collections.Sequence) and


save checkpoint at /home/aistudio/checkpoints/0
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 50000
Epoch 2/2
save checkpoint at /home/aistudio/checkpoints/1
Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.
Eval samples: 50000
save checkpoint at /home/aistudio/checkpoints/final


# 验证模型

作者并没有开源源码和更多的信息, 模型验证的结果很难接近论文的精度

In [8]:
model = TNT(num_classes=1000)
model_state_dict = paddle.load("/home/aistudio/work/tnt.pdparams")
model.set_state_dict(model_state_dict)
run_model = paddle.Model(model)
optim = paddle.optimizer.SGD(learning_rate=0.0001, weight_decay=6e-5, parameters=run_model.parameters())
run_model.prepare(optimizer= optim,
              loss=paddle.nn.CrossEntropyLoss(),
              metrics=paddle.metric.Accuracy())

#模型验证
run_model.evaluate(val_dataset, batch_size=32, verbose=1)

Eval begin...
The loss value printed in the log is the current batch, and the metric is the average value of previous step.


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):


Eval samples: 50000


{'loss': [0.44473687], 'acc': 0.81408}

# 总结

因训练硬件资源和时间有限，本次复现过程还有很多缺失和不足，后续持续改进。

请点击[此处](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576)查看本环境基本用法.  <br>
Please click [here ](https://ai.baidu.com/docs#/AIStudio_Project_Notebook/a38e5576) for more detailed instructions. 