In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import numpy as np

# Cross Channel Interaction

## attention  interaction

In [2]:
# def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
#     """
#     Create and initialize a `nn.Conv1d` layer with spectral normalization.
#     """
#     conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
#     nn.init.kaiming_normal_(conv.weight)
#     if bias:
#         conv.bias.data.zero_()
#     return conv

# class SelfAttention_interaction(nn.Module):
#     """

#     """
#     def __init__(self, n_channels: int, div=1):
#         super(SelfAttention_interaction, self).__init__()

#         if n_channels > 1:
#             self.query = conv1d(n_channels, n_channels//div)
#             self.key = conv1d(n_channels, n_channels//div)
#         else:
#             self.query = conv1d(n_channels, n_channels)
#             self.key = conv1d(n_channels, n_channels)
#         self.value = conv1d(n_channels, n_channels)
#         self.gamma = nn.Parameter(torch.tensor([0.]))

#     def forward(self, x):
#         # Notation from https://arxiv.org/pdf/1805.08318.pdf
#         # 输入尺寸是 batch feature_dim sensor_channel


#         f, g, h = self.query(x), self.key(x), self.value(x)
        
#         beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
        
#         o = self.gamma * torch.bmm(h, beta) + x
#         # 输出的尺寸是 batch feature_dim sensor_channel 1
#         return o.unsqueeze(3)

In [3]:
class SelfAttention_interaction(nn.Module):
    """

    """
    def __init__(self, n_channels):
        super(SelfAttention_interaction, self).__init__()

        self.query = nn.Linear(n_channels, n_channels, bias=False)
        self.key = nn.Linear(n_channels, n_channels, bias=False)
        self.value = nn.Linear(n_channels, n_channels, bias=False)
        self.gamma = nn.Parameter(torch.tensor([0.]))

    def forward(self, x):

        # 输入尺寸是 batch  sensor_channel feature_dim
        #print(x.shape)

        f, g, h = self.query(x), self.key(x), self.value(x)
        
        beta = F.softmax(torch.bmm(f, g.permute(0, 2, 1).contiguous()), dim=1)

        o = self.gamma * torch.bmm(h.permute(0, 2, 1).contiguous(), beta) + x.permute(0, 2, 1).contiguous()
        # 输出是 batch  sensor_channel feature_dim 1 
        return o.permute(0, 2, 1).contiguous()

## transformer interaction

In [4]:

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 16, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer_interaction(nn.Module):
    def __init__(self, dim, depth=1, heads=4, dim_head=16, mlp_dim=16, dropout = 0.):
        super(Transformer_interaction,self).__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))

    def forward(self, x):


        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return x

## Identity

In [5]:
class Identity(nn.Module):
    def __init__(self, n_channels):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

In [6]:
crosschannel_interaction = {"attn":SelfAttention_interaction,
                            "transformer": Transformer_interaction,
                            "identity": Identity}

# Cross Channel Aggregation

## FilterWeighted_Aggregation

In [7]:
class FilterWeighted_Aggregation(nn.Module):
    """

    """
    def __init__(self, n_channels):
        super(FilterWeighted_Aggregation, self).__init__()
        self.value_projection = nn.Linear(n_channels, n_channels)
        self.value_activation = nn.ReLU() 
        
        self.weight_projection = nn.Linear(n_channels, n_channels)
        self.weighs_activation = nn.Tanh() 
        self.softmatx = nn.Softmax(dim=1)        

        
        
    def forward(self, x):
        
        # 输入是 batch  sensor_channel feature_dim


        weights = self.weighs_activation(self.weight_projection(x))
        weights = self.softmatx(weights)
        
        values  = self.value_activation(self.value_projection(x))

        values  = torch.mul(values, weights)
        # 返回是 batch feature_dim
        return torch.sum(values,dim=1)

## NaiveWeighted_Aggregation

In [8]:
class NaiveWeighted_Aggregation(nn.Module):
    """
    Temporal attention module
    """
    def __init__(self, hidden_dim):
        super(NaiveWeighted_Aggregation, self).__init__()
        self.fc = nn.Linear(hidden_dim, 1)
        self.sm = torch.nn.Softmax(dim=1)

    def forward(self, x):

        # 输入是 batch  sensor_channel feature_dim
        #   B C F

        out = self.fc(x).squeeze(2)

        weights_att = self.sm(out).unsqueeze(2)
        context = torch.sum(weights_att * x, 1)
        return context

## Reshape FC

In [9]:
class FC(nn.Module):

    def __init__(self, channel_in, channel_out):
        super(FC, self).__init__()
        self.fc = nn.Linear(channel_in ,channel_out)

    def forward(self, x):
        x = self.fc(x)
        return(x)

In [10]:
crosschannel_aggregation = {"filter": FilterWeighted_Aggregation,
                            "naive" : NaiveWeighted_Aggregation,
                            "FC" : FC}

# Tempotal Info Interaction

## GRU

In [11]:
class temporal_GRU(nn.Module):
    """

    """
    def __init__(self, filter_num):
        super(temporal_GRU, self).__init__()
        self.rnn = nn.GRU(
            filter_num,
            filter_num,
            1,
            bidirectional=False,
            dropout=0.15,
            batch_first = True
        )
    def forward(self, x):
        # Batch length Filter
        outputs, h = self.rnn(x)
        return outputs

## LSTM

In [12]:
class temporal_LSTM(nn.Module):
    """

    """
    def __init__(self, filter_num):
        super(temporal_LSTM, self).__init__()
        self.lstm = nn.LSTM(filter_num, 
                            filter_num, 
                            batch_first =True)
    def forward(self, x):
        # Batch length Filter
        outputs, h = self.lstm(x)
        return outputs

In [13]:
temporal_interaction = {"gru": temporal_GRU,
                        "lstm": temporal_LSTM,
                        "attn"   :SelfAttention_interaction,
                        "transformer": Transformer_interaction,
                        "identity" : Identity}

# Temporal Aggregation

