In [None]:
pip install -U albumentations

In [None]:
pip install ptflops

# Model build

In [None]:
import torch
import torch.nn as nn

class STMEM(nn.Module):
    def __init__(self, num_segments, new_length, img_size=(224, 224)):
        super(STMEM, self).__init__()
        self.num_segments = num_segments
        self.new_length = new_length
        self.height, self.width = img_size

        self.sigmoid = nn.Sigmoid()

        self.m1 = nn.Sequential(
            nn.Conv2d(
                in_channels=(self.new_length * 2 - 1) * 3,
                out_channels=3,
                kernel_size=3,
                stride=1,
                padding=1
            )
        )

        self.m2 = nn.Sequential(
            nn.Conv2d(
                in_channels=3,
                out_channels=3,
                kernel_size=3,
                stride=1,
                padding=1
            )
        )

    def forward(self, x):
        # Input shape: (B, S * L * 3, H, W)
        B, SLC, H, W = x.size()
        assert H == self.height and W == self.width, \
            f"Expected input height={self.height}, width={self.width} but got {H}x{W}"

        # Reshape to: (B * S, L * 3, H, W)
        x = x.view(B * self.num_segments, self.new_length * 3, self.height, self.width)

        # Compute frame differences (temporal modeling)
        frame_diff = x[:, 3:] - x[:, : (self.new_length - 1) * 3]

        # Concatenate original input with motion information
        x_with_diff = torch.cat((x, frame_diff), dim=1)
        x_with_diff = self.m1(x_with_diff)

        # Get max motion frame
        frame_diff = frame_diff.view(B * self.num_segments, self.new_length - 1, 3, self.height, self.width)
        frame_diff = frame_diff.max(dim=1)[0]

        frame_diff = self.m2(frame_diff)
        motion_mask = self.sigmoid(frame_diff)

        # Apply motion mask
        output = motion_mask * x_with_diff

        return output  # shape: (B * S, 3, H, W)

if __name__ == '__main__':
    a = torch.rand([4,90,224,224])
    model = STMEM(num_segments=5,new_length=6)
    out = model(a)
    print(out.size())

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class TemporalShift(nn.Module):
    def __init__(self, net, n_segment=3, n_div=8, inplace=False):
        super(TemporalShift, self).__init__()
        self.net = net
        self.n_segment = n_segment
        self.fold_div = n_div
        self.inplace = inplace
        if inplace:
            print('=> Using in-place shift...')
        print('=> Using fold div: {}'.format(self.fold_div))

    def forward(self, x):
        x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
        return self.net(x)

    @staticmethod
    def shift(x, n_segment, fold_div=3, inplace=False):
        nt, c, h, w = x.size()
        n_batch = nt // n_segment
        x = x.view(n_batch, n_segment, c, h, w)

        fold = c // fold_div
        if inplace:
            # Due to some out of order error when performing parallel computing. 
            # May need to write a CUDA kernel.
            raise NotImplementedError  
            # out = InplaceShift.apply(x, fold)
        else:
            out = torch.zeros_like(x)
            out[:, :-1, :fold] = x[:, 1:, :fold]  # shift left
            out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold]  # shift right
            out[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shift

        return out.view(nt, c, h, w)


class InplaceShift(torch.autograd.Function):
    # Special thanks to @raoyongming for the help to this function
    @staticmethod
    def forward(ctx, input, fold):
        # not support higher order gradient
        # input = input.detach_()
        ctx.fold_ = fold
        n, t, c, h, w = input.size()
        buffer = input.data.new(n, t, fold, h, w).zero_()
        buffer[:, :-1] = input.data[:, 1:, :fold]
        input.data[:, :, :fold] = buffer
        buffer.zero_()
        buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold]
        input.data[:, :, fold: 2 * fold] = buffer
        return input

    @staticmethod
    def backward(ctx, grad_output):
        # grad_output = grad_output.detach_()
        fold = ctx.fold_
        n, t, c, h, w = grad_output.size()
        buffer = grad_output.data.new(n, t, fold, h, w).zero_()
        buffer[:, 1:] = grad_output.data[:, :-1, :fold]
        grad_output.data[:, :, :fold] = buffer
        buffer.zero_()
        buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold]
        grad_output.data[:, :, fold: 2 * fold] = buffer
        return grad_output, None


