In [1]:
import torch
from torch import nn
from torch.nn import functional as F



def pad_image_to_divisible(image, p):
    # 获取图片的高度和宽度
    _, _, h, w = image.shape
    
    # 计算需要的 padding 大小
    pad_h = (p - h % p) % p
    pad_w = (p - w % p) % p
    
    # 对图片进行 padding，使用 `torch.nn.functional.pad`
    # `pad` 参数的格式是 (pad_left, pad_right, pad_top, pad_bottom)
    padding = (0, pad_w, 0, pad_h)
    padded_image = F.pad(image, padding, mode='constant', value=0)
    
    return padded_image




class MLP(nn.Module):
    def __init__(self, num_features, num_hidden, dropout):
        super().__init__()

        self.fc1 = nn.Linear(num_features, num_hidden)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(num_hidden, num_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout1(F.gelu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return x





class TokenMixer(nn.Module):
    def __init__(self, num_patches, embedding_dim, patch_dim, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(embedding_dim)
        self.mlp = MLP(num_patches, patch_dim, dropout)

    def forward(self, x):

        residual = x
        x = self.norm(x)
        x = x.transpose(1, 2)

        x = self.mlp(x)
        x = x.transpose(1, 2)

        out = x + residual
        return out


class ChannelMixer(nn.Module):
    def __init__(self, embedding_dim, filter_dim, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(embedding_dim)
        self.mlp = MLP(embedding_dim, filter_dim, dropout)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        x = self.norm(x)
        x = self.mlp(x)
        # x.shape == (batch_size, num_patches, num_features)
        out = x + residual
        return out



class MixerLayer(nn.Module):
    def __init__(self, num_patches, embedding_dim, patch_dim, filter_dim, dropout):
        super().__init__()
        self.token_mixer = TokenMixer(
            num_patches, embedding_dim, patch_dim, dropout
        )
        self.channel_mixer = ChannelMixer(
            embedding_dim, filter_dim, dropout
        )

    def forward(self, x):

        x = self.token_mixer(x)
        x = self.channel_mixer(x)

        return x
    
class Vision_MIXER(nn.Module):
    def __init__(
        self,
        input_shape,
        patch_size,
        number_class,
        embedding_dim = 512,
        patch_dim = 256,
        filter_dim = 2048,
        layer_nr = 10,
        dropout = 0.1
    ):
        super(Vision_MIXER,self).__init__()
        batch,_,window_length,sensor_channel_nr = input_shape
        temp_x = torch.randn(1, 1, window_length, sensor_channel_nr)
        y = pad_image_to_divisible(temp_x,patch_size)
        padded_window_length = y.shape[2]
        padded_sensor_channel = y.shape[3]
        self.num_patches = int(padded_window_length/patch_size)*int(padded_sensor_channel/patch_size)
        self.padded_window_length     = padded_window_length
        self.padded_sensor_channel    = padded_sensor_channel
        self.patch_size               = patch_size
        self.number_class             = number_class
        self.embedding_dim            = embedding_dim
        self.patch_dim                = patch_dim
        self.filter_dim               = filter_dim
        self.layer_nr                 = layer_nr

        self.patcher = nn.Conv2d(
            1, self.embedding_dim, kernel_size=self.patch_size, stride=self.patch_size
        )

        self.mixers = nn.Sequential(
            *[
                MixerLayer(self.num_patches, self.embedding_dim, self.patch_dim, self.filter_dim , dropout)
                for _ in range(self.layer_nr)
            ]
        )

        self.classifier = nn.Linear(embedding_dim, number_class)
    
    def forward(self, x):
        x = pad_image_to_divisible(x,self.patch_size)
        x = self.patcher(x)
        batch_size, num_features, _, _ = x.shape
        x = x.permute(0, 2, 3, 1)
        x = x.view(batch_size, -1, num_features)

        x = self.mixers(x)
        # embedding.shape == (batch_size, num_patches, num_features)
        x = x.mean(dim=1)
        logits = self.classifier(x)
        return logits



In [2]:
window_length = 128
sensor_channel = 9
input_shape=(1,1,window_length,sensor_channel)
patch_size=3
number_class=10
embedding_dim=512
patch_dim=256
filter_dim=2048
layer_nr=10

model = Vision_MIXER( input_shape,        patch_size,        number_class)

x = torch.randn(input_shape)
y = model(x)

In [2]:
from torch import nn
from torch.nn import functional as F


class MLP(nn.Module):
    def __init__(self, num_features, expansion_factor, dropout):
        super().__init__()
        num_hidden = num_features * expansion_factor
        self.fc1 = nn.Linear(num_features, num_hidden)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(num_hidden, num_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout1(F.gelu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return x


class TokenMixer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(num_features)
        self.mlp = MLP(num_patches, expansion_factor, dropout)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        x = self.norm(x)
        x = x.transpose(1, 2)
        # x.shape == (batch_size, num_features, num_patches)
        x = self.mlp(x)
        x = x.transpose(1, 2)
        # x.shape == (batch_size, num_patches, num_features)
        out = x + residual
        return out


class ChannelMixer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(num_features)
        self.mlp = MLP(num_features, expansion_factor, dropout)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        x = self.norm(x)
        x = self.mlp(x)
        # x.shape == (batch_size, num_patches, num_features)
        out = x + residual
        return out


class MixerLayer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.token_mixer = TokenMixer(
            num_patches, num_features, expansion_factor, dropout
        )
        self.channel_mixer = ChannelMixer(
            num_patches, num_features, expansion_factor, dropout
        )

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        x = self.token_mixer(x)
        x = self.channel_mixer(x)
        # x.shape == (batch_size, num_patches, num_features)
        return x


def check_sizes(image_size, patch_size):
    sqrt_num_patches, remainder = divmod(image_size, patch_size)
    assert remainder == 0, "`image_size` must be divisibe by `patch_size`"
    num_patches = sqrt_num_patches ** 2
    return num_patches


class MLPMixer(nn.Module):
    def __init__(
        self,
        image_size=256,
        patch_size=16,
        in_channels=3,
        num_features=128,
        expansion_factor=2,
        num_layers=8,
        num_classes=10,
        dropout=0.5,
    ):
        num_patches = check_sizes(image_size, patch_size)
        super().__init__()
        # per-patch fully-connected is equivalent to strided conv2d
        self.patcher = nn.Conv2d(
            in_channels, num_features, kernel_size=patch_size, stride=patch_size
        )
        self.mixers = nn.Sequential(
            *[
                MixerLayer(num_patches, num_features, expansion_factor, dropout)
                for _ in range(num_layers)
            ]
        )
        self.classifier = nn.Linear(num_features, num_classes)

    def forward(self, x):
        patches = self.patcher(x)
        batch_size, num_features, _, _ = patches.shape
        patches = patches.permute(0, 2, 3, 1)
        patches = patches.view(batch_size, -1, num_features)
        # patches.shape == (batch_size, num_patches, num_features)
        embedding = self.mixers(patches)
        # embedding.shape == (batch_size, num_patches, num_features)
        embedding = embedding.mean(dim=1)
        logits = self.classifier(embedding)
        return logits

In [3]:

import torch

# 创建一个 3x3 的 tensor，元素服从标准正态分布
x = torch.randn(1,1,224,224)

num_features = 32
patch_size = 16
patcher = nn.Conv2d(
            1, num_features, kernel_size=patch_size, stride=patch_size
        )

y = patcher(x)

batch_size, num_features, _, _ = y.shape
patches = y.permute(0, 2, 3, 1)
patches = patches.view(batch_size, -1, num_features)

In [4]:
def pad_image_to_divisible(image, p):
    # 获取图片的高度和宽度
    _, _, h, w = image.shape
    
    # 计算需要的 padding 大小
    pad_h = (p - h % p) % p
    pad_w = (p - w % p) % p
    
    # 对图片进行 padding，使用 `torch.nn.functional.pad`
    # `pad` 参数的格式是 (pad_left, pad_right, pad_top, pad_bottom)
    padding = (0, pad_w, 0, pad_h)
    padded_image = F.pad(image, padding, mode='constant', value=0)
    
    return padded_image

# 示例图片张量，形状为 (batch_size, channels, height, width)
image = torch.randn(1, 1, 8, 5)

# patch 的大小
p = 4

y = pad_image_to_divisible(image,p)

In [14]:
def pad_image_to_divisible(image, p):
    # 获取图片的高度和宽度
    _, _, h, w = image.shape
    
    # 计算需要的 padding 大小
    pad_h = (p - h % p) % p
    pad_w = (p - w % p) % p
    
    # 对图片进行 padding，使用 `torch.nn.functional.pad`
    # `pad` 参数的格式是 (pad_left, pad_right, pad_top, pad_bottom)
    padding = (0, pad_w, 0, pad_h)
    padded_image = F.pad(image, padding, mode='constant', value=0)
    
    return padded_image




class MLP(nn.Module):
    def __init__(self, num_features, num_hidden, dropout):
        super().__init__()

        self.fc1 = nn.Linear(num_features, num_hidden)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(num_hidden, num_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout1(F.gelu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return x





class TokenMixer(nn.Module):
    def __init__(self, num_patches, embedding_dim, patch_dim, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(embedding_dim)
        self.mlp = MLP(num_patches, patch_dim, dropout)

    def forward(self, x):

        residual = x
        x = self.norm(x)
        x = x.transpose(1, 2)

        x = self.mlp(x)
        x = x.transpose(1, 2)

        out = x + residual
        return out


class ChannelMixer(nn.Module):
    def __init__(self, embedding_dim, filter_dim, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(embedding_dim)
        self.mlp = MLP(embedding_dim, filter_dim, dropout)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        x = self.norm(x)
        x = self.mlp(x)
        # x.shape == (batch_size, num_patches, num_features)
        out = x + residual
        return out



class MixerLayer(nn.Module):
    def __init__(self, num_patches, embedding_dim, patch_dim, filter_dim, dropout):
        super().__init__()
        self.token_mixer = TokenMixer(
            num_patches, embedding_dim, patch_dim, dropout
        )
        self.channel_mixer = ChannelMixer(
            embedding_dim, filter_dim, dropout
        )

    def forward(self, x):

        x = self.token_mixer(x)
        x = self.channel_mixer(x)

        return x
    
class Vision_MIXER(nn.Module):
    def __init__(
        self,
        input_shape,
        patch_size,
        number_class,
        embedding_dim,
        patch_dim,
        filter_dim,
        layer_nr,
        dropout = 0.1
    ):
        super(Vision_MIXER,self).__init__()
        batch,_,window_length,sensor_channel_nr = input_shape
        temp_x = torch.randn(1, 1, window_length, sensor_channel_nr)
        y = pad_image_to_divisible(temp_x,patch_size)
        padded_window_length = y.shape[2]
        padded_sensor_channel = y.shape[3]
        self.num_patches = int(padded_window_length/patch_size)*int(padded_sensor_channel/patch_size)
        self.padded_window_length     = padded_window_length
        self.padded_sensor_channel    = padded_sensor_channel
        self.patch_size               = patch_size
        self.number_class             = number_class
        self.embedding_dim            = embedding_dim
        self.patch_dim                = patch_dim
        self.filter_dim               = filter_dim
        self.layer_nr                 = layer_nr

        self.patcher = nn.Conv2d(
            1, self.embedding_dim, kernel_size=self.patch_size, stride=self.patch_size
        )

        self.mixers = nn.Sequential(
            *[
                MixerLayer(self.num_patches, self.embedding_dim, self.patch_dim, self.filter_dim , dropout)
                for _ in range(self.layer_nr)
            ]
        )

        self.classifier = nn.Linear(embedding_dim, number_class)
    
    def forward(self, x):
        x = pad_image_to_divisible(x,self.patch_size)
        x = self.patcher(x)
        batch_size, num_features, _, _ = x.shape
        x = x.permute(0, 2, 3, 1)
        x = x.view(batch_size, -1, num_features)

        x = self.mixers(x)
        # embedding.shape == (batch_size, num_patches, num_features)
        x = x.mean(dim=1)
        logits = self.classifier(x)
        return logits

In [18]:
window_length = 128
sensor_channel = 9
input_shape=(1,1,window_length,sensor_channel)
patch_size=3
number_class=10
embedding_dim=512
patch_dim=256
filter_dim=2048
layer_nr=10

model = Vision_MIXER( input_shape,        patch_size,        number_class,        embedding_dim,        patch_dim,        filter_dim,        layer_nr)

x = torch.randn(input_shape)
y = model(x)

In [19]:
from ptflops import get_model_complexity_info

In [21]:
macs, params = get_model_complexity_info(model, 
                                         (1, window_length, sensor_channel), 
                                         as_strings=True, print_per_layer_stat=False, verbose=False)

In [22]:
params

'21.69 M'

In [23]:
macs

'3.05 GMac'

In [31]:
128/9

14.222222222222221