In [14]:
temmporal_aggregation = {"filter": FilterWeighted_Aggregation,
                         "naive" : NaiveWeighted_Aggregation,
                         "FC" : FC,
                         "identiry":Identity}

In [20]:
class Light_HAR_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 

        filter_num = 16, # 所有层拥有同样的dim

        nb_conv_layers = 4,        
        filter_size = 5,
        
        cross_channel_interaction_type = "attn",    # attn  transformer  identity
        
        cross_channel_aggregation_type = "filter",  # filter  naive  FC
         
        temporal_info_interaction_type = "gru",     # gru  lstm  attn  transformer  identity
        
        temporal_info_aggregation_type = "FC",      # naive  filter  FC 

        dropout = 0.2,
        activation = "ReLU",

    ):
        super(Light_HAR_Model, self).__init__()
        
        
        self.cross_channel_interaction_type = cross_channel_interaction_type
        self.cross_channel_aggregation_type = cross_channel_aggregation_type
        self.temporal_info_interaction_type = temporal_info_interaction_type
        self.temporal_info_aggregation_type = temporal_info_aggregation_type
        
        
        """
        PART 1 , ============= Channel wise Feature Extraction =============================        
        输入的格式为  Batch, filter_num, length, Sensor_channel        
        输出格式为为  Batch, filter_num, downsampling_length, Sensor_channel
        """

        layers_conv = []
        for i in range(nb_conv_layers):
            if i == 0:
                in_channel = input_shape[1]
            else:
                in_channel = filter_num
    
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(2,1)),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),

            ))
        self.layers_conv = nn.ModuleList(layers_conv)
        # 这是给最后时间维度 vectorize的时候用的
        downsampling_length = self.get_the_shape(input_shape)        
        

        """
        PART2 , ================ Cross Channel interaction  =================================
        这里可供选择的  attn   transformer  itentity
        输出格式为  Batch, filter_num, downsampling_length, Sensor_channel
        
        """

        self.channel_interaction = crosschannel_interaction[cross_channel_interaction_type](filter_num)
        # 这里还是 B F C L  需要permute++++++++++++++

        

        """
        PART3 , =============== Cross Channel Fusion  ====================================
        这里可供选择的  filter   naive  FC

        输出格式为  Batch, downsampling_length, filter_num
        """
        if cross_channel_aggregation_type == "FC":
            # 这里需要reshape为 B L C*F++++++++++++++
            self.channel_fusion = crosschannel_aggregation[cross_channel_aggregation_type](input_shape[3]*filter_num,filter_num)

        else:
            # 这里需要沿着时间轴走
            self.channel_fusion = crosschannel_aggregation[cross_channel_aggregation_type](filter_num)
            # --> B F L
            # 需要reshape++++++++++++++++++++++++++++++

            
        # BLF
        self.activation = nn.ReLU() 


        """
        PART4  , ============= Temporal information Extraction =========================
        这里可供选择的  gru lstm attn transformer   identity

        输出格式为  Batch, downsampling_length, filter_num
        """
        
        # ++++++++++++ 这里需要讨论
        self.temporal_interaction = temporal_interaction[temporal_info_interaction_type](filter_num)
        
        
        """
        PART 5 , =================== Temporal information Aggregation ================


        输出格式为  Batch, downsampling_length, filter_num
        """        

        self.dropout = nn.Dropout(dropout)
        
        if temporal_info_aggregation_type == "FC":
            self.flatten = nn.Flatten()
            self.temporal_fusion = temmporal_aggregation[temporal_info_aggregation_type](downsampling_length*filter_num,filter_num)
        else:
            self.temporal_fusion = temmporal_aggregation[temporal_info_aggregation_type](filter_num)
            
        #--> B F

        # PART 6 , ==================== Prediction ==============================
        self.prediction = nn.Linear(filter_num ,number_class)

    def get_the_shape(self, input_shape):
        x = torch.rand(input_shape)

        for layer in self.layers_conv:
            x = layer(x)    

        return x.shape[2]
        


    def forward(self, x):
        # B F L C   
        for layer in self.layers_conv:
            x = layer(x)

        x = x.permute(0,3,2,1) 
        # ------->  B x C x L* x F*       



        """ =============== cross channel interaction ==============="""
        x = torch.cat(
            [self.channel_interaction(x[:, :, t, :]).unsqueeze(3) for t in range(x.shape[2])],
            dim=-1,
        )
        # ------->  B x C x F* x L* 
        
        x = self.dropout(x)

        """=============== cross channel fusion ==============="""
        
        if self.cross_channel_aggregation_type == "FC":
            x = x.permute(0, 3, 1, 2)
            x = x.reshape(x.shape[0], x.shape[1], -1)
            x = self.activation(self.channel_fusion(x)) # B L C
        else:
            x = torch.cat(
                [self.channel_fusion(x[:, :, :, t]).unsqueeze(2) for t in range(x.shape[3])],
                dim=-1,
            )
            x = x.permute(0,2,1)
            x = self.activation(x)
        # ------->  B x L* x F*
            
            
        """cross temporal interaction """
        x = self.temporal_interaction(x)


        
        """cross temporal fusion """
        if self.temporal_info_aggregation_type == "FC":
            x = self.flatten(x)
            x = self.activation(self.temporal_fusion(x)) # B L C
        else:
            x = self.temporal_fusion(x)
        
        

        y = self.prediction(x)
        return y

In [21]:
batch         = 1
number_filter = 12
length        = 128
channel       = 6


model = Light_HAR_Model((batch, number_filter, length, channel),
                        
                        filter_num = 16, # 所有层拥有同样的dim

                        nb_conv_layers = 4,        
                        filter_size = 5,

                        cross_channel_interaction_type = "attn",    # attn  transformer  identity

                        cross_channel_aggregation_type = "FC",  # filter  naive  FC

                        temporal_info_interaction_type = "lstm",     # gru  lstm  attn  transformer  identity

                        temporal_info_aggregation_type = "FC",      # naive  filter  FC 
                        
                        number_class=6).double()
