<a href="https://colab.research.google.com/github/soumyadip1995/language-models/blob/main/Notebook/Spatialtransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
from torch import nn
import math
import numpy as np 
import torch.nn.functional  as F
from math import sqrt


words = open(r"/content/text.txt", 'r' , encoding='utf-8').read().split()
# words[:20]


chars = sorted(list(set(words)))
string2integer = {ch: i for i, ch in enumerate(chars)}
# print(string2integer)

integer2string = {i:ch for ch,i in string2integer.items()}
encode = lambda s: [string2integer[c] for c in s]
# print(encode)

decode = lambda l: ''.join([integer2string[i] for i in l])
# print(decode)

data = torch.tensor(encode(words), dtype = torch.long)
# print(data)
# data.size()

## block_size and batch size has been changed from 64 and 512 to 32 and 128
block_size = 32
batch_size = 128
ix = torch.randint(len(data) - block_size, (batch_size,))

## hidden dimensionality has been changed from 512 to 128.

vocab_size = len(chars)
d_k = 128
token_emb = nn.Embedding(vocab_size, d_k)


x = torch.stack([data[i:i + block_size] for i in ix])
input_embeds = token_emb(x)
# input_embeds.size()


def scaled_dot_product(query, key, value):
  dim_k = query.size(-1)
  scores = torch.bmm(query, key.transpose(-2, -1)) / sqrt(dim_k)
  weights = F.softmax(scores, dim = -1)
  return torch.bmm(weights, value)

key = input_embeds
query = input_embeds
value = input_embeds

# sdp = scaled_dot_product(query, key, value)
# print(sdp.size())

### Multi headed attention

"""Having many heads allows the model to focus on different parts of the sentences. 
The softmax on one head tends to focus on one aspect of similarity. For example subject verb interaction."""
## A single attention head

class AttentionHead(nn.Module):
  def __init__(self, embedded_dim, head_dim):
    super().__init__()
    self.q = nn.Linear(embedded_dim, head_dim)
    self.k = nn.Linear(embedded_dim,  head_dim)
    self.v = nn.Linear(embedded_dim,  head_dim)

  def forward(self, x):
    attention_outputs = scaled_dot_product(self.q(x), self.k(x), self.v(x))
    

    return attention_outputs

# embedding_dim = embedding dimensions
# num_heads  = number of heads 


class MultiHeadAttention(nn.Module):
  def __init__(self, embedded_dim, num_heads):
    super().__init__()
    self.embedded_dim = embedded_dim
    self.num_heads = num_heads
    head_dim = embedded_dim // num_heads 

    self.heads = nn.ModuleList([AttentionHead(embedded_dim, head_dim) for _ in range(num_heads)])
    self.output_linear = nn.Linear(embedded_dim, embedded_dim)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim = -1)
    
    out = self.output_linear(out)

    return out

# multihead_attention = MultiHeadAttention(128, 8)
# # multihead_attention

# attention_outputs =  multihead_attention(input_embeds)
# # print(attention_outputs.size())


# from karpathy , partially
dropout = 0.2

class FeedForward(nn.Module):
  def __init__(self, embedded_dim):
    super(FeedForward, self).__init__()
    self.net = nn.Sequential(nn.Linear(embedded_dim, 4 * embedded_dim),
    nn.Linear(4 * embedded_dim, embedded_dim),
    nn.GELU(),
    nn.Dropout(dropout))

  def forward(self, x):
    return self.net(x)


### A simple Transformer Block    
class Transformer(nn.Module):
  def __init__(self, embedded_dim, num_heads):
    super(Transformer, self).__init__()
    self.attention = MultiHeadAttention(embedded_dim,  num_heads)
    self.feed_forward = FeedForward(embedded_dim)
    self.layer_norm_1 = nn.LayerNorm(embedded_dim)
    self.layer_norm_2 = nn.LayerNorm(embedded_dim)

  def forward(self, x):
    
    x = x + self.attention(self.layer_norm_1(x))
    x = x + self.feed_forward(self.layer_norm_2(x))
    return x

