In [None]:
# default_exp models

# Image sequence models
> Models to predict the action class from a sequence of frames

We will built a bunch of models to read the `ImageTuple` and output the corresponding `Category`.

In [None]:
#export
from fastai.vision.all import *

In [None]:
torch.cuda.set_device(1)
torch.cuda.get_device_name()

'GeForce RTX 2070 SUPER'

## A resnet based Encoder
> Extracting features of images to latent variable space

Let's build a tensor representing a batch of images:
- `(batch_size, channels, width, hight)`

In [None]:
x = torch.rand(8, 3, 64, 64)

We will build a basic Resnet based encoder:

In [None]:
#export
@delegates(create_cnn_model)
class Encoder(Module):
    def __init__(self, arch=resnet34, n_in=3, weights_file=None, head=True, **kwargs):
        "Encoder based on resnet, if head=False returns the feature map"
        model = create_cnn_model(arch, n_out=1, n_in=n_in, pretrained=True, **kwargs)
        if weights_file is not None: load_model(weights_file, model, opt=None)
        self.body = model[0]
        if head: self.head = model[1]
        else:    self.head = nn.Sequential(*(model[1][0:3]))

    def forward(self, x):
        return self.head(self.body(x))

this encoder will reduce images to a latent dimension space:

In [None]:
enc = Encoder(n_in=3, weights_file=None, head=False)

In [None]:
enc.head

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In this case, 1024

In [None]:
encoded_var = enc(x)
encoded_var.shape

torch.Size([8, 1024])

In [None]:
test_eq(encoded_var.shape, [8,1024])

## Simple Model
> A very basic CNN model

This network is just using an old resnet and expanding the sequence dimesion on the batch dim. It is not optimal.