input = torch.rand(batch,number_filter ,length, channel).double()
model(input)

tensor([[ 0.1937,  0.0957,  0.1372,  0.2366, -0.2333, -0.1543]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

In [24]:
batch         = 1
number_filter = 12
length        = 128
channel       = 77

for a in ["attn" , "transformer" , "identity"]:
    for b in ["filter",  "naive",  "FC"]:
        for c in ["gru",  "lstm",  "attn",  "transformer" , "identity"]:
            for d in ["naive",  "filter",  "FC" ]:
                
                print(a," ", b ," ",c, " ",d)
    
                model = Light_HAR_Model((batch, number_filter, length, channel),

                                        filter_num = 32, # 所有层拥有同样的dim

                                        nb_conv_layers = 4,        
                                        filter_size = 5,

                                        cross_channel_interaction_type = a,    # attn  transformer  identity

                                        cross_channel_aggregation_type = b,  # filter  naive  FC

                                        temporal_info_interaction_type = c,     # gru  lstm  attn  transformer  identity

                                        temporal_info_aggregation_type = d,      # naive  filter  FC 

                                        number_class=6).double()
        
                input = torch.rand(batch,number_filter ,length, channel).double()
                out = model(input)
                print(np.sum([para.numel() for para in model.parameters()]))

attn   filter   gru   naive
29416
attn   filter   gru   filter
31495
attn   filter   gru   FC
34535
attn   filter   lstm   naive
31528
attn   filter   lstm   filter
33607
attn   filter   lstm   FC
36647
attn   filter   attn   naive
26153
attn   filter   attn   filter
28232
attn   filter   attn   FC
31272
attn   filter   transformer   naive
32504
attn   filter   transformer   filter
34583
attn   filter   transformer   FC
37623
attn   filter   identity   naive
23080
attn   filter   identity   filter
25159
attn   filter   identity   FC
28199
attn   naive   gru   naive
27337
attn   naive   gru   filter
29416
attn   naive   gru   FC
32456
attn   naive   lstm   naive
29449
attn   naive   lstm   filter
31528
attn   naive   lstm   FC
34568
attn   naive   attn   naive
24074
attn   naive   attn   filter
26153
attn   naive   attn   FC
29193
attn   naive   transformer   naive
30425
attn   naive   transformer   filter
32504
attn   naive   transformer   FC
35544
attn   naive   identity   naive
21001

# import numpy as np
np.sum([para.numel() for para in model.parameters()])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
    """
    Create and initialize a `nn.Conv1d` layer with spectral normalization.
    """
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias:
        conv.bias.data.zero_()
    # return spectral_norm(conv)
    return conv

class SelfAttention_crosschannel_interaction(nn.Module):
    """

    """
    def __init__(self, n_channels: int, div):
        super(SelfAttention_crosschannel_interaction, self).__init__()

        if n_channels > 1:
            self.query = conv1d(n_channels, n_channels//div)
            self.key = conv1d(n_channels, n_channels//div)
        else:
            self.query = conv1d(n_channels, n_channels)
            self.key = conv1d(n_channels, n_channels)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(torch.tensor([0.]))

    def forward(self, x):
        # Notation from https://arxiv.org/pdf/1805.08318.pdf
        size = x.size()
        #print("size+",size)
        x = x.view(*size[:2], -1)
        #print("size-",x.size())
        f, g, h = self.query(x), self.key(x), self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size).contiguous()

    
class SelfAttention_Aggregation(nn.Module):
    """

    """
    def __init__(self, n_channels):
        super(SelfAttention_Aggregation, self).__init__()
        self.value_projection = nn.Linear(n_channels, n_channels)
        self.weight_projection = nn.Linear(n_channels, n_channels)
        self.softmatx = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        weights = self.weight_projection(x)
        weights = self.softmatx(weights)
        values  = self.value_projection(x)
        values  = torch.mul(values, weights)
        return torch.sum(values,dim=1).unsqueeze(2)

        
        
        
class TemporalAttention(nn.Module):
    """
    Temporal attention module
    """
    def __init__(self, hidden_dim):
        super(TemporalAttention, self).__init__()
        self.fc = nn.Linear(hidden_dim, 1)
        self.sm = torch.nn.Softmax(dim=0)

    def forward(self, x):
        out = self.fc(x).squeeze(2)
        weights_att = self.sm(out).unsqueeze(2)
        context = torch.sum(weights_att * x, 0)
        return context
    
class CFC_V4_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 16,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        hidden_dim = 16,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_V4_Model, self).__init__()
        
        # PART 1 , ============= Channel wise Feature Extraction =============================
        
        layers_conv = []
        for i in range(nb_conv_layers):
            if i == 0:
                in_channel = input_shape[1]
            else:
                in_channel = filter_num
    
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(2,1)),#(2,1)
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),

            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , ================ Cross Channel interaction  =================================
        self.dropout = nn.Dropout(dropout)
        self.channel_interaction = SelfAttention_crosschannel_interaction(filter_num, sa_div)
        

        # PART3 , =============== Cross Channel Fusion  ====================================

        
        self.channel_fusion = SelfAttention_Aggregation(filter_num)
    
        # PART4  , ============= Temporal information Extraction =========================

        self.rnn = nn.GRU(
            filter_num,
            hidden_dim,
            2,
            bidirectional=False,
            dropout=0.15,
        )
        # PART 5 , =================== Temporal information Aggregation ================
        self.temporal_fusion = TemporalAttention(hidden_dim)

        # PART 6 , ==================== Prediction ==============================
        self.prediction = nn.Linear(hidden_dim ,number_class)


        


    def forward(self, x):
        # B F L C   F==1 or F==Nds+1
        for layer in self.layers_conv:
            x = layer(x)      

        # apply self-attention on each temporal dimension (along sensor and feature dimensions)

        x = torch.cat(
            [self.channel_interaction(torch.unsqueeze(x[:, :, t, :], dim=3)) for t in range(x.shape[2])],
            dim=-1,
        )
        # x B F C L

        x = x.permute(0,2,1,3)
        # refined B C F L

        x = torch.cat(
            [self.channel_fusion(x[:, :, :,t]) for t in range(x.shape[3])],
            dim=-1,
        )
        # refined B F L
        

        x = self.dropout(x)
        x = x.permute(2,0,1)
        x, h = self.rnn(x) # L B  F

        x = self.temporal_fusion(x)
        y = self.prediction(x)
        return y

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
m = nn.Softmax(dim=1)
input = torch.randn(1, 4, 2)
output = m(input)
#torch.sum(output,dim=1)
output

