In [1]:
# default_exp models

# Image sequence models
> Models to predict the action class from a sequence of frames

We will built a bunch of models to read the `ImageTuple` and output the corresponding `Category`.

In [2]:
#export
from fastai.vision.all import *

In [3]:
torch.cuda.set_device(1)
torch.cuda.get_device_name()

'GeForce RTX 2070 SUPER'

## A resnet based Encoder
> Extracting features of images to latent variable space

Let's build a tensor representing a batch of images:
- `(batch_size, channels, width, hight)`

In [4]:
x = torch.rand(8, 3, 64, 64)

We will build a basic Resnet based encoder:

In [5]:
#export
@delegates(create_cnn_model)
class Encoder(Module):
    def __init__(self, arch=resnet34, n_in=3, weights_file=None, head=True, **kwargs):
        "Encoder based on resnet, if head=False returns the feature map"
        model = create_cnn_model(arch, n_out=1, n_in=n_in, pretrained=True, **kwargs)
        if weights_file is not None: load_model(weights_file, model, opt=None)
        self.body = model[0]
        if head: self.head = model[1]
        else:    self.head = nn.Sequential(*(model[1][0:3]))

    def forward(self, x):
        return self.head(self.body(x))

this encoder will reduce images to a latent dimension space:

In [6]:
enc = Encoder(n_in=3, weights_file=None, head=False)

In [7]:
enc.head

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In this case, 1024

In [8]:
encoded_var = enc(x)
encoded_var.shape

torch.Size([8, 1024])

In [9]:
test_eq(encoded_var.shape, [8,1024])

## Simple Model
> A very basic CNN model

This network is just using an old resnet and expanding the sequence dimesion on the batch dim. It is not optimal.