In [None]:
#export
class SimpleModel(Module):
    "A simple CNN model"
    def __init__(self, arch=resnet34, weights_file=None, num_classes=30, seq_len=40, debug=False):
        "Create a simple arch based model"
        model = Encoder(arch, 3, weights_file, head=False)
        nf = num_features_model(nn.Sequential(*model.body.children())) * 2
        self.encoder = model
        self.head = nn.Sequential(LinBnDrop(nf,  nf//2, p=0.2, act=nn.ReLU()),
                                  LinBnDrop(nf//2, num_classes, p=0.05))
        self.attention_layer = nn.Linear(nf, 1)
        self.debug = debug

    def forward(self, x):
        if self.debug:  print(f' input len:   {len(x), x[0].shape}')
        x = torch.stack(x, dim=1)
        if self.debug:  print(f' after stack:   {x.shape}')
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w)
        x = self.encoder(x)
        x = x.view(batch_size, seq_length, -1)
        if self.debug:  print(f' encoded shape: {x.shape}')
        attention_w = F.softmax(self.attention_layer(x).squeeze(-1), dim=-1)
        x = torch.sum(attention_w.unsqueeze(-1) * x, dim=1)
        if self.debug:  print(f' after attention shape: {x.shape}')
        x = self.head(x)
        return x

A splitter function to train separetely the parameers from the encoder and the head, this is a needed argument for the `Learner` to be able to call `Learner.freeze()`.

In [None]:
#export
def simple_splitter(model):
    return [params(model.encoder), params(model.attention_layer)+ params(model.head)]

A sequence of 10 images:

In [None]:
#bs, seq_len, ch, w, h
inp = [torch.rand(64, 3, 64, 64) for _ in range(10)]

In [None]:
sm = SimpleModel(debug=True, seq_len=10)
out = sm(inp)
test_eq(out.shape, [64, 30])

 input len:   (10, torch.Size([64, 3, 64, 64]))
 after stack:   torch.Size([64, 10, 3, 64, 64])
 encoded shape: torch.Size([64, 10, 1024])
 after attention shape: torch.Size([64, 1024])


## ConvLSTM
> An LSTM encoded image model

First the LSTM wrapper, with the `reset` method to erase hidden state before each epoch.

In [None]:
#export
class LSTM(Module):
    def __init__(self, input_dim, n_hidden, n_layers, bidirectional=False, p=0.5):
        self.lstm = nn.LSTM(input_dim, n_hidden, n_layers, batch_first=True, bidirectional=bidirectional)
        self.drop = nn.Dropout(p)
        self.h = None

    def reset(self):
        self.h = None

    def forward(self, x):
        if (self.h is not None) and (x.shape[0] != self.h[0].shape[1]): #dealing with last batch on valid
#             self.h = [h_[:, :x.shape[0], :] for h_ in self.h]
            self.h = None
        raw, h = self.lstm(x, self.h)
        out = self.drop(raw)
        self.h = [h_.detach() for h_ in h]
        return out, h

We will take as input_size the output of the encoder, so the `latent_dimesion`, the `num_layers` is how many `nn.LSTMCell` are stacked and hidden dim is the same as before.

Let's build a 16 layers LSTM stack:

In [None]:
lstm = LSTM(512, 512, 1, bidirectional=False)

In [None]:
# bs, input_dim, hidden_dim
y = torch.rand(32, 10,  512)

We get the same input, encoded on the hidden_dim

In [None]:
out, (h,c) = lstm(y)
out.shape, h.shape, c.shape

(torch.Size([32, 10, 512]), torch.Size([1, 32, 512]), torch.Size([1, 32, 512]))

It can deal with different batch sizes now:

In [None]:
out, (h,c) = lstm(torch.rand(16,10,512))
out.shape, h.shape, c.shape

(torch.Size([16, 10, 512]), torch.Size([1, 16, 512]), torch.Size([1, 16, 512]))

In [None]:
lstm = LSTM(512, 512, 3, bidirectional=True)

In [None]:
out, (h,c) = lstm(torch.rand(16,10,512))
out.shape,  h.shape, c.shape

(torch.Size([16, 10, 1024]),
 torch.Size([6, 16, 512]),
 torch.Size([6, 16, 512]))

In [None]:
#export
class ConvLSTM(Module):
    def __init__(self, arch=resnet34, weights_file=None, num_classes=30, lstm_layers=1, hidden_dim=1024, 
                 bidirectional=True, attention=True, debug=False):
        model = Encoder(arch, 3, weights_file, head=False)
        nf = num_features_model(nn.Sequential(*model.body.children())) * 2
        self.encoder = model
        self.lstm = LSTM(nf, hidden_dim, lstm_layers, bidirectional)
        self.attention = attention
        self.attention_layer = nn.Linear(2 * hidden_dim if bidirectional else hidden_dim, 1)
        self.head = nn.Sequential(
            LinBnDrop( (lstm_layers if not attention else 1)*(2 * hidden_dim if bidirectional else hidden_dim), 
                      hidden_dim, p=0.2, act=nn.ReLU()),
            nn.Linear(hidden_dim, num_classes),
        )
        self.debug = debug
        
    def forward(self, x):
        x = torch.stack(x, dim=1)
        if self.debug:  print(f' after stack:   {x.shape}')
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w)
        x = self.encoder(x)
        if self.debug:  print(f' after encode:   {x.shape}')
        x = x.view(batch_size, seq_length, -1)
        if self.debug:  print(f' before lstm:   {x.shape}')
        x, (h,c) = self.lstm(x)
        if self.debug:  print(f' after lstm:   {x.shape}')
        if self.attention:
            attention_w = F.softmax(self.attention_layer(x).squeeze(-1), dim=-1)
            if self.debug: print(f' attention_w: {attention_w.shape}')
            out = torch.sum(attention_w.unsqueeze(-1) * x, dim=1)
            if self.debug: print(f' after attention: {out.shape}')
        else:
            if self.debug: print(f' hidden state: {h.shape}')
            out = h.permute(1,0,2).flatten(1)
            if self.debug: print(f' hidden state flat: {out.shape}')
        return self.head(out)
    
    def reset(self): self.lstm.reset()

In [None]:
#export
def convlstm_splitter(model):
    return [params(model.encoder), params(model.lstm) + params(model.attention_layer) + params(model.head)]

In [None]:
#bs, seq_len, ch, w, h
inp = [torch.rand(32, 3, 64, 64) for _ in range(10)]

In [None]:
clstm = ConvLSTM(attention=False, bidirectional=False, lstm_layers=2, debug=True)
test_eq(clstm(inp).shape, [32, 30])

 after stack:   torch.Size([32, 10, 3, 64, 64])
 after encode:   torch.Size([320, 1024])
 before lstm:   torch.Size([32, 10, 1024])
 after lstm:   torch.Size([32, 10, 1024])
 hidden state: torch.Size([2, 32, 1024])
 hidden state flat: torch.Size([32, 2048])


In [None]:
clstm = ConvLSTM(lstm_layers=3, debug=True)
test_eq(clstm(inp).shape, [32, 30])

 after stack:   torch.Size([32, 10, 3, 64, 64])
 after encode:   torch.Size([320, 1024])
 before lstm:   torch.Size([32, 10, 1024])
 after lstm:   torch.Size([32, 10, 2048])
 attention_w: torch.Size([32, 10])
 after attention: torch.Size([32, 2048])


In [None]:
clstm = ConvLSTM(lstm_layers=1, debug=True)
test_eq(clstm(inp).shape, [32, 30])

 after stack:   torch.Size([32, 10, 3, 64, 64])
 after encode:   torch.Size([320, 1024])
 before lstm:   torch.Size([32, 10, 1024])
 after lstm:   torch.Size([32, 10, 2048])
 attention_w: torch.Size([32, 10])
 after attention: torch.Size([32, 2048])


## TimeSformer
> thanks LucidRains https://github.com/lucidrains/TimeSformer-pytorch

In [None]:
#export
from torch import nn, einsum
from einops import rearrange, repeat

In [None]:
#export
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, *args, **kwargs):
        x = self.norm(x)
        return self.fn(x, *args, **kwargs)