tensor([[[0.1149, 0.1393],
         [0.1138, 0.7365],
         [0.4094, 0.1016],
         [0.3620, 0.0226]]])

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention_Aggregation(nn.Module):
    """

    """
    def __init__(self, n_channels):
        super(SelfAttention_Aggregation, self).__init__()
        self.value_projection = nn.Linear(n_channels, n_channels)
        self.weight_projection = nn.Linear(n_channels, n_channels)
        self.softmatx = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        weights = self.weight_projection(x)
        weights = self.softmatx(weights)
        values  = self.value_projection(x)
        values  = torch.mul(values, weights)
        return torch.sum(values,dim=1)
a = SelfAttention_Aggregation(2).double()
input = torch.rand(1,10,2).double()
a(input)

tensor([[ 1.1948, -0.1908]], dtype=torch.float64, grad_fn=<SumBackward1>)

In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
    """
    Create and initialize a `nn.Conv1d` layer with spectral normalization.
    """
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias:
        conv.bias.data.zero_()
    # return spectral_norm(conv)
    return conv

class SelfAttention_crosschannel_interaction(nn.Module):
    """

    """
    def __init__(self, n_channels: int, div):
        super(SelfAttention_crosschannel_interaction, self).__init__()

        if n_channels > 1:
            self.query = conv1d(n_channels, n_channels//div)
            self.key = conv1d(n_channels, n_channels//div)
        else:
            self.query = conv1d(n_channels, n_channels)
            self.key = conv1d(n_channels, n_channels)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(torch.tensor([0.]))

    def forward(self, x):
        # Notation from https://arxiv.org/pdf/1805.08318.pdf
        size = x.size()
        #print("size+",size)
        x = x.view(*size[:2], -1)
        #print("size-",x.size())
        f, g, h = self.query(x), self.key(x), self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size).contiguous()

    
class SelfAttention_Aggregation(nn.Module):
    """

    """
    def __init__(self, n_channels):
        super(SelfAttention_Aggregation, self).__init__()
        self.value_projection = nn.Linear(n_channels, n_channels)
        self.weight_projection = nn.Linear(n_channels, n_channels)
        self.softmatx = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        weights = self.weight_projection(x)
        weights = self.softmatx(weights)
        values  = self.value_projection(x)
        values  = torch.mul(values, weights)
        return torch.sum(values,dim=1).unsqueeze(2)

        
        
        
class TemporalAttention(nn.Module):
    """
    Temporal attention module
    """
    def __init__(self, hidden_dim):
        super(TemporalAttention, self).__init__()
        self.fc = nn.Linear(hidden_dim, 1)
        self.sm = torch.nn.Softmax(dim=0)

    def forward(self, x):
        out = self.fc(x).squeeze(2)
        weights_att = self.sm(out).unsqueeze(2)
        context = torch.sum(weights_att * x, 0)
        return context
    
class CFC_V4_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 16,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        hidden_dim = 16,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_V4_Model, self).__init__()
        
        # PART 1 , ============= Channel wise Feature Extraction =============================
        
        layers_conv = []
        for i in range(nb_conv_layers):
            if i == 0:
                in_channel = input_shape[1]
            else:
                in_channel = filter_num
    
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(2,1)),#(2,1)
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),

            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , ================ Cross Channel interaction  =================================
        self.dropout = nn.Dropout(dropout)
        self.channel_interaction = SelfAttention_crosschannel_interaction(filter_num, sa_div)
        

        # PART3 , =============== Cross Channel Fusion  ====================================

        
        self.channel_fusion = SelfAttention_Aggregation(filter_num)
    
        # PART4  , ============= Temporal information Extraction =========================

        self.rnn = nn.GRU(
            filter_num,
            hidden_dim,
            2,
            bidirectional=False,
            dropout=0.15,
        )
        # PART 5 , =================== Temporal information Aggregation ================
        self.temporal_fusion = TemporalAttention(hidden_dim)

        # PART 6 , ==================== Prediction ==============================
        self.prediction = nn.Linear(hidden_dim ,number_class)


        


    def forward(self, x):
        # B F L C   F==1 or F==Nds+1
        for layer in self.layers_conv:
            x = layer(x)      

        # apply self-attention on each temporal dimension (along sensor and feature dimensions)

        x = torch.cat(
            [self.channel_interaction(torch.unsqueeze(x[:, :, t, :], dim=3)) for t in range(x.shape[2])],
            dim=-1,
        )
        # x B F C L

        x = x.permute(0,2,1,3)
        # refined B C F L

        x = torch.cat(
            [self.channel_fusion(x[:, :, :,t]) for t in range(x.shape[3])],
            dim=-1,
        )
        # refined B F L
        

        x = self.dropout(x)
        x = x.permute(2,0,1)
        x, h = self.rnn(x) # L B  F

        x = self.temporal_fusion(x)
        y = self.prediction(x)
        return y

In [27]:
Batch = 2
F_in  = 50
Leng  = 128
C_in = 6
number_class = 6
model = CFC_V4_Model((Batch,F_in,Leng, C_in),number_class,filter_num = 16).double()
input = torch.rand(Batch,F_in,Leng, C_in).double()
model(input)

torch.Size([2, 6, 16])
torch.Size([2, 6, 16])
torch.Size([2, 6, 16])
torch.Size([2, 6, 16])
torch.Size([2, 6, 16])