In [10]:
#export
class SimpleModel(Module):
    "A simple CNN model"
    def __init__(self, arch=resnet34, weights_file=None, num_classes=30, seq_len=40, debug=False):
        "Create a simple arch based model"
        model = Encoder(arch, 3, weights_file, head=False)
        nf = num_features_model(nn.Sequential(*model.body.children())) * 2
        self.encoder = model
        self.head = nn.Sequential(LinBnDrop(nf,  nf//2, p=0.2, act=nn.ReLU()),
                                  LinBnDrop(nf//2, num_classes, p=0.05))
        self.attention_layer = nn.Linear(nf, 1)
        self.debug = debug

    def forward(self, x):
        if self.debug:  print(f' input len:   {len(x), x[0].shape}')
        x = torch.stack(x, dim=1)
        if self.debug:  print(f' after stack:   {x.shape}')
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w)
        x = self.encoder(x)
        x = x.view(batch_size, seq_length, -1)
        if self.debug:  print(f' encoded shape: {x.shape}')
        attention_w = F.softmax(self.attention_layer(x).squeeze(-1), dim=-1)
        x = torch.sum(attention_w.unsqueeze(-1) * x, dim=1)
        if self.debug:  print(f' after attention shape: {x.shape}')
        x = self.head(x)
        return x

A splitter function to train separetely the parameers from the encoder and the head, this is a needed argument for the `Learner` to be able to call `Learner.freeze()`.

In [11]:
#export
def simple_splitter(model):
    return [params(model.encoder), params(model.attention_layer)+ params(model.head)]

A sequence of 10 images:

In [12]:
#bs, seq_len, ch, w, h
inp = [torch.rand(64, 3, 64, 64) for _ in range(10)]

In [13]:
sm = SimpleModel(debug=True, seq_len=10)
out = sm(inp)
test_eq(out.shape, [64, 30])

 input len:   (10, torch.Size([64, 3, 64, 64]))
 after stack:   torch.Size([64, 10, 3, 64, 64])
 encoded shape: torch.Size([64, 10, 1024])
 after attention shape: torch.Size([64, 1024])


## ConvLSTM
> An LSTM encoded image model

First the LSTM wrapper, with the `reset` method to erase hidden state before each epoch.

In [14]:
#export
class LSTM(Module):
    def __init__(self, input_dim, n_hidden, n_layers, bidirectional=False, p=0.5):
        self.lstm = nn.LSTM(input_dim, n_hidden, n_layers, batch_first=True, bidirectional=bidirectional)
        self.drop = nn.Dropout(p)
        self.h = None

    def reset(self):
        self.h = None

    def forward(self, x):
        if (self.h is not None) and (x.shape[0] != self.h[0].shape[1]): #dealing with last batch on valid
#             self.h = [h_[:, :x.shape[0], :] for h_ in self.h]
            self.h = None
        raw, h = self.lstm(x, self.h)
        out = self.drop(raw)
        self.h = [h_.detach() for h_ in h]
        return out, h

We will take as input_size the output of the encoder, so the `latent_dimesion`, the `num_layers` is how many `nn.LSTMCell` are stacked and hidden dim is the same as before.

Let's build a 16 layers LSTM stack:

In [15]:
lstm = LSTM(512, 512, 1, bidirectional=False)

In [16]:
# bs, input_dim, hidden_dim
y = torch.rand(32, 10,  512)

We get the same input, encoded on the hidden_dim

In [17]:
out, (h,c) = lstm(y)
out.shape, h.shape, c.shape

(torch.Size([32, 10, 512]), torch.Size([1, 32, 512]), torch.Size([1, 32, 512]))

It can deal with different batch sizes now:

In [18]:
out, (h,c) = lstm(torch.rand(16,10,512))
out.shape, h.shape, c.shape

(torch.Size([16, 10, 512]), torch.Size([1, 16, 512]), torch.Size([1, 16, 512]))

In [19]:
lstm = LSTM(512, 512, 3, bidirectional=True)

In [20]:
out, (h,c) = lstm(torch.rand(16,10,512))
out.shape,  h.shape, c.shape

(torch.Size([16, 10, 1024]),
 torch.Size([6, 16, 512]),
 torch.Size([6, 16, 512]))

In [21]:
#export
class ConvLSTM(Module):
    def __init__(self, arch=resnet34, weights_file=None, num_classes=30, lstm_layers=1, hidden_dim=1024, 
                 bidirectional=True, attention=True, debug=False):
        model = Encoder(arch, 3, weights_file, head=False)
        nf = num_features_model(nn.Sequential(*model.body.children())) * 2
        self.encoder = model
        self.lstm = LSTM(nf, hidden_dim, lstm_layers, bidirectional)
        self.attention = attention
        self.attention_layer = nn.Linear(2 * hidden_dim if bidirectional else hidden_dim, 1)
        self.head = nn.Sequential(
            LinBnDrop( (lstm_layers if not attention else 1)*(2 * hidden_dim if bidirectional else hidden_dim), 
                      hidden_dim, p=0.2, act=nn.ReLU()),
            nn.Linear(hidden_dim, num_classes),
        )
        self.debug = debug
        
    def forward(self, x):
        x = torch.stack(x, dim=1)
        if self.debug:  print(f' after stack:   {x.shape}')
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w)
        x = self.encoder(x)
        if self.debug:  print(f' after encode:   {x.shape}')
        x = x.view(batch_size, seq_length, -1)
        if self.debug:  print(f' before lstm:   {x.shape}')
        x, (h,c) = self.lstm(x)
        if self.debug:  print(f' after lstm:   {x.shape}')
        if self.attention:
            attention_w = F.softmax(self.attention_layer(x).squeeze(-1), dim=-1)
            if self.debug: print(f' attention_w: {attention_w.shape}')
            out = torch.sum(attention_w.unsqueeze(-1) * x, dim=1)
            if self.debug: print(f' after attention: {out.shape}')
        else:
            if self.debug: print(f' hidden state: {h.shape}')
            out = h.permute(1,0,2).flatten(1)
            if self.debug: print(f' hidden state flat: {out.shape}')
        return self.head(out)
    
    def reset(self): self.lstm.reset()

In [22]:
#export
def convlstm_splitter(model):
    return [params(model.encoder), params(model.lstm) + params(model.attention_layer) + params(model.head)]

In [23]:
#bs, seq_len, ch, w, h
inp = [torch.rand(32, 3, 64, 64) for _ in range(10)]

In [24]:
clstm = ConvLSTM(attention=False, bidirectional=False, lstm_layers=2, debug=True)
test_eq(clstm(inp).shape, [32, 30])

 after stack:   torch.Size([32, 10, 3, 64, 64])
 after encode:   torch.Size([320, 1024])
 before lstm:   torch.Size([32, 10, 1024])
 after lstm:   torch.Size([32, 10, 1024])
 hidden state: torch.Size([2, 32, 1024])
 hidden state flat: torch.Size([32, 2048])


In [25]:
clstm = ConvLSTM(lstm_layers=3, debug=True)
test_eq(clstm(inp).shape, [32, 30])

 after stack:   torch.Size([32, 10, 3, 64, 64])
 after encode:   torch.Size([320, 1024])
 before lstm:   torch.Size([32, 10, 1024])
 after lstm:   torch.Size([32, 10, 2048])
 attention_w: torch.Size([32, 10])
 after attention: torch.Size([32, 2048])


In [26]:
clstm = ConvLSTM(lstm_layers=1, debug=True)
test_eq(clstm(inp).shape, [32, 30])

 after stack:   torch.Size([32, 10, 3, 64, 64])
 after encode:   torch.Size([320, 1024])
 before lstm:   torch.Size([32, 10, 1024])
 after lstm:   torch.Size([32, 10, 2048])
 attention_w: torch.Size([32, 10])
 after attention: torch.Size([32, 2048])


## TimeSformer
> thanks LucidRains https://github.com/lucidrains/TimeSformer-pytorch

In [27]:
# !pip install timesformer-pytorch

In [28]:
#export
from timesformer_pytorch import TimeSformer as _TimeSformer

In [29]:
#export
class TimeSformer(_TimeSformer):
    def forward(self, video):
        video = torch.stack(video, dim=1)  #to deal with the ImageTuple
        return super().forward(video)

In [30]:
model = TimeSformer(
    dim = 128,
    image_size = 64,
    patch_size = 16,
    num_frames = 8,
    num_classes = 10,
    depth = 12,
    heads = 8,
    dim_head =  64,
    attn_dropout = 0.1,
    ff_dropout = 0.1
)

In [31]:
video = tuple(torch.randn(2, 3, 64, 64) for _ in range(8)) # (batch x frames x channels x height x width)
test_eq(model(video).shape, (2,10))

## STAM - Pytorch

Implementation of <a href="https://arxiv.org/abs/2103.13915">STAM (Space Time Attention Model)</a>, yet another pure and simple SOTA attention model that bests all previous models in video classification. This corroborates the finding of <a href="https://github.com/lucidrains/TimeSformer-pytorch">TimeSformer</a>. Attention is all we need.

In [None]:
# !pip install stam-pytorch

In [40]:
#export
from stam_pytorch import STAM as _STAM

In [41]:
#export
class STAM(_STAM):
    def forward(self, video):
        video = torch.stack(video, dim=1)  #to deal with the ImageTuple
        return super().forward(video)

In [46]:
model = STAM(
    dim = 128,
    image_size = 64,     # size of image
    patch_size = 16,      # patch size
    num_frames = 8,       # number of image frames, selected out of video
    space_depth = 12,     # depth of vision transformer
    space_heads = 8,      # heads of vision transformer
    space_mlp_dim = 2048, # feedforward hidden dimension of vision transformer
    time_depth = 6,       # depth of time transformer (in paper, it was shallower, 6)
    time_heads = 8,       # heads of time transformer
    time_mlp_dim = 2048,  # feedforward hidden dimension of time transformer
    num_classes = 10,    # number of output classes
    space_dim_head = 64,  # space transformer head dimension
    time_dim_head = 64,   # time transformer head dimension
    dropout = 0.,         # dropout
    emb_dropout = 0.      # embedding dropout
)

In [47]:
video = tuple(torch.randn(2, 3, 64, 64) for _ in range(8)) # (batch x frames x channels x height x width)
test_eq(model(video).shape, (2,10))

# Export -

In [48]:
###### hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_models.ipynb.
Converted index.ipynb.
