## Recurrent Attention Model

In [2]:
import torch
import torch.nn as nn
from torchxai.base import XaiBase, XaiHook, XaiModel
from collections import OrderedDict
import matplotlib.pyplot as plt

In [None]:
class GlimpseSensor(nn.Module):
    """Glimpse Sensor"""
    def __init__(self, g, k, s):
        """
        retina and location encoding
        ----
        The retina encoding `ρ(x, l)` extracts `k` square 
        patches centered at location `l`, with the first patch 
        being `g_w` × `g_w` pixels in size, and each successive 
        patch having twice the width of the previous. 
        The k patches are then all resized to `g_w × g_w` 
        and concatenated. 
        Glimpse locations `l` were encoded 
        as real-valued (x, y) coordinates with 
        (0, 0) being the center of the image x and 
        (−1, −1) being the top left corner of x.
        
        args:
        - g: size of first square patches
        - k: number of patches
        - s: scaling factor to control patches
        """
        super(GlimpseSensor, self).__init__()
        self.g = g
        self.k = k
        selk.s = s
    
    def extract_patch(self, x, l):
        B, C, H, W = x.size()
        
    
    def encode_coordinates(self, l):
        """
        encode coordinates to (-1, 1) range
        - center: (0, 0)
        - topleft: (-1, -1)
        """
        
        return 
        
    
    def forward(self, x, l):
        """
        args:
        - x: current time step input data, (B, C, H, W)
        - l: previous time step location, (B, 2)
        
        returns:
        - rho: retina-like representation
        """
        pass

In [8]:
coord = torch.FloatTensor([[19, 8],[20, 20]])
H = 28
coords_normed = (0.5 * ((coord + 1.0)*H)).long()