tensor([[-0.0572,  0.1141, -0.0173,  0.2453,  0.0325, -0.1014],
        [-0.0833,  0.1265, -0.0186,  0.2403,  0.0312, -0.1002]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

In [28]:
import numpy as np
np.sum([para.numel() for para in model.parameters()])

12728

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
    """
    Create and initialize a `nn.Conv1d` layer with spectral normalization.
    """
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias:
        conv.bias.data.zero_()
    # return spectral_norm(conv)
    return conv

class SelfAttention(nn.Module):
    """

    """
    def __init__(self, n_channels: int, div):
        super(SelfAttention, self).__init__()

        if n_channels > 1:
            self.query = conv1d(n_channels, n_channels//div)
            self.key = conv1d(n_channels, n_channels//div)
        else:
            self.query = conv1d(n_channels, n_channels)
            self.key = conv1d(n_channels, n_channels)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(torch.tensor([0.]))

    def forward(self, x):
        # Notation from https://arxiv.org/pdf/1805.08318.pdf
        size = x.size()
        #print("size+",size)
        x = x.view(*size[:2], -1)
        #print("size-",x.size())
        f, g, h = self.query(x), self.key(x), self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size).contiguous()
    
class CFC_V3_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 16,
        hidden_dim = 16,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_V3_Model, self).__init__()
        
        # PART 1 , Channel wise Feature Extraction
        
        layers_conv = []
        for i in range(nb_conv_layers):
        
            if i == 0:
                in_channel = input_shape[1]
            else:
                in_channel = filter_num
    
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(2,1)),#(2,1)
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),

            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , Cross Channel Fusion through Attention
        self.dropout = nn.Dropout(dropout)

        self.sa = SelfAttention(filter_num, sa_div)
        


        # PART 3 , Prediction 
        
        self.activation = nn.ReLU() 
        self.fc1 = nn.Linear(input_shape[3]*filter_num ,filter_num)

        self.rnn = nn.GRU(
            filter_num,
            hidden_dim,
            2,
            bidirectional=False,
            dropout=0.15,
        )

        self.prediction = nn.Linear(hidden_dim ,number_class)


        


    def forward(self, x):
        # B ? L C
        # x = x.unsqueeze(1)
        
        
        for layer in self.layers_conv:
            x = layer(x)      


        batch, filter, length, channel = x.shape


        # apply self-attention on each temporal dimension (along sensor and feature dimensions)
        refined = torch.cat(
            [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3)) for t in range(x.shape[2])],
            dim=-1,
        )


       # print(refined.shape)

        x = refined.permute(0, 3, 1, 2)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = self.dropout(x)
        
        x = self.activation(self.fc1(x)) # B L F
        x = x.permute(1,0,2)

        outputs, h = self.rnn(x) # L B  F
        x = outputs[-1, :, :]
        y = self.prediction(x)    
        return y

In [19]:
Batch = 2
F_in  = 50
Leng  = 128
C_in = 6
number_class = 6
model = CFC_V3_Model((Batch,F_in,Leng, C_in),number_class,filter_num = 16).double()
input = torch.rand(Batch,F_in,Leng, C_in).double()
model(input)

tensor([[ 0.2748, -0.1322, -0.0927,  0.0486,  0.0153,  0.2842],
        [ 0.2229, -0.1758, -0.1539,  0.0666,  0.0259,  0.3083]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

In [20]:
import numpy as np
np.sum([para.numel() for para in model.parameters()])

13719

In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F


from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# -------------- Transformer Encoder -----------
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 16, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class AggregationAttention(nn.Module):
    """
    Temporal attention module
    """
    def __init__(self, hidden_dim):
        super(AggregationAttention, self).__init__()
        self.fc = nn.Linear(hidden_dim, 1)
        self.sm = torch.nn.Softmax(dim=0)

    def forward(self, x):
        out = self.fc(x).squeeze(2)
        weights_att = self.sm(out).unsqueeze(2)
        context = torch.sum(weights_att * x, 0)
        return context

class Transformer(nn.Module):
    def __init__(self, dim, depth=1, heads=4, dim_head=16, mlp_dim=16, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
        self.aggregation = AggregationAttention(dim)
    def forward(self, x):
        # B F C 1
        size = x.size()
        # --> B C F
        x = x.view(*size[:2], -1).permute(0,2,1)

        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        # --- B C F 
        #x = self.aggregation(x.permute(1,0,2))
        return x.permute(0,2,1).unsqueeze(3)





    
class CFC_V2_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 32,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        hidden_dim = 32,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_V2_Model, self).__init__()
        
        # PART 1 , Channel wise Feature Extraction
        
        layers_conv = []
        for i in range(nb_conv_layers):
        
            if i == 0:
                in_channel = input_shape[1]
            else:
                in_channel = filter_num
            if i%2==0:
                stride = 2
            else:
                stride = 2
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(stride,1)),#(2,1)
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),
            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , Cross Channel Fusion through Attention
        self.dropout = nn.Dropout(dropout)
        #self.sa = SelfAttention(filter_num, sa_div)
        self.channel_aggregation = Transformer(dim=filter_num, depth=1, heads=4, dim_head=16, mlp_dim=filter_num, dropout = 0.)



        self.rnn = nn.GRU(
            filter_num*input_shape[3],
            hidden_dim,
            2,
            bidirectional=False,
            dropout=0.15,
        )
        
        #self.temporal_aggregation = AggregationAttention(hidden_dim)

#         # PART 3 , Prediction 
        self.prediction = nn.Linear(hidden_dim, number_class)


    def forward(self, x):
        # B F , L C  F =1 or F =filter * scale + 1

        
        for layer in self.layers_conv:
            x = layer(x)      
        # B filter_num  L  C

        batch, filter, length, channel = x.shape
        # 每次进去的都是 B filter_num C 1
        refined = torch.cat(
            [self.channel_aggregation(torch.unsqueeze(x[:, :, t, :], dim=3)) for t in range(x.shape[2])],
            dim=-1,
        )
        # B F C L

        print(refined.shape)
        x = refined.permute(0, 3, 1, 2)

        x = x.reshape(x.shape[0], x.shape[1], -1)
        print(x.shape)
        # B L  F*C
        x = self.dropout(x)
        x = x.permute(1,0,2)
        print(x.shape)
        outputs, h = self.rnn(x)

        # L B F
        x = outputs[-1, :, :]
        print(outputs.shape)
        # B  F 

        y = self.prediction(x)    
        return y

