Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of More Pooling Methods #2048

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions timm/layers/gen_maxpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn

class GeneralizedMP(nn.Module):
"""
Implements Generalized Max Pooling (GMP), a global pooling operation that
generalizes the concept of max pooling to capture more complex and discriminative
features from the input tensor.

The class operates by computing a linear kernel based on the input tensor,
then solving a linear system to obtain the pooling coefficients. These coefficients
are used to weigh and aggregate the input features, resulting in a pooled feature vector.

Parameters:
lamb (float, optional): A regularization parameter used in the linear system
to ensure numerical stability. Default value is 1e3.

Note:
- The input tensor is expected to be in the format (B, D, H, W), where B is batch size,
D is depth (channels), H is height, and W is width.
- The implementation uses PyTorch's linear algebra functions to solve the linear system.
"""
def __init__(self, lamb = 1e3):
super().__init__()
self.lamb = nn.Parameter(lamb * torch.ones(1))
#self.inv_lamb = nn.Parameter((1./lamb) * torch.ones(1))

def forward(self, x):
B, D, H, W = x.shape
N = H * W
identity = torch.eye(N).cuda()
# reshape x, s.t. we can use the gmp formulation as a global pooling operation
x = x.view(B, D, N)
x = x.permute(0, 2, 1)
# compute the linear kernel
K = torch.bmm(x, x.permute(0, 2, 1))
# solve the linear system (K + lambda * I) * alpha = ones
A = K + self.lamb * identity
o = torch.ones(B, N, 1).cuda()
#alphas, _ = torch.gesv(o, A) # tested using pytorch 1.0.1
alphas = torch.linalg.solve(A,o) # TODO check it again
alphas = alphas.view(B, 1, -1)
xi = torch.bmm(alphas, x)
xi = xi.view(B, -1)
return xi
97 changes: 97 additions & 0 deletions timm/layers/how_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class HOWPooling(nn.Module):
"""
Implements HOW, as described in the paper
'Learning and Aggregating Deep Local Descriptors for Instance-Level Recognition'.
This pooling method focuses on aggregating deep local descriptors
for enhanced instance-level recognition.

The class includes functions for L2-based attention, smoothing average pooling,
L2 normalization (l2n), and a forward method that integrates these components.
It applies dimensionality reduction to the input features before the pooling operation.

Parameters:
input_dim (int): Dimension of the input features.
dim_reduction (int): Target dimension after reduction.
kernel_size (int): Size of the kernel used in smoothing average pooling.
"""
def __init__(self, input_dim = 512, dim_reduction = 128, kernel_size = 3):
super(HOWPooling, self).__init__()
self.kernel_size = kernel_size
self.dimreduction = ConvDimReduction(input_dim, dim_reduction)

def L2Attention(self, x):
return (x.pow(2.0).sum(1) + 1e-10).sqrt().squeeze(0)

def smoothing_avg_pooling(self, feats):
"""Smoothing average pooling
:param torch.Tensor feats: Feature map
:param int kernel_size: kernel size of pooling
:return torch.Tensor: Smoothend feature map
"""
pad = self.kernel_size // 2
return F.avg_pool2d(feats, (self.kernel_size, self.kernel_size), stride=1, padding=pad,
count_include_pad=False)

def l2n(self, x, eps=1e-6):
return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x)

def forward(self, x):

weights = self.L2Attention(x)
x = self.smoothing_avg_pooling(x)
x = self.dimreduction(x)
x = (x * weights.unsqueeze(1)).sum((-2, -1))
return self.l2n(x)

class ConvDimReduction(nn.Conv2d):
"""
Implements dimensionality reduction using a convolutional layer. This layer is
designed for reducing the dimensions of input features, particularly for use in
aggregation and pooling operations like in the HOWPooling class.

The class also includes methods for learning and applying PCA whitening with shrinkage,
which is a technique to reduce dimensionality while preserving important feature variations.

Parameters:
input_dim (int): The input dimension (number of channels) of the network.
dim (int): The target output dimension for the whitening process.
"""
def __init__(self, input_dim, dim):
super().__init__(input_dim, dim, (1, 1), padding=0, bias=True)

def pcawhitenlearn_shrinkage(X, s=1.0):
"""Learn PCA whitening with shrinkage from given descriptors"""
N = X.shape[0]

# Learning PCA w/o annotations
m = X.mean(axis=0, keepdims=True)
Xc = X - m
Xcov = np.dot(Xc.T, Xc)
Xcov = (Xcov + Xcov.T) / (2*N)
eigval, eigvec = np.linalg.eig(Xcov)
order = eigval.argsort()[::-1]
eigval = eigval[order]
eigvec = eigvec[:, order]

eigval = np.clip(eigval, a_min=1e-14, a_max=None)
P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5*s))), eigvec.T)

return m, P.T

def initialize_pca_whitening(self, des):
"""Initialize PCA whitening from given descriptors. Return tuple of shift and projection."""
m, P = self.pcawhitenlearn_shrinkage(des)
m, P = m.T, P.T