btt = Transformer(128, 8)
to = btt(input_embeds)
print(to.size())



torch.Size([128, 32, 128])


From [ldm](https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/attention.py)

In [9]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Spatial Attention Head

In [7]:
import torch
from torch import nn
import math
import numpy as np 
import torch.nn.functional  as F
from math import sqrt
from einops import rearrange, repeat


class SpatialAttentionHead(nn.Module):
  def __init__(self, in_channels):
    super(SpatialAttentionHead, self).__init__()
    self.in_channels = in_channels
    self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
    self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
    self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
    self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
    self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)
    

  def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b,c,h,w = q.shape
        q = rearrange(q, 'b c h w -> b (h w) c')
        k = rearrange(k, 'b c h w -> b c (h w)')
        w_ = torch.einsum('bij,bjk->bik', q, k)

        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values
        v = rearrange(v, 'b c h w -> b c (h w)')
        w_ = rearrange(w_, 'b i j -> b j i')
        h_ = torch.einsum('bij,bjk->bik', v, w_)
        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
        h_ = self.proj_out(h_)

        return x+h_


In [8]:
x = torch.randn(3, 32, 64, 64)

stn = SpatialAttentionHead(32)
st = stn(x)
st[:5]

tensor([[[[ 2.4144e+00, -6.2536e-01, -4.5577e-01,  ...,  9.1269e-01,
           -5.6247e-01, -1.6154e+00],
          [ 3.5866e-01, -3.4775e-01,  5.0385e-01,  ..., -3.4040e-01,
           -2.5032e-01, -1.6667e+00],
          [-1.3572e+00,  1.3564e+00,  1.2945e+00,  ..., -8.7441e-03,
            1.8555e+00,  1.2009e+00],
          ...,
          [-1.7396e+00,  4.1048e-02, -1.3828e-01,  ..., -9.7439e-01,
           -8.5793e-01,  6.3062e-01],
          [-1.5686e+00,  1.2455e+00, -5.5006e-01,  ...,  4.4212e-01,
            8.0098e-01,  3.0116e-01],
          [-1.4802e-01,  2.1312e-01, -1.3096e+00,  ..., -9.9902e-01,
            5.3561e-02, -2.5475e+00]],

         [[-1.3619e+00, -8.7990e-02,  1.6321e+00,  ...,  4.3035e-01,
            8.3480e-01, -9.4382e-01],
          [-3.1477e-01, -6.3020e-01, -7.7731e-01,  ..., -8.1120e-01,
           -5.2567e-01, -2.4942e-02],
          [-1.2581e+00, -2.6787e+00,  8.3537e-01,  ..., -1.5512e+00,
            3.9807e-01,  1.4524e-01],
          ...,
     

#### Spatial Transformer

In [9]:

class SpatialTransformer(nn.Module):
  def __init__(self, in_channels, embedded_dim, num_heads):
    super(SpatialTransformer, self).__init__()
    self.in_channels = in_channels
    self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
    self.proj_in = nn.Conv2d(in_channels,
                                 embedded_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
    self.transformer_blocks = nn.ModuleList([Transformer(embedded_dim, num_heads)for _ in range(num_heads)])
    self.proj_out = nn.Conv2d(embedded_dim, in_channels, kernel_size=1, stride=1, padding=0)
    
  def forward(self, x):
    # note: if no context is given, cross-attention defaults to self-attention
    b, c, h, w = x.shape
    x_in = x
    x = self.norm(x)
    x = self.proj_in(x)
    x = rearrange(x, 'b c h w -> b (h w) c')
    for block in self.transformer_blocks:
        x = block(x)
    x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
    x = self.proj_out(x)
    return x + x_in

In [10]:
x = torch.randn(3, 32, 48, 48)
stn = SpatialTransformer(32, 16, 8)
st = stn(x)
st[:5]

tensor([[[[-7.3886e-01,  6.3820e-01,  4.4061e-01,  ...,  1.2937e+00,
           -1.8474e+00,  2.9752e+00],
          [ 7.4824e-01, -4.1092e-01,  3.0088e-01,  ...,  1.1771e+00,
            5.1551e-01, -1.0566e-01],
          [-2.4397e-01,  7.3126e-01, -5.7814e-01,  ..., -1.1877e+00,
           -6.8078e-01, -5.3009e-01],
          ...,
          [-2.8724e-01, -1.2830e+00,  9.1008e-01,  ...,  6.8512e-01,
            1.3750e+00,  6.0598e-01],
          [ 4.5885e-01,  6.3332e-01,  3.6134e-02,  ..., -1.8209e+00,
           -1.2418e-01,  1.7355e+00],
          [ 1.9166e+00,  8.2430e-01,  2.6554e-01,  ..., -5.1318e-01,
           -6.7094e-01, -5.3458e-01]],

         [[ 2.0268e-01, -6.4168e-02, -4.4359e-01,  ...,  5.4447e-01,
            4.5099e-01,  5.6532e-01],
          [-4.3931e-01,  1.1415e+00,  9.5912e-01,  ..., -5.3576e-01,
            1.7886e+00,  1.6422e+00],
          [-7.3811e-02,  9.1319e-01,  1.4463e+00,  ...,  5.1679e-01,
            3.6386e-02,  4.2644e-01],
          ...,
     

In [3]:
x = torch.randn(3, 32, 64, 64)
x.shape

torch.Size([3, 32, 64, 64])

In [6]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class SpatialTransformer(nn.Module):
#     def __init__(self, input_size, output_size):
#         super(SpatialTransformer, self).__init__()

#         self.input_size = input_size
#         self.output_size = output_size

#         # Localization network
#         self.localization = nn.Sequential(
#             nn.Conv2d(input_size[0], 8, kernel_size=7),
#             nn.MaxPool2d(2, stride=2),
#             nn.ReLU(True),
#             nn.Conv2d(8, 10, kernel_size=5),
#             nn.MaxPool2d(2, stride=2),
#             nn.ReLU(True),
#             nn.Conv2d(10, 12, kernel_size=3),
#             nn.MaxPool2d(2, stride=2),
#             nn.ReLU(True)
#         )

#         # Output size of the localization network
#         out_size = self._get_output_size(input_size)

#         # Affine transformation matrix theta
#         self.fc_loc = nn.Sequential(
#             nn.Linear(out_size, 32),
#             nn.ReLU(True),
#             nn.Linear(32, 3 * 2)
#         )

#         # Initialize theta to identity transformation
#         self.fc_loc[2].weight.data.zero_()
#         self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

#     def _get_output_size(self, input_size):
#         input_tensor = torch.zeros(1, *input_size)
#         output_tensor = self.localization(input_tensor)
#         out_size = output_tensor.data.size()[1:]
#         return int(torch.prod(torch.tensor(out_size)))

#     def stn(self, x):
#         # Get theta
#         theta = self.fc_loc(x.view(-1, self._get_output_size(self.input_size)))
#         theta = theta.view(-1, 2, 3)

#         # Generate grid
#         grid = F.affine_grid(theta, torch.Size((x.size(0), self.input_size[0], self.output_size[1], self.output_size[2])))

#         # Apply transformation
#         x = F.grid_sample(x, grid)

#         return x

#     def forward(self, x):
#         # Apply localization network
#         x = self.localization(x)

#         # Apply spatial transformation
#         x = self.stn(x)

#         return x

In [None]:
# input_size = (3, 64, 64)
# output_size = (3, 128, 128)

# st = SpatialTransformer(input_size, output_size)

# input_data = torch.randn(1, 3, 64, 64)
# output_data = st(input_data)
# output_data.size()



torch.Size([1, 12, 128, 128])