## Before Running:
Please Install all from the requirements.txt (pip install -r requirements.txt).   
Download the pre-trained SwinT weights from here https://drive.google.com/drive/folders/1HBw5NGGw8DjkyNurksCP5v8a5f0FG7zU and put them in this folder

## Set Hyper Parameters

In [32]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder_decoder_layers = 3
encoder_decoder_heads = 8
embedded_dim_size = 512
max_length = 32
coco_dataset_ratio = 50
coco_dataset_dir = "./coco"
vocab_size = 30522
batch_size = 32
num_epochs = 10
learning_rate = 1e-3
patience = 3
weight_decay = 1e-5
att_feats_dim = 2048
encoder_model = "microsoft/swin-base-patch4-window7-224-in22k"
decoder_model = "gpt2"

## Downloading and Format datasets
This will take some time to finishing running the first time. It took me roughly 40 minutes.

This section does the following actions:
1. Downloads the Dataset
2. Keeps images with only 3 or 4 dim
3. Transforms the dataset 
4. Turns the data set into data loaders


In [12]:
import numpy as np
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast
from torch.utils.data import DataLoader
import torch

# Download the train, val and test splits of the COCO dataset
train_ds = load_dataset("HuggingFaceM4/COCO", split=f"train[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
valid_ds = load_dataset("HuggingFaceM4/COCO", split=f"validation[:{coco_dataset_ratio}%]", cache_dir=coco_dataset_dir)
test_ds = load_dataset("HuggingFaceM4/COCO", split="test", cache_dir=coco_dataset_dir)


# Filter all non 3 or 4 dim images out
# Can change num_proc, but might be errors with np
train_ds = train_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
valid_ds = valid_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)
test_ds = test_ds.filter(lambda item: np.array(item["image"]).ndim in [3, 4], num_proc=1)


# Does pre processing on the data set
# This includes pre-trained ViTimage feature extraction and tokenizing captions
# I am unsure if the paper does any of this pre processing
tokenizer = GPT2TokenizerFast.from_pretrained(decoder_model)
tokenizer.pad_token = tokenizer.eos_token
image_processor = ViTImageProcessor.from_pretrained(encoder_model)

def preprocess(items):
    pixel_values = image_processor(items["image"], return_tensors="pt").pixel_values.to(device)
    targets = tokenizer([sentence["raw"] for sentence in items["sentences"]],
                        max_length=max_length, padding="max_length", truncation=True, return_tensors="pt").to(device)
    return {'pixel_values': pixel_values, 'labels': targets["input_ids"]}

train_dataset = train_ds.with_transform(preprocess)
valid_dataset = valid_ds.with_transform(preprocess)
test_dataset = test_ds.with_transform(preprocess)


# Turns the dataset into a torch DataLoader
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch])
    }

train_dataset_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
valid_dataset_loader = DataLoader(valid_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataset_loader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


## Creating the Model
Creates the PureT model from the paper

This section does the following actions:
1. Creates the SWIN Transformer used by PureT
2. Creates the PureT encoder
3. Creates the PureT decoder
4. Creates the PureT model

Download the pre-trained SwinT weights from here https://drive.google.com/drive/folders/1HBw5NGGw8DjkyNurksCP5v8a5f0FG7zU and put them in this folder before running

In [28]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torch import nn, einsum
import numpy as np
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import math
from functools import reduce
import torch.utils.checkpoint as checkpoint
from torch.autograd import Variable

##############################
#
# SWIN Transformer code
# from https://github.com/232525/PureT/blob/5581b5d10ae3bb9f9c859f6644e90db8beaf992b/models/backbone/swin_transformer_backbone.py#L458
#
##############################
def window_partition (x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


# Finds shifted windows
def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


# Window attention layer
class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 fused_window_process=False):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.

    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x
    

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x


class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030

    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.ape = ape
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # absolute position embedding
        if self.ape:
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer],
                               window_size=window_size,
                               mlp_ratio=self.mlp_ratio,
                               qkv_bias=qkv_bias, qk_scale=qk_scale,
                               drop=drop_rate, attn_drop=attn_drop_rate,
                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                               norm_layer=norm_layer,
                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
                               use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.norm = norm_layer(self.num_features)
        # self.avgpool = nn.AdaptiveAvgPool1d(1)
        # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        x = self.avgpool(x.transpose(1, 2))  # B C 1
        x = torch.flatten(x, 1)
        return x

    """
    # forward
    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x
    """
    
    # forward w/o head
    def forward(self, x):
        # extract image features，[B, L, D]
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        return x
    
    def load_weights(self, pretrained_model):
        checkpoint = torch.load(pretrained_model, map_location='cpu')
        # print(checkpoint['patch_embed.proj.bias'])
        self.load_state_dict(checkpoint, strict=False)