In [30]:
Batch = 2
F_in  = 100
Leng  = 256
C_in = 6
number_class = 6
model = CFC_V2_Model((Batch,F_in,Leng, C_in),number_class,filter_num = 31).double()
input = torch.rand(Batch,F_in,Leng, C_in).double()
model(input)

torch.Size([2, 31, 6, 13])
torch.Size([2, 13, 186])
torch.Size([13, 2, 186])
torch.Size([13, 2, 32])


tensor([[ 0.0557, -0.4428, -0.1118, -0.0205,  0.1939, -0.1840],
        [ 0.1468,  0.0677, -0.0788,  0.1941,  0.1573,  0.0115]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [2]:
# -------------- Transformer Encoder -----------
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 16, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class AggregationAttention(nn.Module):
    """
    Temporal attention module
    """
    def __init__(self, hidden_dim):
        super(AggregationAttention, self).__init__()
        self.fc = nn.Linear(hidden_dim, 1)
        self.sm = torch.nn.Softmax(dim=0)

    def forward(self, x):
        out = self.fc(x).squeeze(2)
        weights_att = self.sm(out).unsqueeze(2)
        context = torch.sum(weights_att * x, 0)
        return context

class Transformer(nn.Module):
    def __init__(self, dim, depth=1, heads=4, dim_head=16, mlp_dim=16, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
        self.aggregation = AggregationAttention(dim)
    def forward(self, x):
        # B F C 1
        size = x.size()
		# --> B C F
        x = x.view(*size[:2], -1).permute(0,2,1)

        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        # --- B C F 
        x = self.aggregation(x.permute(1,0,2))
        return x.unsqueeze(2)

In [3]:
class CFC_V1_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 32,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        hidden_dim = 32,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_V1_Model, self).__init__()
        
        # PART 1 , Channel wise Feature Extraction
        
        layers_conv = []
        for i in range(nb_conv_layers):
        
            if i == 0:
                in_channel = input_shape[1]
            else:
                in_channel = filter_num
            if i%2==0:
                stride = 2
            else:
                stride = 2
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(stride,1)),#(2,1)
                nn.ReLU(inplace=True),
                #nn.BatchNorm2d(filter_num),
            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , Cross Channel Fusion through Attention
        self.dropout = nn.Dropout(dropout)
        #self.sa = SelfAttention(filter_num, sa_div)
        self.channel_aggregation = Transformer(dim=filter_num, depth=1, heads=4, dim_head=16, mlp_dim=filter_num, dropout = 0.)



        self.rnn = nn.GRU(
            filter_num,
            hidden_dim,
            2,
            bidirectional=False,
            dropout=0.15,
        )
        
        self.temporal_aggregation = AggregationAttention(hidden_dim)

#         # PART 3 , Prediction 
        self.prediction = nn.Linear(hidden_dim, number_class)


    def forward(self, x):
        # B F , L C  F =1 or F =filter * scale + 1

        
        for layer in self.layers_conv:
            x = layer(x)      
        # B filter_num  L  C

        batch, filter, length, channel = x.shape
        # 每次进去的都是 B filter_num C 1
        refined = torch.cat(
            [self.channel_aggregation(torch.unsqueeze(x[:, :, t, :], dim=3)) for t in range(x.shape[2])],
            dim=-1,
        )
        # B F L
        print(refined.shape)

        x = refined.permute(2,0,1)
        # L B F
        x = self.dropout(x)
        outputs, h = self.rnn(x)
	
        # L B F
        x = self.temporal_aggregation(outputs)
        # B  F 

        y = self.prediction(x)    
        return y

In [6]:
Batch = 2
F_in  = 100
Leng  = 128
C_in = 6
number_class = 6
model = CFC_V1_Model((Batch,F_in,Leng, C_in),number_class,filter_num = 32).double()
input = torch.rand(Batch,F_in,Leng, C_in).double()
model(input)

torch.Size([2, 32, 5])


tensor([[ 0.0680,  0.1128, -0.1493,  0.1220,  0.1361, -0.1117],
        [ 0.0524,  0.1450, -0.1422,  0.1411,  0.1466, -0.1003]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

In [7]:
import numpy as np
np.sum([para.numel() for para in model.parameters()])

54888

In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
    """
    Create and initialize a `nn.Conv1d` layer with spectral normalization.
    """
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias:
        conv.bias.data.zero_()
    # return spectral_norm(conv)
    return conv

class AggregationAttention(nn.Module):
    """
    Temporal attention module
    """
    def __init__(self, hidden_dim):
        super(AggregationAttention, self).__init__()
        self.fc = nn.Linear(hidden_dim, 1)
        self.sm = torch.nn.Softmax(dim=0)

    def forward(self, x):
        out = self.fc(x).squeeze(2)
        weights_att = self.sm(out).unsqueeze(2)
        context = torch.sum(weights_att * x, 0)
        return context

class SelfAttention(nn.Module):
    """

    """
    def __init__(self, n_channels: int, div):
        super(SelfAttention, self).__init__()

        if n_channels > 1:
            self.query = conv1d(n_channels, n_channels//div)
            self.key = conv1d(n_channels, n_channels//div)
        else:
            self.query = conv1d(n_channels, n_channels)
            self.key = conv1d(n_channels, n_channels)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(torch.tensor([0.]))
        self.ta = AggregationAttention(n_channels)

    def forward(self, x):
        # Notation from https://arxiv.org/pdf/1805.08318.pdf
        #print(" 1 : ", x.shape)
        size = x.size()
        #print("size+",size)
        x = x.view(*size[:2], -1)
        #print(" 2 : ", x.shape)
        #print("size-",x.size())
        f, g, h = self.query(x), self.key(x), self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        #print(" 3 : ", o.shape)
        ta = self.ta(o.permute(2,0,1))
        #print("5 :" ,ta.shape)
        #print(" 4 : ", o.view(*size).shape)
        return o.view(*size).contiguous(),ta.unsqueeze(2)

    
class CFC_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 16,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_Model, self).__init__()
        
        # PART 1 , Channel wise Feature Extraction
        
        layers_conv = []
        for i in range(nb_conv_layers):
        
            if i == 0:
                in_channel = 1
            else:
                in_channel = filter_num
            if i%2==0:
                stride = 2
            else:
                stride = 1
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(stride,1)),#(2,1)
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),

            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , Cross Channel Fusion through Attention
        self.dropout = nn.Dropout(dropout)

        self.sa = SelfAttention(filter_num, sa_div)
        
#         shape = self.get_the_shape(input_shape)

#         # PART 3 , Prediction 
        
#         self.activation = nn.ReLU() 
#         self.fc1 = nn.Linear(input_shape[2]*filter_num ,filter_num)
#         self.flatten = nn.Flatten()
#         self.fc2 = nn.Linear(shape[1]*filter_num ,filter_num)
#         self.fc3 = nn.Linear(filter_num ,number_class)


        
    def get_the_shape(self, input_shape):
        x = torch.rand(input_shape)
        x = x.unsqueeze(1)
        for layer in self.layers_conv:
            x = layer(x)    
        atten_x = torch.cat(
            [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3))[0] for t in range(x.shape[2])],
            dim=-1,
        )
        atten_x = atten_x.permute(0, 3, 1, 2)
        return atten_x.shape

    def forward(self, x):
        # B L C
        x = x.unsqueeze(1)
        print("begin:" , x.shape)
        
        for layer in self.layers_conv:
            x = layer(x)      


        batch, filter, length, channel = x.shape

        print("conv : ", x.shape)
        # apply self-attention on each temporal dimension (along sensor and feature dimensions)