class TemporalPool(nn.Module):
    def __init__(self, net, n_segment):
        super(TemporalPool, self).__init__()
        self.net = net
        self.n_segment = n_segment

    def forward(self, x):
        x = self.temporal_pool(x, n_segment=self.n_segment)
        return self.net(x)

    @staticmethod
    def temporal_pool(x, n_segment):
        nt, c, h, w = x.size()
        n_batch = nt // n_segment
        x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2)  # n, c, t, h, w
        x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
        x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
        return x


def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
    if temporal_pool:
        n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
    else:
        n_segment_list = [n_segment] * 4
    assert n_segment_list[-1] > 0
    print('=> n_segment per stage: {}'.format(n_segment_list))

    import torchvision
    # if isinstance(net, resnet.ResNet):
    if isinstance(net, torchvision.models.ResNet):
        if place == 'block':
            def make_block_temporal(stage, this_segment):
                blocks = list(stage.children())
                print('=> Processing stage with {} blocks'.format(len(blocks)))
                for i, b in enumerate(blocks):
                    blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
                return nn.Sequential(*(blocks))

            net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
            net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
            net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
            net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])

        elif 'blockres' in place:
            n_round = 1
            if len(list(net.layer3.children())) >= 23:
                n_round = 2
                print('=> Using n_round {} to insert temporal shift'.format(n_round))

            def make_block_temporal(stage, this_segment):
                blocks = list(stage.children())
                print('=> Processing stage with {} blocks residual'.format(len(blocks)))
                for i, b in enumerate(blocks):
                    if i % n_round == 0:
                        blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div)
                return nn.Sequential(*blocks)

            net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
            net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
            net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
            net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
    else:
        raise NotImplementedError(place)


def make_temporal_pool(net, n_segment):
    import torchvision
    if isinstance(net, torchvision.models.ResNet):
        print('=> Injecting nonlocal pooling')
        net.layer2 = TemporalPool(net.layer2, n_segment)
    else:
        raise NotImplementedError


In [None]:
# Non-local block using embedded gaussian
# Code from
# https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py
class _NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
        super(_NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)
        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

        return z


class NONLocalBlock1D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock2D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock2D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=2, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NONLocalBlock3D(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock3D, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=3, sub_sample=sub_sample,
                                              bn_layer=bn_layer)


class NL3DWrapper(nn.Module):
    def __init__(self, block, n_segment):
        super(NL3DWrapper, self).__init__()
        self.block = block
        self.nl = NONLocalBlock3D(block.bn3.num_features)
        self.n_segment = n_segment

    def forward(self, x):
        x = self.block(x)

        nt, c, h, w = x.size()
        x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2)  # n, c, t, h, w
        x = self.nl(x)
        x = x.transpose(1, 2).contiguous().view(nt, c, h, w)
        return x


def make_non_local(net, n_segment):
    import torchvision
    if isinstance(net, torchvision.models.ResNet):
        net.layer2 = nn.Sequential(
            NL3DWrapper(net.layer2[0], n_segment),
            net.layer2[1],
            NL3DWrapper(net.layer2[2], n_segment),
            net.layer2[3],
        )
        net.layer3 = nn.Sequential(
            NL3DWrapper(net.layer3[0], n_segment),
            net.layer3[1],
            NL3DWrapper(net.layer3[2], n_segment),
            net.layer3[3],
            NL3DWrapper(net.layer3[4], n_segment),
            net.layer3[5],
        )
    else:
        raise NotImplementedError


if __name__ == '__main__':
    from torch.autograd import Variable
    import torch

    sub_sample = True
    bn_layer = True

    img = Variable(torch.zeros(2, 3, 20))
    net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
    out = net(img)
    print(out.size())

    img = Variable(torch.zeros(2, 3, 20, 20))
    net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
    out = net(img)
    print(out.size())

    img = Variable(torch.randn(2, 3, 10, 20, 20))
    net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
    out = net(img)
    print(out.size())

In [None]:
class Identity(torch.nn.Module):
    def forward(self, input):
        return input