##############################
#
# PureT Encoder code
# from https://github.com/232525/PureT/blob/5581b5d10ae3bb9f9c859f6644e90db8beaf992b/models/encoder_decoder/PureT_encoder.py
#
##############################
class FeedForward(nn.Module):
    def __init__(self, embed_dim, ffn_embed_dim, relu_dropout = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, ffn_embed_dim)
        self.act = nn.ReLU()    # ReLU / GELU / CELU
        self.fc2 = nn.Linear(ffn_embed_dim, embed_dim)
        self.dropout = nn.Dropout(relu_dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class WindowAttentionEncoder(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
    """

    def __init__(self, embed_dim=512, window_size=(12, 12), num_heads=8, 
                 nW=4, ind_gx=True):
        super().__init__()
        self.embed_dim = embed_dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.nW = nW
        self.ind_gx = ind_gx # 是否在注意力机制内部单独计算全局特征

        # 相对位置编码，用于grid特征的每一个window
        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Wh-1 * 2*Ww-1, nH]

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)
        
        trunc_normal_(self.relative_position_bias_table, std=.02)
        
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.o_linear = nn.Linear(embed_dim, embed_dim)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # N = Ww * Wh      仅window区域的grid特征，Ww和Wh为window_size
        # N = Ww * Wh + 1  window区域的grid特征加图像全局特征
        B_, N, C = x.size()
        """
        print('*'*30)
        print('raw gx', x[:, -1, :].min(), x[:, -1, :].max(), x[:, -1, :].mean())
        print('raw all', x.min(), x.max(), x.mean())
        # """
        
        # [B*nW, nH, N, C//nH]，其中nW为window数量，nH为num_heads
        q = self.q_linear(x).view(B_, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(x).view(B_, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(x).view(B_, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        """
        print('gx q', q[:, :, -1, :].min(), q[:, :, -1, :].max(), q[:, :, -1, :].mean())
        print('gx k', k[:, :, -1, :].min(), k[:, :, -1, :].max(), k[:, :, -1, :].mean())
        print('gx v', v[:, :, -1, :].min(), v[:, :, -1, :].max(), v[:, :, -1, :].mean())
        print('all q', q.min(), q.max(), q.mean())
        print('all k', k.min(), k.max(), k.mean())
        print('all v', v.min(), v.max(), v.mean())
        # """
        
        # [B*nW, nH, N, N]，其中nW为window数量，nH为num_heads
        attn = (q @ k.transpose(-2, -1)) * self.scale
        # print(attn.min(), attn.max(), attn.mean())

        # 相对位置编码，仅window区域内的grid特征之间计算
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # [Wh*Ww, Wh*Ww, nH]
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Wh*Ww, Wh*Ww]
        # 如果加入了全局特征，相对位置编码与全局特征无关
        if N == self.window_size[0] * self.window_size[1]:
            attn = attn + relative_position_bias.unsqueeze(0)
        else:
            # 仅对window区域的grid特征部分嵌入相对位置编码
            attn[:, :, :-1, :-1] = attn[:, :, :-1, :-1] + relative_position_bias.unsqueeze(0)
        
        """
        print(relative_position_bias.min(), relative_position_bias.max(), relative_position_bias.mean())
        print(attn.min(), attn.max(), attn.mean())
        # """
                    
        # 此处mask用于区分SW-MSA/W-MSA
        # mask: [nW, N, N]，
        # 其中nW为window数量，N=Ww*Wh or Ww*Wh+1，Ww和Wh为window_size
        if mask is not None:
            # mask = mask.masked_fill(mask == float(-100), float(-1e9))
            nW = mask.shape[0]
            # attn: [B*nW, nH, N, N] --> [B, nW, nH, N, N]
            # mask: [nW, N, N]       --> [1, nW,  1, N, N]
            # print('IN', attn.view(B_ // nW, nW, self.num_heads, N, N)[0, 2, 0, 0, :])
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            # [B*nW, nH, N, N]
            attn = attn.view(-1, self.num_heads, N, N)
        else:
            attn = attn
        
        # TODO：代码精简
        # 单独处理全局特征，从attn中分离出全局特征权重，单独加权计算
        # [B*nW, nH, N]
        if self.ind_gx and N != self.window_size[0] * self.window_size[1]:
            # [B*nW, nH, N] -> [B, nW, nH, N] -> [B, nH, nW, N]
            gx_attn = attn[:, :, -1, :]
            gx_attn = gx_attn.view(B_ // self.nW, self.nW, self.num_heads, -1).permute(0, 2, 1, 3)
            # [B, nH, nW, N-1] --> [B, nH, nW * (N-1)] 即 [B, nH, H * W]
            gx_attn_1 = gx_attn[:, :, :, :-1].contiguous().view(B_ // self.nW, self.num_heads, -1)
            # [B, nH, nW, 1] --> [B, nH, 1]
            gx_attn_2 = gx_attn[:, :, :, -1:].mean(-2)
            # [B, nH, nW * (N-1) + 1] 即 [B, nH, H * W + 1]
            gx_attn = torch.cat([gx_attn_1, gx_attn_2], -1)
            # 全局特征权重 [B, nH, nW * (N-1) + 1] 即 [B, nH, H * W + 1]
            gx_attn = self.softmax(gx_attn)
            """
            if mask is None:
                # [B, 8, 145]
                # print(gx_attn.size())
                # print(gx_attn[0, 0, :-1].view(12, 12))
                print(gx_attn[:12, :, :].max(-1))
                if gx_attn[:, :, -1].max() > 0.001:
                    print('>>> gx alpha:', gx_attn[:, :, -1].max())
            # """
            # [B, nH, nW * (N-1)]
            gx_attn_1 = gx_attn[:, :, :-1] 
            # [B, nH, nW * (N-1)] --> [B, nH, nW, N-1] --> [B, nW, nH, N-1] --> [B*nW, nH, N-1]
            gx_attn_1 = gx_attn_1.view(B_ // self.nW, self.num_heads, self.nW , -1).permute(0, 2, 1, 3).contiguous().view(B_, self.num_heads, -1)
            # [B, nH, 1]
            gx_attn_2 = gx_attn[:, :, -1:] 
            # [B, nH, 1] --> [B, nH, nW, 1] --> [B, nW, nH, 1] --> [B*nW, nH, 1]
            gx_attn_2 = gx_attn_2.unsqueeze(-1).repeat(1, 1, self.nW, 1).permute(0, 2, 1, 3).contiguous().view(B_, self.num_heads, -1)
            # [B*nW, nH, N] --> [B*nW, nH, 1, N]
            gx_attn = torch.cat([gx_attn_1, gx_attn_2], -1).unsqueeze(-2)
            gx = (gx_attn @ v).transpose(1, 2).reshape(B_, C)
            # print(gx.size())
            gx = gx.view(B_ // self.nW, self.nW, -1).sum(1)
            # print(gx.size())
            gx = gx.unsqueeze(1).repeat(1, self.nW, 1).view(B_, C)
            # print(gx.size())
            
        # softmax计算权重
        attn = self.softmax(attn)
        """
        # [B*nW, 8, Ww*Wh, Ww*Wh]
        # print(attn.size())
        if attn[:12, :, :-1, -1].max() > 0.1:
            print(attn.size())
            print(attn[:12, :, 0, :].max(-1))
            # print(attn[:12, :, :-1, -1].max())
            # print(attn[:12, :, :-1, -1].max(-1))
            # print(attn[:8, :, :-1, -1].argmax(-1, keepdim=True))
        # """
        
        # 加权求和，[B*nW, N, C]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        # 替换掉x中全局特征，[B*nW, C]
        if self.ind_gx and N != self.window_size[0] * self.window_size[1]:
            x[:, -1, :] = gx
        x = self.o_linear(x)
        return x


class EncoderLayer(nn.Module):
    def __init__(
        self, 
        embed_dim=512, 
        input_resolution=(12, 12), 
        num_heads=8, 
        window_size=12,    # 窗口大小，如果窗口大小和输入一致，则退化为普通MSA
        shift_size=0,      # shift大小，0 OR window_size // 2
        mlp_ratio=4,       # FeedForward 中间层维度变换
        dropout=0.1,
        use_gx=False
    ):
        super(EncoderLayer, self).__init__()
        self.embed_dim = embed_dim                            # 1536
        self.input_resolution = input_resolution  # (12， 12)
        self.num_heads = num_heads                # 8
        self.window_size = window_size            # 12 / 6
        self.shift_size = shift_size   # shift_size可用于区分SW-MSA / W-MSA
        self.mlp_ratio = mlp_ratio     # 4
        self.use_gx = use_gx           # False
        self.nW = (input_resolution[0] // window_size)**2
        
        # if window size is larger than input resolution, 
        # we don't partition windows
        # 且window_size需要能够被input resolution整除，才能正确划分窗口
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        # 构造注意力核心操作层
        self.encoder_attn = WindowAttentionEncoder(
            embed_dim=embed_dim, 
            window_size=to_2tuple(self.window_size), 
            num_heads=num_heads,
            nW = self.nW
        )
        # dropout同时用于encoder_attn和ff_layer输出
        self.dropout = nn.Dropout(dropout) 
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        
        # 构造FeedForward层
        ffn_embed_dim = int(embed_dim * mlp_ratio)
        self.ff_layer = FeedForward(
            embed_dim = embed_dim, 
            ffn_embed_dim = ffn_embed_dim, 
            relu_dropout = dropout
        )
        self.layer_norm2 = nn.LayerNorm(embed_dim)

        # 此处mask为SW-MSA使用
        # [nW, w_s * w_s, w_s * w_s]
        # nW 为 window 数量，w_s 为 window_size
        # [4, 36, 36] 当input_resolution=(12, 12)，window_size=6时
        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))  # [1, H, W, 1]
            # 对 [H, W] 大小进行分区
            # 分区的目的在于，shift之后，进行window划分时，一个window内包含多个区域，可能彼此不相临，需要进行标号区分
            # 数字相同表示在shift之前区域相邻
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            
            """
            # 也可以按如下分区
            # 数字相同表示在shift之前区域相邻
            h_slices = (slice(0, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.shift_size),
                        slice(-self.shift_size, None))
            """
            
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1e9)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)
    
    def forward(self, x, att_mask=None):
        # x: query / key / value  [B, L, C] 其中，L = H * W
        # x为grid特征，一个batch内每个样本特征数量一致，注意力计算时无需mask标注
        # att_mask 为 None 即可，不参与计算
        H, W = self.input_resolution
        B, L, C = x.shape
        short_cut = x
        
        # 如果使用全局特征，需要划分出全局特征和grid特征
        if self.use_gx:
            assert L == (H * W +1), "input feature has wrong size"
            gx = x[:, -1, :]   # [B, C]
            x  = x[:, :-1, :]  # [B, H * W, C]
        else:
            assert L == H * W, "input feature has wrong size"
            gx = None
            x = x

        x = x.view(B, H, W, C)

        # 循环移位，SW-MSA核心操作，W-MSA时不做处理
        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # windows划分，比如 12x12 的区域被划分为4个 6 x 6 的windows
        # [B, H, W, C] --> [nW*B, window_size * window_size, C]
        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
        
        # 如果使用全局特征，需要嵌入 gx 到 x_windows 中，
        # 每个 window 内部的注意力机制加入了全局特征
        # 全局特征被复制了 nW 次，即窗口个数
        if self.use_gx:
            # [B, C] -> [B, 1, C] -> [B, nW, C] -> [B*nW, C]
            gx_ = gx.unsqueeze(1).repeat(1, x_windows.size()[0]//gx.size()[0], 1).view(x_windows.size()[0], -1)
            x_windows = torch.cat([x_windows, gx_.unsqueeze(1)], 1) # [B*nW, window_size*window_size + 1, C]
            
            # 对SW-MSA需要的mask进行扩充
            # 使用 torch.nn.functional.pad 填充
            if self.attn_mask is None:
                _mask = self.attn_mask
            else:
                _mask = F.pad(
                    self.attn_mask, 
                    pad=(0, 1, 0, 1, 0, 0),
                    mode='constant', 
                    value=0.0
                ) # [nW, window_size*window_size + 1, window_size*window_size + 1]
        else:
            _mask = self.attn_mask
            
        # W-MSA/SW-MSA
        # x_windows: [B*nW, Ww*Wh+1, C] 
        # mask: [nW, Ww*Wh+1, Ww*Wh+1,] or None
        attn_windows = self.encoder_attn(x_windows, mask=_mask)  # nW*B, window_size*window_size, C
        
        # 如果使用全局特征，需要从注意力机制输出（attn_windows）中拆分出 gx
        if self.use_gx:
            # 此处gx计算为4个windows内部的_gx的均值
            # 也可以在注意力机制内部进行处理（由注意力层的ind_gx参数控制
            # [B*nW, C] --> [B, nW, C] --> [B, C]
            gx = attn_windows[:, -1, :].view(-1, self.nW, 512).mean(1)
            # [B*nW, Ww*Wh, C]
            attn_windows = attn_windows[:, :-1, :]
        
        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)
        
        # 如果使用全局特征，需要再次合并gird特征和全局特征
        if self.use_gx:
            # [B, H*W+1, C]
            x = torch.cat([x, gx.unsqueeze(1)], dim=1)
        
        # 注意力后的残差
        x = self.dropout(x)
        x = self.layer_norm1(x + short_cut)
        
        # FeedForward及残差
        short_cut = x
        x = self.ff_layer(x)
        # dropout 残差 LayerNorm在此加入
        x = self.dropout(x)
        x = self.layer_norm2(x + short_cut)

        return x


class Encoder(nn.Module):
    def __init__(
        self, 
        embed_dim=512, 
        input_resolution=(12, 12), 
        depth=3, 
        num_heads=8, 
        window_size=12,  # =12 退化为普通MSA结构
        shift_size=6,    # =0  无SW-MSA，仅W-MSA
        mlp_ratio=4,
        dropout=0.1,
        use_gx=False
    ):
        super(Encoder, self).__init__()
        self.embed_dim = embed_dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_gx = use_gx
        
        # 构建 W-MSA / SW-MSA 层
        # 输入特征尺寸为 144 = 12 x 12，如果构建 SW-MSA 层，
        # 则需要将 window_size 设置得更小，比如设置为 6，且shift_size > 0
        # SW-MSA仅在偶数层被构造，W-MSA在奇数层构造
        # 如：W-MSA，SW-MSA，W-MSA，SW-MSA ......
        self.layers = nn.ModuleList([
            EncoderLayer(
                embed_dim=embed_dim, 
                input_resolution=input_resolution,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else shift_size,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                use_gx=use_gx
            ) for i in range(self.depth)
        ])
    
    def forward(self, x, att_mask=None):
        # x: [B, H*W, C]
        # 对于grid特征，att mask为None亦可
        # 全局特征初始化，图像特征均值 [B, C]
        if att_mask is not None:
            gx = (torch.sum(x * att_mask.unsqueeze(-1), 1) / torch.sum(att_mask.unsqueeze(-1), 1))
        else:
            gx = x.mean(1)
        
        # 如果使用全局特征，则需要将全局特征gx和grid特征x合并送入后续层处理
        if self.use_gx:
            O = torch.cat([x, gx.unsqueeze(1)], dim=1)  # [B, H*W+1, C]
        else:
            O = x
            
        # 核心操作层
        for layer in self.layers:
            O = layer(O, att_mask)
        
        if self.use_gx:
            gx = O[:, -1, :]
            x  = O[:, :-1, :]
        else:
            gx = O.mean(1)
            x = O
        return gx, x


##############################
#
# PureT Decoder code
# from https://github.com/232525/PureT/blob/5581b5d10ae3bb9f9c859f6644e90db8beaf992b/models/encoder_decoder/PureT_decoder.py
#
##############################
def position_embedding(input, d_model):
    input = input.view(-1, 1)
    dim = torch.arange(d_model // 2, dtype=torch.float32, device=input.device).view(1, -1)
    sin = torch.sin(input / 10000 ** (2 * dim / d_model))
    cos = torch.cos(input / 10000 ** (2 * dim / d_model))

    out = torch.zeros((input.shape[0], d_model), device=input.device)
    out[:, ::2] = sin
    out[:, 1::2] = cos
    return out


def sinusoid_encoding_table(max_len, d_model, padding_idx=None):
    pos = torch.arange(max_len, dtype=torch.float32)
    out = position_embedding(pos, d_model)

    if padding_idx is not None:
        out[padding_idx] = 0
    return out


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = self.head_dim ** -0.5
        
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)
        self.o_linear = nn.Linear(embed_dim, embed_dim)
        
        self.softmax = nn.Softmax(-1)
        
        self.clear_buffer()
        
    def init_buffer(self, batch_size):
        # [B, nH, 0, C/nH]
        self.buffer_key = torch.zeros((batch_size, self.num_heads, 0, self.head_dim), device='cuda')
        self.buffer_value = torch.zeros((batch_size, self.num_heads, 0, self.head_dim), device='cuda')
        
    def clear_buffer(self):
        self.buffer_key = None
        self.buffer_value = None
        
    def apply_to_states(self, fn):
        self.buffer_key = fn(self.buffer_key)
        self.buffer_value = fn(self.buffer_value)
    
    def forward(self, q, k, v, mask):
        """
        Decoder部分有两部分进行注意力：
            1）单词嵌入自注意力，q/k/v大小均为[B, L, D]
            2）单词嵌入与图像特征（包含全局特征）的cross attention，q的大小为[B, L, D]
               k/v的大小为[B, M+1, D]
        输出的维度大小只与q的维度大小相关
        """
        B_, N, C = q.size()
        # 线性变换
        q = self.q_linear(q).view(B_, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_linear(k).view(B_, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_linear(v).view(B_, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 存储buffer，用于inference时单词嵌入自注意力
        if self.buffer_key is not None and self.buffer_value is not None:
            self.buffer_key = torch.cat([self.buffer_key, k], dim=2)
            self.buffer_value = torch.cat([self.buffer_value, v], dim=2)
            k = self.buffer_key
            v = self.buffer_value
            
        # 注意力核心操作
        # [B, nH, L, L] or [B, nH, L, M+1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # 计算注意力权重
        if mask is not None:
            mask = mask.unsqueeze(1)
            attn = attn.masked_fill(mask == 0, -1e9)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
            
        out = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        out = self.o_linear(out)
        return out


class DecoderLayer(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, dropout=0.1, ff_dropout=0.1, use_gx=False):
        super(DecoderLayer, self).__init__()
        self.word_attn = MultiHeadSelfAttention(
            embed_dim = embed_dim, 
            num_heads = num_heads
        )
        self.layer_norm1 = nn.LayerNorm(embed_dim)

        self.cross_att = MultiHeadSelfAttention(
            embed_dim = embed_dim, 
            num_heads = num_heads
        )
        self.layer_norm2 = nn.LayerNorm(embed_dim)

        self.ff_layer = FeedForward(
            embed_dim = embed_dim, 
            ffn_embed_dim = embed_dim * 4, 
            relu_dropout = ff_dropout
        )
        self.layer_norm3 = torch.nn.LayerNorm(embed_dim)
        
        self.dropout = nn.Dropout(dropout)
        
        self.use_gx = use_gx
        if self.use_gx:
            # 方式2，concat接Linear / Linear+GLU
            self.fuse_layer = nn.Sequential(
                nn.Linear(embed_dim*2, embed_dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            )
            self.fuse_layer_norm = nn.LayerNorm(embed_dim)
	    

    def apply_to_states(self, fn):
        self.word_attn.apply_to_states(fn)

    def init_buffer(self, batch_size):
        self.word_attn.init_buffer(batch_size)

    def clear_buffer(self):
        self.word_attn.clear_buffer()

    def precompute(self, encoder_out):
        # key, value2 = self.cross_att.precompute(encoder_out, encoder_out)
        # return key, value2
        pass

    def forward(self, gx, x, encoder_out, seq_mask, att_mask=None):
        # 单词嵌入自注意力
        # short_cut = x
        # 在单词嵌入自注意力阶段，嵌入图像的全局特征
        # 方式2:concat接Linear+GLU / Linear
        if self.use_gx:
            x_cat = torch.cat([x, gx.unsqueeze(1).expand_as(x)], dim=-1)
            x = self.fuse_layer(x_cat) + x
            x = self.fuse_layer_norm(x)
        short_cut = x
        
        x = self.word_attn(
            q = x,
            k = x,
            v = x,
            mask = seq_mask
        )
        x = self.dropout(x)
        x = self.layer_norm1(x + short_cut)

        # 单词嵌入与图像特征（可包含全局特征）cross 注意力
        short_cut = x
        if self.use_gx:
            kv = torch.cat([encoder_out, gx.unsqueeze(1)], 1)
            if att_mask is not None:
                # [B, 1, M+1]，对于grid特征，直接设置为None亦可
                _att_mask = torch.cat(
                    [att_mask, torch.ones(att_mask.size(0), device='cuda').unsqueeze(1).unsqueeze(1)], 2
                ).long()
            else:
                _att_mask = None
        else:
            kv = encoder_out
            _att_mask = att_mask
            
        x = self.cross_att(
            q = x,
            k = kv,
            v = kv,
            mask = _att_mask,
            # precompute=False
        )
        x = self.dropout(x)
        x = self.layer_norm2(x + short_cut)
        
        # Feedforward
        short_cut = x
        x = self.ff_layer(x)
        x = self.dropout(x)
        x = self.layer_norm3(x + short_cut)
        
        return x


class Decoder(nn.Module):
    def __init__(
        self, 
        vocab_size, 
        embed_dim=512, 
        depth=3,
        num_heads=8,
        dropout=0.1, 
        ff_dropout=0.1, 
        use_gx=False
    ):
        super(Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.layers = nn.ModuleList([])
        self.embed_dim = embed_dim
        self.use_gx = use_gx
        for i in range(depth):
            sublayer = DecoderLayer( 
                embed_dim = embed_dim, 
                num_heads = num_heads, 
                dropout = dropout, 
                ff_dropout = ff_dropout,
                use_gx = use_gx
            )
            self.layers.append(sublayer)
            
        self.dropout = nn.Dropout(0.0)
        
        self.word_embed = nn.Embedding(self.vocab_size, self.embed_dim)
        self.embed_scale = math.sqrt(self.embed_dim)
        self.pos_embed = nn.Embedding.from_pretrained(
            sinusoid_encoding_table(100, self.embed_dim, 0), freeze=True
        )
        
        self.generator = nn.Linear(self.embed_dim, self.vocab_size, bias=True)
                
        self.clear_buffer()

    def init_buffer(self, batch_size):
        self.seq_len = 0
        for layer in self.layers:
            layer.init_buffer(batch_size)

    def clear_buffer(self):
        self.seq_len = None
        for layer in self.layers:
            layer.clear_buffer()

    def apply_to_states(self, fn):
        for layer in self.layers:
            layer.apply_to_states(fn)

    def precompute(self, encoder_out):
        p_att_feats = []
        for layer in self.layers:
            key, value2 = layer.precompute(encoder_out)
            p_att_feats.append((key, value2))
        return p_att_feats

    def forward(self, gx, seq, encoder_out, seq_mask=None, att_mask=None):
        if att_mask is not None:
            att_mask = att_mask.unsqueeze(1)  # [B, 1, M]
        
        seq_len = seq.size()[1]
        pos_indx = torch.arange(1, seq_len + 1, device='cuda').view(1, -1)
        if self.seq_len is not None:
            seq_len = self.seq_len + seq_len
            self.seq_len = seq_len
            pos_indx = torch.arange(seq_len, seq_len + 1, device='cuda').view(1, -1)
            
        # 词汇嵌入 + 位置嵌入
        # [B, seq_len, C] for training or [B, 1, C] for inference
        x = self.embed_scale * self.word_embed(seq) + self.pos_embed(pos_indx)
        
        for layer in self.layers:
            x = layer(gx, x, encoder_out, seq_mask, att_mask)

        x = self.dropout(x)
        out = self.generator(x)
        return out



##############################
#
# PureT model code
# from  https://github.com/232525/PureT/blob/5581b5d10ae3bb9f9c859f6644e90db8beaf992b/models/basic_model.py
# and   https://github.com/232525/PureT/blob/5581b5d10ae3bb9f9c859f6644e90db8beaf992b/models/pure_transformer.py
#
##############################
class BasicModel(nn.Module):
    def __init__(self):
        super(BasicModel, self).__init__()

    def select(self, batch_size, beam_size, t, candidate_logprob):
        selected_logprob, selected_idx = torch.sort(candidate_logprob.view(batch_size, -1), -1, descending=True)
        selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size]
        return selected_idx, selected_logprob

    def beam_search(self, init_state, init_logprobs, **kwargs):
        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda, bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][prev_decisions[prev_labels]] = logprobsf[sub_beam][prev_decisions[prev_labels]] - diversity_lambda
            return unaug_logprobsf

        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
            #INPUTS:
            #logprobsf: probabilities augmented after diversity
            #beam_size: obvious
            #t        : time instant
            #beam_seq : tensor contanining the beams
            #beam_seq_logprobs: tensor contanining the beam logprobs
            #beam_logprobs_sum: tensor contanining joint logprobs
            #OUPUTS:
            #beam_seq : tensor containing the word indices of the decoded captions
            #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            #beam_logprobs_sum : joint log-probability of each beam

            ys,ix = torch.sort(logprobsf,1,True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols): # for each column (word, essentially)
                for q in range(rows): # for each beam expansion
                    #compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q,c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    local_unaug_logprob = unaug_logprobsf[q,ix[q,c]]
                    candidates.append({'c':ix[q,c], 'q':q, 'p':candidate_logprob, 'r':local_unaug_logprob})
            candidates = sorted(candidates,  key=lambda x: -x['p'])
            
            new_state = [_.clone() for _ in state]
            #beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
            #we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                #fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:, v['q']]
                #rearrange recurrent states
                for state_ix in range(len(new_state)):
                #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v['q']] # dimension one is time step
                #append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c'] # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r'] # the raw logprob here
                beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam
            state = new_state
            return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,candidates

        beam_size = kwargs['BEAM_SIZE']
        group_size = 1 #kwargs['GROUP_SIZE']
        diversity_lambda = 0.5 #kwargs['DIVERSITY_LAMBDA']
        constraint = False #kwargs['CONSTRAINT']
        max_ppl = False #kwargs['MAX_PPL']
        bdash = beam_size // group_size

        beam_seq_table = [torch.LongTensor(max_length, bdash).zero_() for _ in range(group_size)]
        beam_seq_logprobs_table = [torch.FloatTensor(max_length, bdash).zero_() for _ in range(group_size)]
        beam_logprobs_sum_table = [torch.zeros(bdash) for _ in range(group_size)]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[] for _ in range(group_size)]
        state_table = [list(torch.unbind(_)) for _ in torch.stack(init_state).chunk(group_size, 2)]
        logprobs_table = list(init_logprobs.chunk(group_size, 0))
        # END INIT

        for t in range(max_length + group_size - 1):
            for divm in range(group_size): 
                if t >= divm and t <= max_length + divm - 1:
                    # add diversity
                    logprobsf = logprobs_table[divm].data.float()
                    # suppress previous word
                    if constraint and t-divm > 0:
                        logprobsf.scatter_(1, beam_seq_table[divm][t-divm-1].unsqueeze(1).cuda(), float('-inf'))
                    # suppress UNK tokens in the decoding
                    logprobsf[:,logprobsf.size(1)-1] -= 1000  
                    # diversity is added here
                    # the function directly modifies the logprobsf values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    unaug_logprobsf = add_diversity(beam_seq_table,logprobsf,t,divm,diversity_lambda,bdash)

                    # infer new beams
                    beam_seq_table[divm],\
                    beam_seq_logprobs_table[divm],\
                    beam_logprobs_sum_table[divm],\
                    state_table[divm],\
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                bdash,
                                                t-divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for vix in range(bdash):
                        if beam_seq_table[divm][t-divm,vix] == 0 or t == max_length + divm - 1:
                            final_beam = {
                                'seq': beam_seq_table[divm][:, vix].clone(), 
                                'logps': beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p': beam_seq_logprobs_table[divm][:, vix].sum().item(),
                                'p': beam_logprobs_sum_table[divm][vix].item()
                            }
                            if max_ppl:
                                final_beam['p'] = final_beam['p'] / (t-divm+1)
                            done_beams_table[divm].append(final_beam)
                            # don't continue beams from finished sequences
                            beam_logprobs_sum_table[divm][vix] = -1000

                    # move the current group one step forward in time
                    wt = beam_seq_table[divm][t-divm]
                    kwargs['WT'] = wt.cuda()
                    kwargs['STATE'] = state_table[divm]
                    logprobs_table[divm], state_table[divm] = self.get_logprobs_state(**kwargs)

        # all beams are sorted by their log-probabilities
        done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
        done_beams = reduce(lambda a,b:a+b, done_beams_table)
        return done_beams


def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1)
    return subsequent_mask == 0


def expand_tensor(tensor, size, dim=1):
    if size == 1 or tensor is None:
        return tensor
    tensor = tensor.unsqueeze(dim)
    tensor = tensor.expand(list(tensor.shape[:dim]) + [size] + list(tensor.shape[dim+1:])).contiguous()
    tensor = tensor.view(list(tensor.shape[:dim-1]) + [-1] + list(tensor.shape[dim+1:]))
    return tensor


class PureT(BasicModel):
    def __init__(self):
        super(PureT, self).__init__()
        self.vocab_size = vocab_size + 1
        
        self.backbone = SwinTransformer(
            img_size=384, 
            embed_dim=192, 
            depths=[2, 2, 18, 2],
            num_heads=[6, 12, 24, 48],
            window_size=12,
            num_classes=1000
        )
        print('load pretrained weights!')
        self.backbone.load_weights(
            './swin_large_patch4_window12_384_22kto1k_no_head.pth'
        )
        # Freeze parameters
        for _name, _weight in self.backbone.named_parameters():
            _weight.requires_grad = False
            # print(_name, _weight.requires_grad)
        
        # raw Dimension to Model Dimension
        if att_feats_dim == embedded_dim_size:
            self.att_embed = nn.Identity()
        else:
            self.att_embed = nn.Sequential(
                nn.Linear(att_feats_dim, embedded_dim_size),
                nn.ReLU(inplace=True),
                nn.LayerNorm(embedded_dim_size) if False == True else nn.Identity(),
                nn.Dropout(0.0)
            )
        
        use_gx = True
        self.encoder = Encoder(
            embed_dim=embedded_dim_size, 
            input_resolution=(12, 12), 
            depth=encoder_decoder_layers, 
            num_heads=encoder_decoder_heads, 
            window_size=6,
            shift_size=3,
            mlp_ratio=4,
            dropout=0.1,
            use_gx = use_gx
        )
        
        self.decoder = Decoder(
            vocab_size = self.vocab_size, 
            embed_dim = embedded_dim_size, 
            depth = encoder_decoder_layers,
            num_heads = encoder_decoder_heads, 
            dropout = 0.0, 
            ff_dropout = 0.1,
            use_gx = use_gx
        )
        
    def forward(self, **kwargs):
        att_feats = kwargs['ATT_FEATS']
        seq = kwargs['INPUT_SENT']
        
        # backbone forward
        att_feats = self.backbone(att_feats)
        
        # att_mask for features
        att_mask = kwargs['ATT_FEATS_MASK']
        att_mask = expand_tensor(att_mask, 5)
        att_feats = expand_tensor(att_feats, 5)

        # words mask [B, L, L]
        ##############################################
        seq_mask = (seq > 0).type(torch.cuda.IntTensor)
        seq_mask[:,0] += 1
        seq_mask = seq_mask.unsqueeze(-2)
        seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask)
        seq_mask = seq_mask.type(torch.cuda.FloatTensor)
        ##############################################

        att_feats = self.att_embed(att_feats)
        gx, encoder_out = self.encoder(att_feats, att_mask)
        decoder_out = self.decoder(gx, seq, encoder_out, seq_mask, att_mask)
        return F.log_softmax(decoder_out, dim=-1)

    def get_logprobs_state(self, **kwargs):
        wt = kwargs['WT']
        state = kwargs['STATE']
        encoder_out = kwargs['ATT_FEATS']
        
        att_mask = kwargs['ATT_FEATS_MASK']
        gx = kwargs['GV_FEAT']
        # p_att_feats = kwargs[cfg.PARAM.P_ATT_FEATS]

        # state[0][0]: [B, seq_len-1]，previously generated words
        # ys: [B, seq_len]
        if state is None:
            ys = wt.unsqueeze(1)
        else:
            ys = torch.cat([state[0][0], wt.unsqueeze(1)], dim=1)
            
        seq_mask = subsequent_mask(ys.size(1)).to(encoder_out.device).type(torch.cuda.FloatTensor)[:, -1, :].unsqueeze(1)
        
        # [B, 1, Vocab_Size] --> [B, Vocab_Size]
        decoder_out = self.decoder(gx, ys[:, -1].unsqueeze(-1), encoder_out, seq_mask, att_mask).squeeze(1)
        
        logprobs = F.log_softmax(decoder_out, dim=-1)
        return logprobs, [ys.unsqueeze(0)]

    def _expand_state(self, batch_size, beam_size, cur_beam_size, selected_beam):
        def fn(s):
            shape = [int(sh) for sh in s.shape]
            beam = selected_beam
            for _ in shape[1:]:
                beam = beam.unsqueeze(-1)
            s = torch.gather(s.view(*([batch_size, cur_beam_size] + shape[1:])), 1,
                             beam.expand(*([batch_size, beam_size] + shape[1:])))
            s = s.view(*([-1, ] + shape[1:]))
            return s
        return fn

    # the beam search code is inspired by https://github.com/aimagelab/meshed-memory-transformer
    def decode_beam(self, **kwargs):
        att_feats = kwargs['ATT_FEATS']
        att_mask = kwargs['ATT_FEATS_MASK']
        beam_size = kwargs['BEAM_SIZE']
        batch_size = att_feats.size(0)
        seq_logprob = torch.zeros((batch_size, 1, 1)).cuda()
        log_probs = []
        selected_words = None
        seq_mask = torch.ones((batch_size, beam_size, 1)).cuda()

        att_feats = self.backbone(att_feats)
        att_feats = self.att_embed(att_feats)
        gx, encoder_out = self.encoder(att_feats, att_mask)
        # p_att_feats = self.decoder.precompute(encoder_out)

        state = None
        wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda())
        kwargs['ATT_FEATS'] = encoder_out
        kwargs['GV_FEAT'] = gx
        # kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats

        outputs = []
        self.decoder.init_buffer(batch_size)
        for t in range(max_length):
            cur_beam_size = 1 if t == 0 else beam_size

            kwargs['WT'] = wt
            kwargs['STATE'] = state
            word_logprob, state = self.get_logprobs_state(**kwargs)
            # [B*cur_beam_size, Vocab_size] --> [B, cur_beam_size, Vocab_size]
            word_logprob = word_logprob.view(batch_size, cur_beam_size, -1)
            # sum of logprob
            # [B, cur_beam_size, Vocab_size]
            candidate_logprob = seq_logprob + word_logprob

            # Mask sequence if it reaches EOS
            if t > 0:
                mask = (selected_words.view(batch_size, cur_beam_size) != 0).float().unsqueeze(-1)
                seq_mask = seq_mask * mask
                word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
                old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
                old_seq_logprob[:, :, 1:] = -999
                candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)

            # [B, beam_size], [B, beam_size]
            selected_idx, selected_logprob = self.select(batch_size, beam_size, t, candidate_logprob)
            selected_beam = selected_idx // candidate_logprob.shape[-1]
            selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]

            # update buffer
            self.decoder.apply_to_states(self._expand_state(batch_size, beam_size, cur_beam_size, selected_beam))
            seq_logprob = selected_logprob.unsqueeze(-1)
            seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))
            outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
            outputs.append(selected_words.unsqueeze(-1))

            this_word_logprob = torch.gather(word_logprob, 1,
                selected_beam.unsqueeze(-1).expand(batch_size, beam_size, word_logprob.shape[-1]))
            this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
            log_probs = list(
                torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(batch_size, beam_size, 1)) for o in log_probs)
            log_probs.append(this_word_logprob)
            selected_words = selected_words.view(-1, 1)
            wt = selected_words.squeeze(-1)

            if t == 0:
                # expand input
                encoder_out = expand_tensor(encoder_out, beam_size)
                gx = expand_tensor(gx, beam_size)
                att_mask = expand_tensor(att_mask, beam_size)
                state[0] = state[0].squeeze(0)
                state[0] = expand_tensor(state[0], beam_size)
                state[0] = state[0].unsqueeze(0)

                # p_att_feats_tmp = []
                # for p_feat in p_att_feats:
                #     p_key, p_value2 = p_feat
                #     p_key = utils.expand_tensor(p_key, beam_size)
                #     p_value2 = utils.expand_tensor(p_value2, beam_size)
                #     p_att_feats_tmp.append((p_key, p_value2))

                kwargs['ATT_FEATS'] = encoder_out
                kwargs['GV_FEAT'] = gx
                kwargs['ATT_FEATS_MASK'] = att_mask
                # kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats_tmp
 
        seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
        outputs = torch.cat(outputs, -1)
        outputs = torch.gather(outputs, 1, sort_idxs.expand(batch_size, beam_size, max_length))
        log_probs = torch.cat(log_probs, -1)
        log_probs = torch.gather(log_probs, 1, sort_idxs.expand(batch_size, beam_size, max_length))

        outputs = outputs.contiguous()[:, 0]
        log_probs = log_probs.contiguous()[:, 0]

        self.decoder.clear_buffer()
        return outputs, log_probs

    def decode(self, **kwargs):
        beam_size = kwargs['BEAM_SIZE']
        greedy_decode = kwargs['GREEDY_DECODE']
        att_feats = kwargs['ATT_FEATS']
        att_mask = kwargs['ATT_FEATS_MASK']

        batch_size = att_feats.size(0)
        att_feats = self.backbone(att_feats)
        att_feats = self.att_embed(att_feats)
        gx, encoder_out = self.encoder(att_feats, att_mask)
        # p_att_feats = self.decoder.precompute(encoder_out)
        self.decoder.init_buffer(batch_size)
        
        state = None
        sents = Variable(torch.zeros((batch_size, max_length), dtype=torch.long).cuda())
        logprobs = Variable(torch.zeros(batch_size, max_length).cuda())
        wt = Variable(torch.zeros(batch_size, dtype=torch.long).cuda())
        unfinished = wt.eq(wt)
        kwargs['ATT_FEATS'] = encoder_out
        kwargs['GV_FEAT'] = gx
        # kwargs[cfg.PARAM.P_ATT_FEATS] = p_att_feats
        
        # inference word by word
        for t in range(max_length):
            kwargs['WT'] = wt
            kwargs['STATE'] = state
            logprobs_t, state = self.get_logprobs_state(**kwargs)
            
            if greedy_decode:
                logP_t, wt = torch.max(logprobs_t, 1)
            else:
                probs_t = torch.exp(logprobs_t)
                wt = torch.multinomial(probs_t, 1)
                logP_t = logprobs_t.gather(1, wt)
            wt = wt.view(-1).long()
            unfinished = unfinished * (wt > 0)
            wt = wt * unfinished.type_as(wt)
            sents[:,t] = wt
            logprobs[:,t] = logP_t.view(-1)

            if unfinished.sum() == 0:
                break
        self.decoder.clear_buffer()
        return sents, logprobs

