In [1]:
from __future__ import absolute_import, print_function
import torch
from torch import nn
import math
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import numpy as np


In [None]:
# MemModule1_new  h*w h/window*w/window
# MemModule_window  widow*window*c  (widow*window*c)/c_size       window=2,c_size=2

In [114]:
#最终版
class MemModule1_new(nn.Module):
    def __init__(self,mem_dim,fea_dim,window,shrink_thres=0.0025, device='cuda'):
        super(MemModule1_new, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)
        self.window=window

    def forward(self, x):
        s = x.data.shape
        x = x.view(s[1]*s[0]*self.window*self.window, -1)
        y_and = self.memory(x)
        y = y_and['output']
#         print('y',y.shape)
        att = y_and['att']
#         print('att',att.shape)
        
        y = y.view(s[0], s[1], s[2], s[3])
        att = att.view(s[0]*s[1],self.window,self.window,self.mem_dim)
        att = att.permute(0, 3, 1, 2)
        print('att',att.shape)
        
#         att = att.view(s[0]* s[1],self.window, self.window, self.mem_dim)
        return {'output': y, 'att': att}


In [116]:
# fea_dim=h/window*w/window
model=MemModule1_new(mem_dim=2000,fea_dim=256,window=2)
x=torch.rand((2,256,32,32))
out=model(x)
out['output'].shape,out['att'].shape,

att torch.Size([512, 2000, 2, 2])


(torch.Size([2, 256, 32, 32]), torch.Size([512, 2000, 2, 2]))

In [68]:
class MemModule1(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'):
        super(MemModule1, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)

    def forward(self, x):
        s = x.data.shape
#         x = input.permute(0, 2, 3, 1).contiguous()
#         print(x.shape)
        x = x.view(s[1]*s[0], -1)
#         print(x.shape)
#         print("self.fea_dim.shape",self.fea_dim)
        #
        y_and = self.memory(x)
        
        #
        y = y_and['output']
        print('y',y.shape)
        att = y_and['att']
        print('att',att.shape)
        
        y = y.view(s[0], s[1], s[2], s[3])
#         print( y.shape)
#         y = y.permute(0, 3, 1, 2)
#         print( y.shape)
        att = att.view(s[0],s[1],self.mem_dim)
#         print( att.shape)
#         att = att.permute(0, 3, 1, 2)
#         print( att.shape)

        return {'output': y, 'att': att}


In [69]:
model=MemModule1(mem_dim=2000,fea_dim=1024)
x=torch.rand((2,256,32,32))
out=model(x)
out['output'].shape,out['att'].shape,


y torch.Size([512, 1024])
att torch.Size([512, 2000])


(torch.Size([2, 256, 32, 32]), torch.Size([2, 256, 2000]))

In [2]:
def hard_shrink_relu(input, lambd=0, epsilon=1e-12):
    output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon)
    return output


class MemoryUnit(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
        super(MemoryUnit, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim))  # M x C
        self.bias = None
        self.shrink_thres= shrink_thres
        # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
#         print(input.shape)
#         print(self.weight.shape)
        
        att_weight = F.linear(input, self.weight)  # Fea x Mem^T, (TxC) x (CxM) = TxM
        att_weight = F.softmax(att_weight, dim=1)  # TxM
        # ReLU based shrinkage, hard shrinkage for positive value
        if(self.shrink_thres>0):
            att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
#             att_weight = F.softshrink(att_weight, lambd=self.shrink_thres)
            # normalize???
            att_weight = F.normalize(att_weight, p=1, dim=1)
            # att_weight = F.softmax(att_weight, dim=1)
            # att_weight = self.hard_sparse_shrink_opt(att_weight)
        mem_trans = self.weight.permute(1, 0)  # Mem^T, MxC
        output = F.linear(att_weight, mem_trans)  # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC
        return {'output': output, 'att': att_weight}  # output, att_weight

    def extra_repr(self):
        return 'mem_dim={}, fea_dim={}'.format(
            self.mem_dim, self.fea_dim is not None
        )


In [50]:
m=MemoryUnit(mem_dim=2000,fea_dim=1024, shrink_thres=0.0025)
x=torch.rand((512,1024))
out=m(x)

In [51]:
out['output'].shape,out['att'].shape,

(torch.Size([512, 1024]), torch.Size([512, 2000]))