In [None]:
#export
class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
#export
def attn(q, k, v):
    sim = einsum('b i d, b j d -> b i j', q, k)
    attn = sim.softmax(dim = -1)
    out = einsum('b i j, b j d -> b i d', attn, v)
    return out

In [None]:
#export
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        dim_head = 64,
        heads = 8,
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, einops_from, einops_to, **einops_dims):
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        q *= self.scale

        # splice out classification token at index 1
        (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))

        # let classification token attend to key / values of all patches across time and space
        cls_out = attn(cls_q, k, v)

        # rearrange across time or space
        q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), (q_, k_, v_))

        # expand cls token keys and values across time or space and concat
        r = q_.shape[0] // cls_k.shape[0]
        cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r = r), (cls_k, cls_v))

        k_ = torch.cat((cls_k, k_), dim = 1)
        v_ = torch.cat((cls_v, v_), dim = 1)

        # attention
        out = attn(q_, k_, v_)

        # merge back time or space
        out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)

        # concat back the cls token
        out = torch.cat((cls_out, out), dim = 1)

        # merge back the heads
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)

        # combine heads out
        return self.to_out(out)

In [None]:
#export
class TimeSformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        num_frames,
        num_classes,
        image_size = 224,
        patch_size = 16,
        channels = 3,
        depth = 12,
        heads = 8,
        dim_head = 64,
        attn_dropout = 0.,
        ff_dropout = 0.
    ):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_size // patch_size) ** 2
        num_positions = num_frames * num_patches
        patch_dim = channels * patch_size ** 2

        self.patch_size = patch_size
        self.to_patch_embedding = nn.Linear(patch_dim, dim)
        self.pos_emb = nn.Embedding(num_positions + 1, dim)
        self.cls_token = nn.Parameter(torch.randn(1, dim))

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, Attention(dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)),
                PreNorm(dim, FeedForward(dim, dropout = ff_dropout))
            ]))

        self.to_out = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, video):
        video = torch.stack(video, dim=1)  #to deal with the ImageTuple
        b, f, _, h, w, *_, device, p = *video.shape, video.device, self.patch_size
        assert h % p == 0 and w % p == 0, f'height {h} and width {w} of video must be divisible by the patch size {p}'

        n = (h // p) * (w // p)

        video = rearrange(video, 'b f c (h p1) (w p2) -> b (f h w) (p1 p2 c)', p1 = p, p2 = p)
        tokens = self.to_patch_embedding(video)

        cls_token = repeat(self.cls_token, 'n d -> b n d', b = b)
        x =  torch.cat((cls_token, tokens), dim = 1)
        x += self.pos_emb(torch.arange(x.shape[1], device = device))

        for (time_attn, spatial_attn, ff) in self.layers:
            x = time_attn(x, 'b (f n) d', '(b n) f d', n = n) + x
            x = spatial_attn(x, 'b (f n) d', '(b f) n d', f = f) + x
            x = ff(x) + x

        cls_token = x[:, 0]
        return self.to_out(cls_token)

In [None]:
model = TimeSformer(
    dim = 128,
    image_size = 64,
    patch_size = 16,
    num_frames = 8,
    num_classes = 10,
    depth = 12,
    heads = 8,
    dim_head =  64,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)

In [None]:
video = tuple(torch.randn(2, 3, 64, 64) for _ in range(8)) # (batch x frames x channels x height x width)
test_eq(model(video).shape, (2,10))

# Export -

In [None]:
###### hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_models.ipynb.
Converted index.ipynb.