class SegmentConsensus(torch.nn.Module):

    def __init__(self, consensus_type, dim=1):
        super(SegmentConsensus, self).__init__()
        self.consensus_type = consensus_type
        self.dim = dim
        self.shape = None

    def forward(self, input_tensor):
        self.shape = input_tensor.size()
        if self.consensus_type == 'avg':
            output = input_tensor.mean(dim=self.dim, keepdim=True)
        elif self.consensus_type == 'identity':
            output = input_tensor
        else:
            output = None

        return output


class ConsensusModule(torch.nn.Module):

    def __init__(self, consensus_type, dim=1):
        super(ConsensusModule, self).__init__()
        self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
        self.dim = dim

    def forward(self, input):
        return SegmentConsensus(self.consensus_type, self.dim)(input)

In [None]:
from torch import nn
from torch.nn.init import normal_, constant_


class TSN(nn.Module):
    def __init__(self, num_class, num_segments, modality, img_size = (224, 224),
                 base_model='resnet101', new_length=None,
                 consensus_type='avg', before_softmax=True,
                 dropout=0.8, img_feature_dim=256,
                 crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet',
                 is_shift=False, shift_div=8, shift_place='blockres', fc_lr5=False,
                 temporal_pool=False, non_local=False, data_length=1):
        super(TSN, self).__init__()

        self.TSM_intrada = STMEM(num_segments=num_segments, new_length=data_length, img_size = img_size)
        self.modality = modality
        self.img_size = img_size
        self.num_segments = num_segments
        self.reshape = True
        self.before_softmax = before_softmax
        self.dropout = dropout
        self.crop_num = crop_num
        self.consensus_type = consensus_type
        self.img_feature_dim = img_feature_dim  # the dimension of the CNN feature to represent each frame
        self.pretrain = pretrain

        self.is_shift = is_shift
        self.shift_div = shift_div
        self.shift_place = shift_place
        self.base_model_name = base_model
        self.fc_lr5 = fc_lr5
        self.temporal_pool = temporal_pool
        self.non_local = non_local
        self.data_length = data_length

        if not before_softmax and consensus_type != 'avg':
            raise ValueError("Only avg consensus can be used after Softmax")

        if new_length is None:
            self.new_length = 1 if modality in ["RGB", 'depth', "motion", "dense"] else 5
        else:
            self.new_length = new_length
        if print_spec:
            print((""" 
    Initializing TSN with base model: {}.
    TSN Configurations:
        input_modality:     {}
        num_segments:       {}
        new_length:         {}
        STMEM_new-length:   {}
        consensus_module:   {}
        dropout_ratio:      {}
        img_feature_dim:    {}
            """.format(base_model, self.modality, self.num_segments, self.new_length, self.data_length, consensus_type, self.dropout, self.img_feature_dim)))

        self._prepare_base_model(base_model)

        feature_dim = self._prepare_tsn(num_class)

        if self.modality == 'Flow':
            print("Converting the ImageNet model to a flow init model")
            self.base_model = self._construct_flow_model(self.base_model)
            print("Done. Flow model ready...")
        elif self.modality == 'RGBDiff':
            print("Converting the ImageNet model to RGB+Diff init model")
            self.base_model = self._construct_diff_model(self.base_model)
            print("Done. RGBDiff model ready.")

        self.consensus = ConsensusModule(consensus_type)

        if not self.before_softmax:
            self.softmax = nn.Softmax()

        self._enable_pbn = partial_bn
        if partial_bn:
            self.partialBN(True)

    def _prepare_tsn(self, num_class):
        feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features
        if self.dropout == 0:
            setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class))
            self.new_fc = None
        else:
            setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout))
            self.new_fc = nn.Linear(feature_dim, num_class)

        std = 0.001
        if self.new_fc is None:
            normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std)
            constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0)
        else:
            if hasattr(self.new_fc, 'weight'):
                normal_(self.new_fc.weight, 0, std)
                constant_(self.new_fc.bias, 0)
        return feature_dim

    def _prepare_base_model(self, base_model):
        print('=> base model: {}'.format(base_model))

        if 'resnet' in base_model:
            import torchvision
            self.base_model = getattr(torchvision.models, base_model)(True if self.pretrain == 'imagenet' else False)
            if self.is_shift:
                print('Adding temporal shift...')
                make_temporal_shift(self.base_model, self.num_segments,
                                    n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool)

            if self.non_local:
                print('Adding non-local module...')
                make_non_local(self.base_model, self.num_segments)

            self.base_model.last_layer_name = 'fc'
            self.input_size = self.img_size[0]
            self.input_mean = [0.485, 0.456, 0.406]
            self.input_std = [0.229, 0.224, 0.225]

            self.base_model.avgpool = nn.AdaptiveAvgPool2d(1)

            if self.modality == 'Flow':
                self.input_mean = [0.5]
                self.input_std = [np.mean(self.input_std)]
            elif self.modality == 'RGBDiff':
                self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length
                self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length

        else:
            raise ValueError("Unknown base model: {}".format(base_model))

    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        :return:
        """
        super(TSN, self).train(mode)
        count = 0
        if self._enable_pbn and mode:
            print("Freezing BatchNorm2D except the first one.")
            for m in self.base_model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    count += 1
                    if count >= (2 if self._enable_pbn else 1):
                        m.eval()
                        # shutdown update in frozen mode
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False

    def partialBN(self, enable):
        self._enable_pbn = enable

    def get_optim_policies(self):
        first_conv_weight = []
        first_conv_bias = []
        normal_weight = []
        normal_bias = []
        lr5_weight = []
        lr10_bias = []
        bn = []
        custom_ops = []

        conv_cnt = 0
        bn_cnt = 0
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d):
                ps = list(m.parameters())
                conv_cnt += 1
                if conv_cnt == 1:
                    first_conv_weight.append(ps[0])
                    if len(ps) == 2:
                        first_conv_bias.append(ps[1])
                else:
                    normal_weight.append(ps[0])
                    if len(ps) == 2:
                        normal_bias.append(ps[1])
            elif isinstance(m, torch.nn.Linear):
                ps = list(m.parameters())
                if self.fc_lr5:
                    lr5_weight.append(ps[0])
                else:
                    normal_weight.append(ps[0])
                if len(ps) == 2:
                    if self.fc_lr5:
                        lr10_bias.append(ps[1])
                    else:
                        normal_bias.append(ps[1])

            elif isinstance(m, torch.nn.BatchNorm2d):
                bn_cnt += 1
                # later BN's are frozen
                if not self._enable_pbn or bn_cnt == 1:
                    bn.extend(list(m.parameters()))
            elif isinstance(m, torch.nn.BatchNorm3d):
                bn_cnt += 1
                # later BN's are frozen
                if not self._enable_pbn or bn_cnt == 1:
                    bn.extend(list(m.parameters()))
            elif len(m._modules) == 0:
                if len(list(m.parameters())) > 0:
                    raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))

        return [
            {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1,
             'name': "first_conv_weight"},
            {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0,
             'name': "first_conv_bias"},
            {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,
             'name': "normal_weight"},
            {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,
             'name': "normal_bias"},
            {'params': bn, 'lr_mult': 1, 'decay_mult': 0,
             'name': "BN scale/shift"},
            {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1,
             'name': "custom_ops"},
            # for fc
            {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1,
             'name': "lr5_weight"},
            {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0,
             'name': "lr10_bias"},
        ]

    def forward(self, input, no_reshape=False):

        output_intrada = self.TSM_intrada(input)#
        base_out = self.base_model(output_intrada)

        # print('base out', base_out.size())

        if self.dropout > 0:
            base_out = self.new_fc(base_out)
            # print('after last', base_out.size())

        if not self.before_softmax:
            base_out = self.softmax(base_out)

        if self.reshape:
            # print('In reshape')
            if self.is_shift and self.temporal_pool:
                # print('use shift')
                base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:])
                # print('Out', base_out.size())
            else:
                # print('NO use shift')
                
                base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:])
            output = self.consensus(base_out)
            return output.squeeze(1)

    def _construct_flow_model(self, base_model):
        # modify the convolution layers
        modules = list(self.base_model.modules())
        first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]
        conv_layer = modules[first_conv_idx]
        container = modules[first_conv_idx - 1]

        # modify parameters, assume the first blob contains the convolution kernels
        params = [x.clone() for x in conv_layer.parameters()]
        kernel_size = params[0].size()
        new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:]
        new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()

        new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels,
                             conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
                             bias=True if len(params) == 2 else False)
        new_conv.weight.data = new_kernels
        if len(params) == 2:
            new_conv.bias.data = params[1].data # add bias if neccessary
        layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name

        # replace the first convlution layer
        setattr(container, layer_name, new_conv)

        return base_model

    def _construct_diff_model(self, base_model, keep_rgb=False):
        # modify the convolution layers
        modules = list(self.base_model.modules())
        first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0]

        conv_layer = modules[first_conv_idx]
        container = modules[first_conv_idx - 1]

        # modify parameters, assume the first blob contains the convolution kernels
        params = [x.clone() for x in conv_layer.parameters()]
        kernel_size = params[0].size()
        if not keep_rgb:
            new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
            new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()
        else:
            new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:]
            new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()),
                                    1)
            new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:]

        new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels,
                             conv_layer.kernel_size, conv_layer.stride, conv_layer.padding,
                             bias=True if len(params) == 2 else False)
        new_conv.weight.data = new_kernels
        if len(params) == 2:
            new_conv.bias.data = params[1].data  # add bias if neccessary
        layer_name = list(container.state_dict().keys())[0][:-7]  # remove .weight suffix to get the layer name

        # replace the first convolution layer
        setattr(container, layer_name, new_conv)
        return base_model


# if __name__== '__main__':
#     from ptflops import get_model_complexity_info
#     import torch
#     x1 = torch.rand((4, 144, 224, 224))
#     net = TSN(num_class=60,num_segments=8,base_model='resnet50',modality='RGB',consensus_type='avg',
#                 dropout=0.8,
#                 img_feature_dim=256,
#                 partial_bn=False,
#                 pretrain='imagenet',
#                 is_shift=True, shift_div=8, shift_place='blockres',
#                 fc_lr5=True,
#                 temporal_pool=False,
#                 non_local=False, data_length=6)
#     # print(net)
#     output = net(x1)
#     print(output.size())
#     policies = net.get_optim_policies()
#     for group in policies:
#         print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
#             group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))

#     flops, params = get_model_complexity_info(net, (144, 224, 224), as_strings=True,
#                                               print_per_layer_stat=True)  # as_strings=True,会用G或M为单位，反之精确到个位。
#     print("%s |%s |%s" % ('net', flops, params))#stmem |33.35 GMac |23.63 M   TSM原版 net |32.96 GMac |23.63 M      stiam  net |33.16 GMac |23.63 M   seg=4  net |16.68 GMac |23.63 M

In [None]:
# import torch


# def debug_model(model, device):
#     model.to(device)
#     model.eval()

#     # Simulate dummy input and label
#     batch_size = 64
#     num_segments = model.num_segments  # e.g., 6
#     stmem_len = model.data_length
#     C, H, W = 3, 128, 128
#     num_classes = model.num_classes if hasattr(model, "num_classes") else 10

#     # Shape: [B, num_segments, C, H, W]
#     video = torch.randn(batch_size, num_segments*stmem_len*C, H, W).to(device)
#     label = torch.randint(0, num_classes, (batch_size,)).to(device)

#     print("Input video shape:", video.shape)
#     print("Label shape:", label.shape)

#     # Forward pass
#     with torch.no_grad():
#         output = model(video)
#         print("Model output shape:", output.shape)

#         # Check loss computation
#         loss_fn = torch.nn.CrossEntropyLoss()
#         try:
#             loss = loss_fn(output, label)
#             print("✅ Loss computed successfully:", loss.item())
#         except Exception as e:
#             print("❌ Error during loss computation:", e)

# if __name__ == "__main__":
#     num_segments = 6
#     INPUT_DEPTH = 36

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model = TSN(num_class = 18, num_segments = num_segments, modality = "RGBDiff", img_size = (128, 128),
#                 new_length = 1,
#                 base_model="resnet50",
#                 dropout= 0.5,
#                 img_feature_dim=128,
#                 pretrain= "imagenet",
#                 is_shift= True,
#                 # fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
#                 fc_lr5=True,
#                 temporal_pool=False,
#                 non_local=True, data_length = int(INPUT_DEPTH/num_segments))

#     debug_model(model, device)


# Data frame processing

In [None]:
import pandas as pd
import os
import numpy as np
import pandas as pd
import cv2
from PIL import Image

from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import Dataset

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
device

## DIR

In [None]:
TRAIN_DATA_DIR = "/kaggle/input/20-bnjester-csv-files/Train.csv"
VAL_DATA_DIR = "/kaggle/input/20-bnjester-csv-files/Validation.csv"
Cropped_TRAIN_DIR = "/kaggle/input/hand-cropped-20jester-train-dataset/Cropped_Train_Data"
Cropped_VAL_DIR = "/kaggle/input/hand-cropped-20jester-validation-dataset/Cropped_Validation_Data"

In [None]:
# label_encoder = OneHotEncoder(sparse_output=False)

In [None]:
import numpy as np

gesture_names = [
    "Doing other things", "No gesture",
    "Rolling Hand Backward", "Rolling Hand Forward",
    "Shaking Hand",
    "Sliding Two Fingers Down", "Sliding Two Fingers Left", 
    "Sliding Two Fingers Right", "Sliding Two Fingers Up",
    "Stop Sign",
    "Swiping Down", "Swiping Left", "Swiping Right", "Swiping Up",
    "Thumb Down", "Thumb Up",
    "Turning Hand Clockwise", "Turning Hand Counterclockwise"
]

label_to_index = {label: idx for idx, label in enumerate(gesture_names)}

# Example usage:
# If you have a label "Swiping Left", get its integer target for CrossEntropyLoss:
label_to_index["Swiping Left"]  # returns 11

In [None]:
train_video_id_ls = list(map(int, os.listdir(Cropped_TRAIN_DIR)))
val_video_id_ls = list(map(int, os.listdir(Cropped_VAL_DIR)))

train_df = pd.read_csv(TRAIN_DATA_DIR)
val_df = pd.read_csv(VAL_DATA_DIR)

sort_train_df = train_df[train_df["video_id"].isin(train_video_id_ls)] # sorting only the used data
sort_val_df = val_df[val_df["video_id"].isin(val_video_id_ls)]

In [None]:
print("train_df len: ", len(train_df))
print("sort_train_df len: ", len(sort_train_df))
print("val_df len: ", len(val_df))
print("sort_val_df len: ", len(sort_val_df))

# Model

In [None]:
!nvidia-smi

In [None]:
num_segments = 6
INPUT_DEPTH = 36

In [None]:
def create_model(device):
    model = TSN(num_class = 18, num_segments = num_segments, modality = "RGBDiff", img_size = (128, 128),
                new_length = 1,
                base_model="resnet50",
                dropout= 0.5,
                img_feature_dim=128,
                pretrain= "imagenet",
                is_shift= True,
                # fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
                fc_lr5=True,
                temporal_pool=False,
                non_local=True, data_length = int(INPUT_DEPTH/num_segments))
    
    model = nn.DataParallel(model, device_ids=[0, 1])
    model = model.to(device)
    return model

In [None]:
model = create_model(device)

# Dataset

## parameter

In [None]:
batch_size = 64

## Data

In [None]:
import albumentations as A
from albumentations.core.composition import ReplayCompose
from albumentations.pytorch.transforms import ToTensorV2

In [None]:
import numpy as np
import random
import cv2
from albumentations.core.transforms_interface import ImageOnlyTransform

class MultiScaleCropAlbumentations(ImageOnlyTransform):
    def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0):
        super().__init__(always_apply, p)
        self.scales = scales if scales is not None else [1, .875, .75, .66]
        self.max_distort = max_distort
        self.fix_crop = fix_crop
        self.more_fix_crop = more_fix_crop
        self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
        self.interpolation = interpolation

    def apply(self, img, **params):
        im_h, im_w = img.shape[:2]
        crop_w, crop_h, offset_w, offset_h = self._sample_crop_size((im_w, im_h))

        # Crop and resize
        cropped = img[offset_h:offset_h + crop_h, offset_w:offset_w + crop_w]
        resized = cv2.resize(cropped, (self.input_size[0], self.input_size[1]), interpolation=self.interpolation)
        return resized

    def _sample_crop_size(self, im_size):
        image_w, image_h = im_size
        base_size = min(image_w, image_h)

        crop_sizes = [int(base_size * x) for x in self.scales]
        crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
        crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]

        pairs = []
        for i, h in enumerate(crop_h):
            for j, w in enumerate(crop_w):
                if abs(i - j) <= self.max_distort:
                    pairs.append((w, h))

        crop_pair = random.choice(pairs)

        if not self.fix_crop:
            w_offset = random.randint(0, image_w - crop_pair[0])
            h_offset = random.randint(0, image_h - crop_pair[1])
        else:
            w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])

        return crop_pair[0], crop_pair[1], w_offset, h_offset

    def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
        offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
        return random.choice(offsets)

    @staticmethod
    def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
        w_step = (image_w - crop_w) // 4
        h_step = (image_h - crop_h) // 4

        ret = [
            (0, 0),
            (4 * w_step, 0),
            (0, 4 * h_step),
            (4 * w_step, 4 * h_step),
            (2 * w_step, 2 * h_step),
        ]

        if more_fix_crop:
            ret += [
                (0, 2 * h_step),
                (4 * w_step, 2 * h_step),
                (2 * w_step, 4 * h_step),
                (2 * w_step, 0),
                (1 * w_step, 1 * h_step),
                (3 * w_step, 1 * h_step),
                (1 * w_step, 3 * h_step),
                (3 * w_step, 3 * h_step),
            ]
        return ret


In [None]:
def augmentation(img, input_shape):
    """Augmentation function for albumentations."""

    transform = ReplayCompose([
        A.RandomBrightnessContrast(p=0.2),
        A.ColorJitter(p=0.2),
        A.Resize(height=input_shape[0], width=input_shape[1]),
        # MultiScaleCropAlbumentations(input_size=input_shape, scales = [1, .875, .75]),
    ])

    aug = transform(image=img)

    return aug

In [None]:
# input_mean = model.module.input_mean
# input_std = model.module.input_std

# print(input_mean)
# print(input_std)


# def normalize(image):

#     transform = ReplayCompose([
#         A.Normalize(mean=input_mean, std=input_std),
#     ])

#     aug = transform(image=image)

#     return aug

In [None]:
class VideoDatasetTorch(Dataset):
    def __init__(self, usage, data_frame, video_dir, input_shape, label_to_index, transform=None):
        self.usage = usage
        self.data = data_frame.reset_index(drop=True)
        self.video_dir = video_dir
        self.input_shape = input_shape  # (D, H, W, C)
        self.transform = transform
        self.label_to_index = label_to_index
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        video_id = row.video_id
        label_str = row.label
        folder = os.path.join(self.video_dir, str(video_id))

        frames = []
        replay = None

        for j in range(1, 37):
            img_path = os.path.join(folder, f"{j:05d}.jpg")
            img = np.array(Image.open(img_path).convert("RGB"))

            if self.usage == "train" and self.transform:
                if replay is None:
                    aug = self.transform(img=img, input_shape=(self.input_shape[1], self.input_shape[2]))
                    replay = aug["replay"]
                    aug_img = aug["image"]
                else:
                    aug_img = ReplayCompose.replay(replay, image=img)["image"]
                
            else:
                aug_img = cv2.resize(img, (self.input_shape[2], self.input_shape[1]))
                
            aug_img = aug_img.astype(np.float32) / 255.0
            frames.append(aug_img)

        # Convert to (T, H, W, C) then to (T, C, H, W)
        video = np.stack(frames, axis=0).transpose(0, 3, 1, 2)  # (T, C, H, W)
        
        # Flatten temporal and channel dims: (T*C, H, W) --> torch.Size([S * L * 3, H, W])
        video = video.reshape(-1, self.input_shape[1], self.input_shape[2])  # (T*C, H, W)
        video_tensor = torch.tensor(video, dtype=torch.float32)
    
        # Convert label string to integer index using the provided dictionary
        label_index = self.label_to_index[label_str]
        label_tensor = torch.tensor(label_index, dtype=torch.long)

        return video_tensor, label_tensor

In [None]:
train_dataset = VideoDatasetTorch(
    usage="train",
    data_frame=sort_train_df,
    video_dir=Cropped_TRAIN_DIR,
    input_shape=(36, 128, 128, 3),
    # normalize = normalize,
    label_to_index = label_to_index,
    transform=augmentation
)

val_dataset = VideoDatasetTorch(
    usage="val",
    data_frame=sort_val_df,
    video_dir=Cropped_VAL_DIR,
    input_shape=(36, 128, 128, 3),
    # normalize = normalize,
    label_to_index = label_to_index,
    transform=None
)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
data = next(iter(train_loader))
video, abs_ = data
print(video.shape, abs_.shape)
print(abs_)

print(len(train_loader), len(val_loader))

## Hyperparaneters

In [None]:
from torch import optim

import wandb

In [None]:
PROJECT = "HandActionReg"
RESUME = "allow"
WANDB_KEY = "d9d14819dddd8a35a353b5c0b087e0f60d717140"

In [None]:
wandb.login(
    key = WANDB_KEY,
)

In [None]:
learning_rate = 1e-4
weight_decay = 1e-5
epochs = 50

In [None]:
wandb.init(
    project=PROJECT,
    resume=RESUME,
    name="STMEM_hand_TSM_init",
    config={
         "learning_rate": learning_rate,
         "epochs": epochs,
         "batch_size": batch_size,
    },
)
wandb.watch(model)

## Set up

In [None]:
policies = model.module.get_optim_policies()

In [None]:
param_groups = []
for group in policies:
    param_groups.append({
        'params': group['params'],
        'lr': learning_rate * group.get('lr_mult', 1),
        'weight_decay': weight_decay * group.get('decay_mult', 1)
    })

# print(param_groups)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(param_groups, lr=learning_rate, weight_decay = weight_decay)

# Training

In [None]:
patience = 7

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        """
        Early stopping to terminate training when validation loss stops improving.

        Args:
            patience (int): Number of epochs to wait before stopping after no improvement.
            min_delta (float): Minimum change in the monitored value to be considered as an improvement.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = float('inf')
        self.counter = 0

    def __call__(self, val_loss):
        """
        Call this function at the end of each validation step.

        Args:
            val_loss (float): Current epoch's validation loss.

        Returns:
            bool: True if training should stop, False otherwise.
        """
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0  # Reset counter if validation loss improves
        else:
            self.counter += 1  # Increase counter if no improvement

        return self.counter >= self.patience  # Stop training if patience is exceeded
        