#         refined = torch.cat(
#             [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3))[0] for t in range(x.shape[2])],
#             dim=-1,
#         )
        refined = torch.cat(
            [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3))[1] for t in range(x.shape[2])],
            dim=-1,
        )
        print("attn : ",refined.shape)
       # print(refined.shape)

#         x = refined.permute(0, 3, 1, 2)
#         x = x.reshape(x.shape[0], x.shape[1], -1)
#         print("reshape : ",x.shape)
#         x = self.dropout(x)
        
#         x = self.activation(self.fc1(x)) # B L C
#         x = self.flatten(x)
#         x = self.activation(self.fc2(x)) # B L C
#         y = self.fc3(x)    
        return None

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
    """
    Create and initialize a `nn.Conv1d` layer with spectral normalization.
    """
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias:
        conv.bias.data.zero_()
    # return spectral_norm(conv)
    return conv

class AggregationAttention(nn.Module):
    """
    Temporal attention module
    """
    def __init__(self, hidden_dim):
        super(AggregationAttention, self).__init__()
        self.fc = nn.Linear(hidden_dim, 1)
        self.sm = torch.nn.Softmax(dim=0)

    def forward(self, x):
        print("x, ",x.shape)
        out = self.fc(x).squeeze(2)
        print("out, ",out.shape)
        weights_att = self.sm(out).unsqueeze(2)
        print("weights_att, ",weights_att.shape)
        context = torch.sum(weights_att * x, 0)
        print("context, ",context.shape)
        return context

class SelfAttention(nn.Module):
    """

    """
    def __init__(self, n_channels: int, div):
        super(SelfAttention, self).__init__()

        if n_channels > 1:
            self.query = conv1d(n_channels, n_channels//div)
            self.key = conv1d(n_channels, n_channels//div)
        else:
            self.query = conv1d(n_channels, n_channels)
            self.key = conv1d(n_channels, n_channels)
        self.value = conv1d(n_channels, n_channels)
        self.gamma = nn.Parameter(torch.tensor([0.]))
        self.ta = AggregationAttention(n_channels)

    def forward(self, x):
        # Notation from https://arxiv.org/pdf/1805.08318.pdf
        #print(" 1 : ", x.shape)
        size = x.size()
        #print("size+",size)
        x = x.view(*size[:2], -1)
        #print(" 2 : ", x.shape)
        #print("size-",x.size())
        f, g, h = self.query(x), self.key(x), self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        print(" 3 : ", o.shape)
        ta = self.ta(o.permute(2,0,1))
        print("5 :" ,ta.shape)
        #print(" 4 : ", o.view(*size).shape)
        return o.view(*size).contiguous(),ta.unsqueeze(2)

    
class CFC_Model(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 16,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        hidden_dim = 32,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_Model, self).__init__()
        
        # PART 1 , Channel wise Feature Extraction
        
        layers_conv = []
        for i in range(nb_conv_layers):
        
            if i == 0:
                in_channel = 1
            else:
                in_channel = filter_num
            if i%2==0:
                stride = 2
            else:
                stride = 1
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(stride,1)),#(2,1)
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),

            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , Cross Channel Fusion through Attention
        self.dropout = nn.Dropout(dropout)

        self.sa = SelfAttention(filter_num, sa_div)
        self.rnn = nn.GRU(
            filter_num,
            hidden_dim,
            2,
            bidirectional=False,
            dropout=0.15,
        )
        
        self.ta = AggregationAttention(hidden_dim)
#         shape = self.get_the_shape(input_shape)

#         # PART 3 , Prediction 
        self.fc = nn.Linear(hidden_dim, number_class)