projection = torch.Tensor(P[:self.weight.shape[0], :]).unsqueeze(-1).unsqueeze(-1)
self.weight.data = projection.to(self.weight.device)

projected_shift = -torch.mm(torch.FloatTensor(P), torch.FloatTensor(m)).squeeze()
self.bias.data = projected_shift[:self.weight.shape[0]].to(self.bias.device)
return m.T, P.T
36 changes: 36 additions & 0 deletions timm/layers/lse_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class LSEPool(nn.Module):
"""
Implements LogSumExp (LSE) pooling, an advanced pooling technique that provides
a smooth approximation to the max pooling operation. This pooling method is useful
for capturing the global distribution of features across spatial dimensions (height and width)
of the input tensor, while still maintaining differentiability.

The class supports learnable pooling behavior with an optional learnable parameter 'r'.
When 'r' is large, LSE pooling closely approximates max pooling, and when 'r' is small,
it behaves more like average pooling. The 'r' parameter can either be a fixed value or
learned during training.

Parameters:
r (float, optional): The initial value of the pooling parameter. Default is 10.
learnable (bool, optional): If True, 'r' is a learnable parameter. Default is True.
"""

def __init__(self, r=10, learnable=True):
super(LSEPool, self).__init__()
if learnable:
self.r = nn.Parameter(torch.ones(1) * r)
else:
self.r = r

def forward(self, x):
s = (x.size(2) * x.size(3))
x_max = F.adaptive_max_pool2d(x, 1)
exp = torch.exp(self.r * (x - x_max))
sumexp = 1 / s * torch.sum(exp, dim=(2, 3))
sumexp = sumexp.view(sumexp.size(0), -1, 1, 1)
logsumexp = x_max + 1 / self.r * torch.log(sumexp)
return logsumexp
101 changes: 101 additions & 0 deletions timm/layers/simpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn

class SimPool(nn.Module):
"""
Implements SimPool as described in the ICCV 2023 paper
"Keep It SimPool: Who Said Supervised Transformers Suffer from Attention Deficit?".
This class is designed to provide an efficient and effective pooling strategy
for both Transformer and CNN architectures.

SimPool applies a global average pooling (GAP) operation as an initial step
and then utilizes a simple but powerful attention mechanism to refine the pooled features.
The attention mechanism uses linear transformations for queries and keys, followed by
softmax normalization to compute attention scores.

Parameters:
dim (int): Dimension of the input features.
num_heads (int, optional): Number of attention heads. Default is 1.
qkv_bias (bool, optional): If True, adds a learnable bias to query, key, value projections. Default is False.
qk_scale (float, optional): Scaling factor for query-key dot product. Default is None, which uses the inverse square root of head dimensions.
gamma (float, optional): Scaling parameter for value vectors, used if not None. Default is None.
use_beta (bool, optional): If True, adds a learnable translation to the value vectors after applying gamma. Default is False.
"""
def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5

self.norm_patches = nn.LayerNorm(dim, eps=1e-6)

self.wq = nn.Linear(dim, dim, bias=qkv_bias)
self.wk = nn.Linear(dim, dim, bias=qkv_bias)

if gamma is not None:
self.gamma = torch.tensor([gamma], device='cuda')
if use_beta:
self.beta = nn.Parameter(torch.tensor([0.0], device='cuda'))
self.eps = torch.tensor([1e-6], device='cuda')

self.gamma = gamma
self.use_beta = use_beta

def prepare_input(self, x):
if len(x.shape) == 3: # Transformer
# Input tensor dimensions:
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
B, N, d = x.shape
gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
return gap_cls, x
if len(x.shape) == 4: # CNN
# Input tensor dimensions:
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
B, d, H, W = x.shape
gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
x = x.reshape(B, d, H*W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
return gap_cls, x
else:
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")

def forward(self, x):
# Prepare input tensor and perform GAP as initialization
gap_cls, x = self.prepare_input(x)

# Prepare queries (q), keys (k), and values (v)
q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)

# Extract dimensions after normalization
Bq, Nq, dq = q.shape
Bk, Nk, dk = k.shape
Bv, Nv, dv = v.shape

# Check dimension consistency across batches and channels
assert Bq == Bk == Bv
assert dq == dk == dv

# Apply linear transformation for queries and keys then reshape
qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1, 3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1, 3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)

vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1, 3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)

# Compute attention scores
attn = (qq @ kk.transpose(-2, -1)) * self.scale
# Apply softmax for normalization
attn = attn.softmax(dim=-1)

# If gamma scaling is used
if self.gamma is not None:
# Apply gamma scaling on values and compute the weighted sum using attention scores
x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma), 1/self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
# If use_beta, add a learnable translation
if self.use_beta:
x = x + self.beta
else:
# Compute the weighted sum using attention scores
x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)

return x.squeeze()
Loading
Oops, something went wrong.