In [9]:
x = coords_normed[:, 0] - (5 // 2)

In [10]:
x

tensor([278, 292])

In [11]:
(28 - 0.5) / 0.5

55.0

In [17]:
torch.arange(28, dtype=float) / 28.0

tensor([0.0000, 0.0357, 0.0714, 0.1071, 0.1429, 0.1786, 0.2143, 0.2500, 0.2857,
        0.3214, 0.3571, 0.3929, 0.4286, 0.4643, 0.5000, 0.5357, 0.5714, 0.6071,
        0.6429, 0.6786, 0.7143, 0.7500, 0.7857, 0.8214, 0.8571, 0.8929, 0.9286,
        0.9643], dtype=torch.float64)

## AnR

In [3]:
class AttentionHead(nn.Module):
    """AttentionHead"""
    def __init__(self, in_c, n_head):
        super(AttentionHead, self).__init__()
        """
        3.2 Attention Head
        
        args:
        - in_c: C
        - n_head: K
        """
        self.n_head = n_head
        self.conv = nn.Conv2d(in_c, n_head, kernel_size=3, padding=1, bias=False)
        self.diag = (1 - torch.eye(n_head, n_head))
        
    def forward(self, x):
        """
        args:
        - x: feature activations, (B, C, H, W)
        
        returns:
        - Tensor, K-attention masks(softmax by channel-wise)  (B, K, H, W)
        """
        B = x.size(0)
        conv_heads = self.conv(x)  # (B, C, H, W) > (B, K, H, W)
        masks = torch.softmax(conv_heads, dim=1)  # K-attention masks
        self.masks_score = masks.view(B, self.n_head, -1)
        return masks  # (B, K, H, W)
    
    def reg_loss(self):
        """
        calculate reg_loss
        """
        # (B, K, H*W) x (B, H*W, K) > (B, K, K)
        reg_loss = self.diag.to(self.masks.device) * torch.bmm(self.masks_score, self.masks_score.transpose(1, 2))
        return reg_loss

class AttentionOut(nn.Module):
    """AttentionOut"""
    def __init__(self, in_c, n_head, n_label=1, gate=False):
        super(AttentionOut, self).__init__()
        """
        3.3 Output head / 3.4 Layered attention gates
        
        args:
        - in_c: C
        - n_head: K
        - n_label: L
        - gate: if gate is `True`, returns attention gates
        """
        self.n_head = n_head
        self.n_label = n_label
        self.gate = gate
        if gate:
            assert self.n_label == 1, "Gate must set `n_label = 1`"
        self.conv = nn.Conv2d(in_c, n_head*n_label, kernel_size=3, padding=1, bias=False)

    def forward(self, x, masks):
        """
        args:
        - x: feature activations, (B, C, H, W)
        - masks: masks from `AttentionHead`, (B, K, H, W)
        
        returns:
        - scores: when `self.gate=False`, (B, K, L)
        - gates: when `self.gate=True`, (B, K, 1)
        """
        B = x.size(0)
        conv_outputs = self.conv(x)  # (B, C, H, W) > (B, K*L, H, W)
        outputs = conv_outputs.view(B, self.n_head, self.n_label, -1)  # (B, K, L, H*W)
        # (B, K, L, H*W) * (B, K, 1, H*W) > (B, K, L)
        scores = (outputs * masks.view(B, self.n_head, -1).unsqueeze(2)).sum(-1)
        if not self.gate:
            return scores
        else:
            # L = 1, returns Tensor (B, K, 1)
            gates = torch.softmax(torch.tanh(scores), dim=1)
            return gates

In [4]:
class AttentionModule(nn.Module):
    """AttentionModule"""
    def __init__(self, in_c, n_head, n_label, reg_weight=0.0):
        """
        calculate outputs of attention module
        
        args:
        - in_c: the number of input channels
        - n_head: the attention width, the number of layers using the attention mechanism
        - n_label: the number of class labels
        """
        super(AttentionModule, self).__init__()
        self.reg_weight = reg_weight
        self.attn_heads = AttentionHead(in_c, n_head)
        self.output_heads = AttentionOut(in_c, n_head, n_label=n_label, gate=False)
        self.attn_gates = AttentionOut(in_c, n_head, gate=True)
        
    def forward(self, x):
        """
        args:
        - x: feature activations, (B, C, H, W)
        
        returns:
        - outputs_vectors: predict vectors which applied the most meaningful attention head, (B, L)
        
        """
        masks = self.attn_heads(x)  # (B, K, H, W)  softmax: dim=1
        outputs = self.output_heads(x, masks)  # (B, K, L)
        gates = self.attn_gates(x, masks)  # (B, K, 1)  softmax: dim=1
        outputs_vectors = (outputs * gates).sum(1)  # (B, L)
        return outputs_vectors
        
    def reg_loss(self):
        return self.attn_heads.reg_loss() * self.reg_weight


class GlobalAttentionGate(nn.Module):
    """GlobalAttentionGate"""
    def __init__(self, in_c, n_hypothesis, gate_fn="softmax"):
        """
        args:
        - in_c: the number of input channels
        - n_hypothesis: the number of total hypothesis, including original output vector and attention outputs(count as: N)
        """
        super(GlobalAttentionGate, self).__init__()
        self.n_hypothesis = n_hypothesis
        self.gate_fn = gate_fn
        self.gate_layer = nn.Linear(in_c, n_gates, bias=False)
    
    def cal_global_gates(self, x):
        x = torch.flatten(x, 1)  # (B, C*H*W) > (B, N+1)
        c = torch.tanh(self.gate_layer(x))
        if self.gate_fn == "softmax":
            global_gates = torch.softmax(c, dim=1)
        elif self.gate_fn == "sigmoid":
            global_gates = torch.sigmoid(c)
        return global_gates.unsqueeze(-1)  # (B, N+1, 1)
    
    def forward(self, x, hypothesis):
        """
        args:
        - x: last feature activations, (B, C, H, W)
        - hypothesis: all hypothesis, list type contains N+1 of (B, 1, L) size Tensor
        
        returns:
        - global_gates: (B, N+1)
        """
        global_gates = self.cal_global_gates(x)  # (B, N+1, 1)
        outputs = torch.cat(hypothesis, dim=1)  #(B, N+1, L)
        outputs = torch.log_softmax(outputs, dim=2)  # calculate log probs
        outputs_net = (outputs * global_gates).sum(1)  # (B, L)
        return outputs_net

In [13]:
x = torch.randn(1, 3, 28, 28)
in_c = 3
n_head = 5
n_label = 10
attn_heads = AttentionHead(in_c, n_head)
output_heads = AttentionOut(in_c, n_head, n_label)
attn_gates = AttentionOut(in_c, n_head, gate=True)
attn_m = AttentionModule(in_c, n_head, n_label)
n_gates = 4
attn_gg = GlobalAttentionGate(in_c, n_gates)

In [14]:
H = attn_heads(x)
O = output_heads(x, H)
G = attn_gates(x, H)

print(H.size(), O.size(), G.size())

torch.Size([1, 5, 28, 28]) torch.Size([1, 5, 10]) torch.Size([1, 5, 1])


In [15]:
OV = attn_m(x)
attn_gg()

TypeError: forward() missing 1 required positional argument: 'x'

In [19]:
nn.functional.avg_pool2d(H, 2, 2).size()

torch.Size([1, 5, 14, 14])