# torch.Size([32, 3, 224, 224])                                                   


## Training loop
Trains the model and saves the best (lowest val error) and last model

This section does the following actions:
1. Creates the PureT model
2. Sets up optimizer, scheduler, counter for training
3. Trains for num_epochs epochs
2. Each Epoch has valadation accuracy calculated TODO
3. Save the model with the best valadation accuracy TODO
4. Save the model when the max number of epochs has been reached TODO

In [33]:
import torch
import numpy as np
import os
import argparse

from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from transformers import VisionEncoderDecoderModel
from torchvision import transforms


# Model setup
# Not sure where vocab_size (maybe from the gpt-2 tokenizer?) or embed_dim come from yet
# But num_heads and window_size are from Table 6 of the paper https://arxiv.org/pdf/2203.15350
model = PureT()

# Setup for training
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, patience=patience)

# Values to remember training performance
stop_counter = 0
train_losses = []
best_val_loss = float('inf')
val_losses = []

# Training loop
for epoch in range(num_epochs):
    # Epoch setup
    model.train()
    train_loss = 0.0

    # Loop through data loader batches
    train_dataloader_iter = tqdm(train_dataset_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    for i, data in enumerate(train_dataloader_iter):
        
        # Get values from data loader
        pixel_vals = data["pixel_values"].to(device)
        captions = data["labels"].to(device)

        # Generate outputs
        optimizer.zero_grad()
        outputs = model(images=pixel_vals, captions=captions)
        loss = outputs.loss

        # Grad descent
        loss.backwards()
        optimizer.step()

        train_loss += loss.item()


load pretrained weights!


RuntimeError: Trying to create tensor with negative dimension -1: [-1, -1]

## Post Training Metrics
TODO ALL
Evalutes the best model on BLEU, ROUGE, and SPICE

This section does the following actions:
1. Loads the model with the highest valadation accuracy
2. Predict all captions with best model
3. Calculates ROUGE score
4. Calculates BLEU score
5. Calculates SPICE score

In [29]:
import evaluate
from transformers import EvalPrediction

# Run through valadation set with best model
predictions = []
labels = []
with torch.no_grad():
       for data in valid_dataset_loader:

              # get data from batch
              pixel_vals = data["pixel_values"].to(device)
              labels = data["labels"].to(device)
       
              # Predict captions
              outputs = model(images=pixel_vals, captions=labels)

              # Format labels
              logits = outputs.logits.detach().cpu() #not sure about logits
              predictions.extend(logits.argmax(dim=-1).tolist())
              labels.extend(labels.tolist())
    

# Format predictions into Hugging Face class
eval_predictions = EvalPrediction(predictions=predictions, label_ids=labels)

predictions = eval_predictions.predictions
labels = eval_predictions.label_ids

# Tokenize predictions and reference captions
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)


# Load test evaluators
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")

# Compute and print Rouge-1, Rogue-2, RougeL
rouge_result = rouge.compute(predictions=predictions_str, references=labels_str)
rouge_result = {k: round(v * 100, 4) for k, v in rouge_result.items()}
print ("ROUGE Metrics: \nROUGE-1: " + rouge_result.get("rouge1", 0) + 
       "\nROUGE-2: " + rouge_result.get("rouge2", 0) + 
       "\nROUGE-L: " + rouge_result.get("rougeL", 0))


# Compute and print BLEU metrics
bleu_result = bleu.compute(predictions=predictions_str, references=labels_str)
bleu_score = round(bleu_result["bleu"] * 100, 4)
print ("BLEU Metrics: " + bleu_score)

NameError: name 'model' is not defined