#         self.activation = nn.ReLU() 
#         self.fc1 = nn.Linear(input_shape[2]*filter_num ,filter_num)
#         self.flatten = nn.Flatten()
#         self.fc2 = nn.Linear(shape[1]*filter_num ,filter_num)
#         self.fc3 = nn.Linear(filter_num ,number_class)


        
    def get_the_shape(self, input_shape):
        x = torch.rand(input_shape)
        x = x.unsqueeze(1)
        for layer in self.layers_conv:
            x = layer(x)    
        atten_x = torch.cat(
            [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3))[0] for t in range(x.shape[2])],
            dim=-1,
        )
        atten_x = atten_x.permute(0, 3, 1, 2)
        return atten_x.shape

    def forward(self, x):
        # B L C
        x = x.unsqueeze(1)
        print("begin:" , x.shape)
        
        for layer in self.layers_conv:
            x = layer(x)      


        batch, filter, length, channel = x.shape

        print("conv : ", x.shape)
        # apply self-attention on each temporal dimension (along sensor and feature dimensions)
#         refined = torch.cat(
#             [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3))[0] for t in range(x.shape[2])],
#             dim=-1,
#         )
        refined = torch.cat(
            [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3))[1] for t in range(x.shape[2])],
            dim=-1,
        )
        print("attn : ",refined.shape)
       # print(refined.shape)

        x = refined.permute(2,0,1)

#         print("reshape : ",x.shape)
        x = self.dropout(x)
        outputs, h = self.rnn(x)
        x = self.ta(outputs)
        print("out : ",x.shape)
#         x = self.activation(self.fc1(x)) # B L C
#         x = self.flatten(x)
#         x = self.activation(self.fc2(x)) # B L C
        y = self.fc(x)    
        return y

In [12]:



model = CFC_Model((2,256,600),6,filter_num = 16).double()
input = torch.rand(2,256,600).double()
model(input)

begin: torch.Size([2, 1, 256, 600])
conv :  torch.Size([2, 16, 55, 600])
 3 :  torch.Size([2, 16, 600])
x,  torch.Size([600, 2, 16])
out,  torch.Size([600, 2])
weights_att,  torch.Size([600, 2, 1])
context,  torch.Size([2, 16])
5 : torch.Size([2, 16])
 3 :  torch.Size([2, 16, 600])
x,  torch.Size([600, 2, 16])
out,  torch.Size([600, 2])
weights_att,  torch.Size([600, 2, 1])
context,  torch.Size([2, 16])
5 : torch.Size([2, 16])
 3 :  torch.Size([2, 16, 600])
x,  torch.Size([600, 2, 16])
out,  torch.Size([600, 2])
weights_att,  torch.Size([600, 2, 1])
context,  torch.Size([2, 16])
5 : torch.Size([2, 16])
 3 :  torch.Size([2, 16, 600])
x,  torch.Size([600, 2, 16])
out,  torch.Size([600, 2])
weights_att,  torch.Size([600, 2, 1])
context,  torch.Size([2, 16])
5 : torch.Size([2, 16])
 3 :  torch.Size([2, 16, 600])
x,  torch.Size([600, 2, 16])
out,  torch.Size([600, 2])
weights_att,  torch.Size([600, 2, 1])
context,  torch.Size([2, 16])
5 : torch.Size([2, 16])
 3 :  torch.Size([2, 16, 600])
x

tensor([[ 0.1087, -0.0541, -0.0789, -0.1395, -0.0465,  0.0397],
        [ 0.1144, -0.0556, -0.0767, -0.1404, -0.0492,  0.0441]],
       dtype=torch.float64, grad_fn=<AddmmBackward>)

In [65]:
import numpy as np
np.sum([para.numel() for para in model.parameters()])

16265

In [None]:
class CFC_New(nn.Module):
    def __init__(
        self,
        input_shape ,
        number_class , 
        filter_num = 16,
        filter_size = 5,
        nb_conv_layers = 4,
        dropout = 0.2,
        activation = "ReLU",
        sa_div= 1,
    ):
        super(CFC_New, self).__init__()
        
        # PART 1 , Channel wise Feature Extraction
        
        layers_conv = []
        for i in range(nb_conv_layers):
        
            if i == 0:
                in_channel = 1
            else:
                in_channel = filter_num
    
            layers_conv.append(nn.Sequential(
                nn.Conv2d(in_channel, filter_num, (filter_size, 1),(2,1)),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(filter_num),

            ))
        
        self.layers_conv = nn.ModuleList(layers_conv)

        # PART2 , Cross Channel Fusion through Attention
        self.dropout = nn.Dropout(dropout)

        self.sa = SelfAttention(filter_num, sa_div)
        
        shape = self.get_the_shape(input_shape)

        # PART 3 , Prediction 
        
        self.activation = nn.ReLU() 
        self.fc1 = nn.Linear(input_shape[2]*filter_num ,filter_num)
        self.flatten = nn.Flatten()
        self.fc2 = nn.Linear(shape[1]*filter_num ,filter_num)
        self.fc3 = nn.Linear(filter_num ,number_class)


        
    def get_the_shape(self, input_shape):
        x = torch.rand(input_shape)
        x = x.unsqueeze(1)
        for layer in self.layers_conv:
            x = layer(x)    
        atten_x = torch.cat(
            [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3)) for t in range(x.shape[2])],
            dim=-1,
        )
        atten_x = atten_x.permute(0, 3, 1, 2)
        return atten_x.shape

    def forward(self, x):
        # B L C
        x = x.unsqueeze(1)
        
        
        for layer in self.layers_conv:
            x = layer(x)      


        batch, filter, length, channel = x.shape


        # apply self-attention on each temporal dimension (along sensor and feature dimensions)
        refined = torch.cat(
            [self.sa(torch.unsqueeze(x[:, :, t, :], dim=3)) for t in range(x.shape[2])],
            dim=-1,
        )


       # print(refined.shape)

        x = refined.permute(0, 3, 1, 2)
        x = x.reshape(x.shape[0], x.shape[1], -1)
        x = self.dropout(x)
        
        x = self.activation(self.fc1(x)) # B L C
        x = self.flatten(x)
        x = self.activation(self.fc2(x)) # B L C
        y = self.fc3(x)    
        return y