early_stopping = EarlyStopping(patience=patience, min_delta=0.001)

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.1, min_lr=1e-6, verbose=True)

In [None]:
from sklearn.metrics import precision_score, recall_score
from tqdm import tqdm

best_val_loss = float('inf')

def train_epoch(model, dataloader, optimizer, loss_fn, DEVICE):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []

    for video, label in tqdm(dataloader, desc="Training", leave=False):
        video = video.to(DEVICE)
        label = label.to(DEVICE)

        optimizer.zero_grad()

        output = model(video)
        loss = loss_fn(output, label)

        loss.backward()

        # Add gradient clipping here
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()
        all_labels.extend(label.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total
    epoch_precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    epoch_recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)

    wandb.log({
        "loss": epoch_loss,
        "accuracy": epoch_acc,
        "precision": epoch_precision,
        "recall": epoch_recall
    })

    return epoch_loss, epoch_acc, epoch_precision, epoch_recall

def validate_epoch(model, dataloader, loss_fn, DEVICE):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for video, label in tqdm(dataloader, desc="Validation", leave=False):
            video = video.to(DEVICE)
            label = label.to(DEVICE)

            output = model(video)
            loss = loss_fn(output, label)

            running_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()
            all_labels.extend(label.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct / total
    epoch_precision = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    epoch_recall = recall_score(all_labels, all_preds, average='macro', zero_division=0)

    wandb.log({
        "val_loss": epoch_loss,
        "val_accuracy": epoch_acc,
        "val_precision": epoch_precision,
        "val_recall": epoch_recall
    })

    return epoch_loss, epoch_acc, epoch_precision, epoch_recall

# Training loop with early stopping and ReduceLROnPlateau
for epoch in range(epochs):
    train_loss, train_acc, train_prec, train_rec = train_epoch(model, train_loader, optimizer, loss_fn, device)
    val_loss, val_acc, val_prec, val_rec = validate_epoch(model, val_loader, loss_fn, device)

    scheduler.step(val_loss)

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} Acc: {train_acc:.2f} | Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}")

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "/kaggle/working/STMEM_TSM_RestNet50.pth")
        patience_counter = 0
    else:
        patience_counter += 1

    if early_stopping(val_loss):
        print("Early stopping triggered.")
        break

# Save model at the last epoch
torch.save(model.state_dict(), "/kaggle/working/STMEM_TSM_RestNet50_last.pth")
wandb.finish()