In [61]:
class MemModule_w_new(nn.Module):
    def __init__(self, mem_dim, fea_dim, window,shrink_thres=0.0025, device='cuda'):
        super(MemModule_w_new, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)
        self.window=window
    def forward(self, input1):
        s = input1.data.shape
        x = input1.permute(0, 2, 3, 1).contiguous()
        b,h,w,c=x.shape
        num_window=int((h/self.window)*(w/self.window))
        x=x.view(b*num_window,self.window,self.window,c)
        x=x.view(x.size(0), -1)
        
        y_and = self.memory(x)
        
        y = y_and['output']
        att = y_and['att']
        y = y.view(s[0], s[2], s[3], s[1])
        y = y.permute(0, 3, 1, 2)
        att = att.view(s[0], int(h/self.window), int(w/self.window), self.mem_dim)
        att = att.permute(0, 3, 1, 2)
#         print("att.shape",att.shape)
        return {'output': y, 'att': att}


In [62]:
window=2
fea_dim=window*window*c
model1=MemModule(mem_dim=2000, fea_dim=1024,window=2)
x1=torch.rand((2,256,32,32))
out1=model1(x1)

torch.Size([512, 1024])
torch.Size([512, 1024])
att.shape torch.Size([512, 2000])
torch.Size([2, 256, 32, 32])
att.shape torch.Size([2, 2000, 16, 16])


In [120]:
#最终版
class MemModule_w_new(nn.Module):
    def __init__(self, mem_dim, fea_dim, window,c_size=1,shrink_thres=0.0025, device='cuda'):
        super(MemModule_w_new, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)
        self.window=window
        self.c_size=c_size
    def forward(self, input1):
        if self.c_size==1:
            s = input1.data.shape
            x = input1.permute(0, 2, 3, 1).contiguous()
            b,h,w,c=x.shape
            num_window=int((h/self.window)*(w/self.window))
            x=x.view(b*num_window,self.window,self.window,c)
            x=x.view(x.size(0), -1)

            y_and = self.memory(x)

            y = y_and['output']
            att = y_and['att']
            y = y.view(s[0], s[2], s[3], s[1])
            y = y.permute(0, 3, 1, 2)
            att = att.view(s[0], int(h/self.window), int(w/self.window), self.mem_dim)
            att = att.permute(0, 3, 1, 2)
    #         print("att.shape",att.shape)
            return {'output': y, 'att': att}
        else:
            B,C,H,W=input1.shape
            input1=input1.view(B*self.c_size,-1,H,W)
            s = input1.data.shape
            x = input1.permute(0, 2, 3, 1).contiguous()
            b,h,w,c=x.shape
            num_window=int((h/self.window)*(w/self.window))
            x=x.view(b*num_window,self.window,self.window,c)
            x=x.view(x.size(0), -1)
            y_and = self.memory(x)
            y = y_and['output']
#             print(y.shape)
            att = y_and['att']
            y = y.view(s[0], s[2], s[3], s[1])
            y = y.permute(0, 3, 1, 2).contiguous()
            y = y.view(B,C,H,W)
            print(y.shape)
            
#             print(att.shape)
            
            att = att.view(B*self.c_size,int(h/self.window),int(w/self.window), self.mem_dim)
            att = att.permute(0, 3, 1, 2)
            print(att.shape)
            return {'output': y, 'att': att}

In [122]:
# window=2
# fea_dim=window*window*c
model1=MemModule_w_new(mem_dim=2000, fea_dim=1024,window=4,c_size=4)
x1=torch.rand((2,256,32,32))
out1=model1(x1)

torch.Size([2, 256, 32, 32])
torch.Size([8, 2000, 8, 8])


In [20]:
x1=torch.rand((1,256,32,32))
x=Flatten(x1,window=2)

torch.Size([2048, 2, 2, 32])
torch.Size([2048, 128])


In [55]:
class MemModule_ori(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'):
        super(MemModule_ori, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)

    def forward(self, input):
        s = input.data.shape
        x = input.permute(0, 2, 3, 1)
        print(x.shape)
            
        x = x.contiguous()
        print(x.shape)
        x = x.view(-1, s[1])
        print(x.shape)
        #
        y_and = self.memory(x)
        
        #
        y = y_and['output']
        print('y',y.shape)
        att = y_and['att']
        print('att',att.shape)
        
        y = y.view(s[0], s[2], s[3], s[1])
#         print( y.shape)
        y = y.permute(0, 3, 1, 2)
#         print( y.shape)
        att = att.view(s[0], s[2], s[3], self.mem_dim)
#         print( att.shape)
        att = att.permute(0, 3, 1, 2)
#         print( att.shape)

        return {'output': y, 'att': att}



In [56]:
model2=MemModule_ori(mem_dim=2000, fea_dim=256)
x2=torch.rand((2,256,32,32))
out2=model2(x2)

torch.Size([2, 32, 32, 256])
torch.Size([2, 32, 32, 256])
torch.Size([2048, 256])
y torch.Size([2048, 256])
att torch.Size([2048, 2000])


In [54]:
out2['output'].shape,out2['att'].shape,

(torch.Size([2, 256, 32, 32]), torch.Size([2, 2000, 32, 32]))

In [None]:
def window_partition_c(x, window_size,c_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C//c_size,c_size)
    windows = x.permute(0, 1, 3,5, 2, 4, 6).contiguous().view(-1, window_size, window_size, c_size)
    return windows


def window_reverse_c(windows, window_size,c_size,B, H, W,C):
    x = windows.view(B, H // window_size, W // window_size, C // c_size, window_size, window_size,  c_size)
    x = x.permute(0, 1, 4, 2, 5, 3,6).contiguous().view(B, H, W, C)
    return x

In [1]:
def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    print(x.shape)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


In [2]:
x=torch.rand((2,32,32,256))
windows=window_partition(x,2)

NameError: name 'torch' is not defined

In [10]:
windows.shape

torch.Size([4096, 2, 2, 32])

In [11]:
from torch import nn
import torch
import math
 
class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
#         self.up = nn.ConvTranspose2d(in_channels=1024, out_channels=512, 
#                                  kernel_size=2, stride=2)
        self.down0 = nn.MaxPool2d(kernel_size=2, stride=2)
    def forward(self, input):
        x = self.down0(input)
        
        print(x.shape)
        
        return x

In [12]:
model=test()
y1=torch.rand((2,1024,32,32))
res=model(y1)

torch.Size([2, 1024, 16, 16])


In [5]:
from torch import nn
import torch
import math
 
class Flatten(nn.Module):
    def forward(self, input,window):
        s = input.data.shape
        x = input.permute(0, 2, 3, 1).contiguous()
        b,h,w,c=x.shape
        num_window=int((h/window)*(w/window))
        x=x.view(b*num_window,window,window,c)
        print(x.shape)
        
        x=x.view(x.size(0), -1)
        
        print(x.shape)
        
        return x

In [7]:
y=Flatten()
y1=torch.rand((2,256,32,32))
y2=y(y1,2)
y3=y(y1,4)

torch.Size([512, 2, 2, 256])
torch.Size([512, 1024])
torch.Size([128, 4, 4, 256])
torch.Size([128, 4096])


In [22]:
y2.shape

torch.Size([512, 1024])

In [4]:
a = torch.rand(2,2,4,4)

In [5]:
a

tensor([[[[0.8896, 0.8661, 0.4522, 0.0104],
          [0.5813, 0.4983, 0.4351, 0.7579],
          [0.3916, 0.9485, 0.3603, 0.1219],
          [0.2412, 0.6517, 0.1153, 0.8062]],

         [[0.3015, 0.7950, 0.1755, 0.8463],
          [0.3675, 0.7933, 0.8266, 0.1903],
          [0.1023, 0.0681, 0.2987, 0.1832],
          [0.7480, 0.1329, 0.4124, 0.9361]]],


        [[[0.0334, 0.8408, 0.6470, 0.8342],
          [0.9630, 0.2236, 0.4736, 0.0899],
          [0.6655, 0.0656, 0.5195, 0.2043],
          [0.3681, 0.4142, 0.2742, 0.6806]],

         [[0.6196, 0.1570, 0.0636, 0.9813],
          [0.2829, 0.5748, 0.5487, 0.7993],
          [0.0043, 0.8250, 0.7426, 0.7172],
          [0.3696, 0.5365, 0.0301, 0.0208]]]])

In [6]:
a[:,:,:2,:2]

tensor([[[[0.8896, 0.8661],
          [0.5813, 0.4983]],

         [[0.3015, 0.7950],
          [0.3675, 0.7933]]],


        [[[0.0334, 0.8408],
          [0.9630, 0.2236]],

         [[0.6196, 0.1570],
          [0.2829, 0.5748]]]])

In [2]:
def hard_shrink_relu(input, lambd=0, epsilon=1e-12):
    output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon)
    return output


class MemoryUnit(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
        super(MemoryUnit, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim))  # M x C
        self.bias = None
        self.shrink_thres= shrink_thres
        # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        att_weight = F.linear(input, self.weight)  # Fea x Mem^T, (TxC) x (CxM) = TxM
        att_weight = F.softmax(att_weight, dim=1)  # TxM
        # ReLU based shrinkage, hard shrinkage for positive value
        if(self.shrink_thres>0):
            att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
#             att_weight = F.softshrink(att_weight, lambd=self.shrink_thres)
            # normalize???
            att_weight = F.normalize(att_weight, p=1, dim=1)
            # att_weight = F.softmax(att_weight, dim=1)
            # att_weight = self.hard_sparse_shrink_opt(att_weight)
        mem_trans = self.weight.permute(1, 0)  # Mem^T, MxC
        output = F.linear(att_weight, mem_trans)  # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC
        return {'output': output, 'att': att_weight}  # output, att_weight

    def extra_repr(self):
        return 'mem_dim={}, fea_dim={}'.format(
            self.mem_dim, self.fea_dim is not None
        )

In [None]:
# NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW
class MemModule1(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'):
        super(MemModule1, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)

    def forward(self, input):
        s = input.data.shape
        x = input.permute(0, 2, 3, 1).contiguous()
        print(x.shape)
        x = x.view(-1, s[1])
        print(x.shape)
        #
        y_and = self.memory(x)
        
        #
        y = y_and['output']
#         print('y',y.shape)
        att = y_and['att']
        print('att',att.shape)
        
        y = y.view(s[0], s[2], s[3], s[1])
#         print( y.shape)
        y = y.permute(0, 3, 1, 2)
#         print( y.shape)
        att = att.view(s[0], s[2], s[3], self.mem_dim)
        print( att.shape)
        att = att.permute(0, 3, 1, 2)
        print( att.shape)

        return {'output': y, 'att': att}

In [None]:
model1=MemModule1(mem_dim=512, fea_dim=256)
x1=torch.rand((1,256,32,32))
out1=model1(x1)

In [None]:
out1['output'].shape

In [None]:
out1['att'].shape

In [None]:
# NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW
class MemModule_s(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'):
        super(MemModule_s, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)

    def forward(self, x):
        s = x.data.shape
#         x = input.permute(0, 2, 3, 1).contiguous()
        print(x.shape)
        x = x.view(s[1]*s[0], -1)
        print(x.shape)
        #
        y_and = self.memory(x)
        
        #
        y = y_and['output']
#         print('y',y.shape)
        att = y_and['att']
        print('att',att.shape)
        
        y = y.view(s[0], s[1], s[2], s[3])
#         print( y.shape)
#         y = y.permute(0, 3, 1, 2)
#         print( y.shape)
        att = att.view(s[0],s[1],self.mem_dim)
#         print( att.shape)
#         att = att.permute(0, 3, 1, 2)
#         print( att.shape)

        return {'output': y, 'att': att}

In [None]:
model=MemModule_s(mem_dim=512, fea_dim=64)
x=torch.rand((32,1024,8,8))
out=model(x)

In [None]:
model=MemModule(mem_dim=512, fea_dim=64)
x=torch.rand((1,1024,8,8))
out=model(x)

In [None]:
out['output'].shape

In [None]:
out['att'].shape

In [5]:
def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

In [6]:
y=torch.rand((256, 100))
y=window_reverse(y,2,32,32)
print(y.shape)

torch.Size([1, 32, 32, 25])


In [10]:
# NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW
class MemModule_w(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025,window_size=4, device='cuda'):
        super(MemModule_w, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)
        self.window_size=window_size
    def forward(self, input):
        s = input.data.shape
#         print(s)
        x = input.permute(0, 2, 3, 1).contiguous()
#         print(x.shape)
        x=window_partition(x, self.window_size)###########
#         print(x.shape)
        b,w,w,c=x.shape
        x = x.view(b, -1)
#         x = x.view(-1, s[1])
#         print(x.shape)
        #
        y_and = self.memory(x)
        y = y_and['output']
#         print('y',y.shape)
        att = y_and['att']
#         print('att',att.shape)
        y=window_reverse(y,self.window_size,s[2], s[3])########
#         print( y.shape)
        y = y.permute(0, 3, 1, 2)
#         print( y.shape)
        att=window_reverse(att,self.window_size,s[2], s[3])
        att = att.permute(0, 3, 1, 2)
#         print( att.shape)

        return {'output': y, 'att': att}

In [11]:
model1=MemModule_w(mem_dim=200, fea_dim=1024)
x1=torch.rand((1,64,64,64))
out1=model1(x1)

RuntimeError: shape '[1, 16, 16, 4, 4, -1]' is invalid for input of size 51200

In [None]:
model1=MemModule1(mem_dim=100, fea_dim=256)
x1=torch.rand((1,256,16,16))
out1=model1(x1)