In [36]:
import torch
import torch.nn as nn
import wandb

# TODO:
# 1. input length & norm mapping order

def rearrange_tensor(input_tensor, order):
    order = order.upper()
    assert len(set(order)) == 5, "Order must be a 5 unique character string"
    assert all([dim in order for dim in "BCHWT"]), "Order must contain all of BCHWT"
    assert all([dim in "BCHWT" for dim in order]), "Order must not contain any characters other than BCHWT"

    return input_tensor.permute([order.index(dim) for dim in "BTCHW"])

def reverse_rearrange_tensor(input_tensor, order):
    order = order.upper()
    assert len(set(order)) == 5, "Order must be a 5 unique character string"
    assert all([dim in order for dim in "BCHWT"]), "Order must contain all of BCHWT"
    assert all([dim in "BCHWT" for dim in order]), "Order must not contain any characters other than BCHWT"

    return input_tensor.permute(["BTCHW".index(dim) for dim in order])

class MotionPrompt(torch.nn.Module):
    def __init__(self, exp_name = ""):
        super(MotionPrompt, self).__init__()
        # default configs
        self.input_permutation = "BCTHW"   # Input format from demo video reader
        self.input_color_order = "BGR"     # Input format from demo video reader
        self.gray_scale = {"B": 0.114, "G": 0.587, "R": 0.299}
        self.visual = False
        self.attention_map = None

        # experimental configs
        if exp_name.startswith("exp1"):
            self.attention_map = exp1
        elif exp_name.startswith("exp2"):
            self.attention_map = exp2

        if "_1" in exp_name:
            self.a = nn.Parameter(torch.zeros(1)+1e-1)
            self.b = nn.Parameter(torch.zeros(1))
        else:
            self.a = nn.Parameter(torch.nn.init.normal_(torch.empty(224,224), 0, 1))
            self.b = nn.Parameter(torch.nn.init.normal_(torch.empty(224,224), 0, 1))

        self.lambda1 = 0
        if "_loss" in exp_name:
            # self.c1 = 1e-3
            # self.c2 = 1e-2
            self.lambda1 = 1
        if "_loss_0.1" in exp_name:
            self.lambda1 = 0.1
        elif "_loss_0.5" in exp_name:
            self.lambda1 = 0.5
        elif "_loss_2" in exp_name:
            self.lambda1 = 2    
        
    def forward(self, video_seq):
        print(f"video_seq shape: {video_seq.shape}")
        print(f"video_seq min: {video_seq.min()}, max: {video_seq.max()}\n")
        video_seq = rearrange_tensor(video_seq, self.input_permutation)
        loss = 0
        
        # normalize the input tensor back to [0, 1]
        norm_seq = video_seq * 0.225 + 0.45
        
        frame_check_acc = 0
        for i in range(norm_seq.shape[0]):
            temp_seq = norm_seq[i]
            frame_check = [(temp_seq[i] == temp_seq[i+1]).all() for i in range(temp_seq.shape[0]-1)]
            frame_check_acc += 0 if sum(frame_check) == 0 else sum(frame_check) + 1

        # transfor the input tensor to grayscale 
        weights = torch.tensor([self.gray_scale[idx] for idx in self.input_color_order], 
                               dtype=norm_seq.dtype, device=norm_seq.device)
        grayscale_video_seq = torch.einsum("btcwh, c -> btwh", norm_seq, weights)
        print(f"grayscale_video_seq shape: {grayscale_video_seq.shape}")
        print(f"grayscale_video_seq min: {grayscale_video_seq.min()}, max: {grayscale_video_seq.max()}\n")
        
        ### frame difference & sums & counts & ratios ###
        B, T, H, W = grayscale_video_seq.shape
        frame_diff = grayscale_video_seq[:,1:] - grayscale_video_seq[:,:-1]
        print(f"frame_diff shape: {frame_diff.shape}")
        print(f"frame diff min: {frame_diff.min()}, max: {frame_diff.max()}\n")

        ### power normalization ###
        norm_attention = self.attention_map(frame_diff, self.a, self.b).unsqueeze(2)
        print(f"norm_attention shape: {norm_attention.shape}")
        print(f"norm attention min: {norm_attention.min()}, max: {norm_attention.max()}\n")
        pad_norm_attention = norm_attention.repeat(1, 1, 3, 1, 1)
        print(f"pad_norm_attention shape: {pad_norm_attention.shape}")
        print(f"pad norm attention min: {pad_norm_attention.min()}, max: {pad_norm_attention.max()}\n")

        if torch.is_grad_enabled():
            # variance_loss_a = torch.var(self.a) / (H*W)
            # variance_loss_b = torch.var(self.b) / (H*W)
            temp_diff = norm_attention[:, 1:] - norm_attention[:, :-1]
            temporal_loss = torch.sum(temp_diff.pow(2)) / (H*W*(T-2)*B)
            # loss = self.lambda1 * temporal_loss + self.c1 * variance_loss_a + self.c2 * variance_loss_b 
            # print(f"variance_loss_a: {variance_loss_a}\t variance_loss_b: {variance_loss_b}\t temporal_loss: {temporal_loss}")
            # if self.visual:
            #     wandb.log({
            #         "variance_loss_a": variance_loss_a, "variance_loss_b": variance_loss_b, "temporal_loss": temporal_loss
            #     })
            loss = self.lambda1 * temporal_loss
            if self.visual:
                wandb.log({
                    "temporal_loss": loss
                })

        if self.visual:
            wandb.log({
            "self.a.mean": self.a.data[0].mean(), "self.a.std": self.a.data[0].std(), 
            "self.b.mean": self.b.data[0].mean(), "self.b.std": self.b.data[0].std(),
            "frame_check": frame_check_acc
            })

        return reverse_rearrange_tensor((pad_norm_attention * video_seq[:,1:]), self.input_permutation), loss


def exp1(input, a, b):
    return 1 / (1 + torch.exp(
        -(5 / (0.45 * torch.abs(torch.tanh(a))+1e-1)) * (input - 0.6 * torch.tanh(b))
        ))

def exp2(input, a, b):
    return 1 / (1 + torch.exp( - torch.nn.ReLU()(a + 1e-3) * (input - b)))

In [37]:
model = MotionPrompt(exp_name="exp2_loss_0.5")
# print(model.attention_map.__name__)
# print(model.lambda1)

# B, C, T, H, W
input = torch.randn(8, 3, 9, 224, 224)
input = (input - input.min()) / (input.max() - input.min())
input = (input - 0.45) / 0.225
output, loss = model(input)
print(f"output shape: {output.shape}")
print(f"output min: {output.min()}, max: {output.max()}")

video_seq shape: torch.Size([8, 3, 9, 224, 224])
video_seq min: -2.0, max: 2.4444446563720703

grayscale_video_seq shape: torch.Size([8, 9, 224, 224])
grayscale_video_seq min: 0.15650144219398499, max: 0.8219897747039795

frame_diff shape: torch.Size([8, 8, 224, 224])
frame diff min: -0.4549407958984375, max: 0.45778119564056396

norm_attention shape: torch.Size([8, 8, 1, 224, 224])
norm attention min: 1.2509172847785521e-05, max: 0.9997380375862122

pad_norm_attention shape: torch.Size([8, 8, 3, 224, 224])
pad norm attention min: 1.2509172847785521e-05, max: 0.9997380375862122

output shape: torch.Size([8, 3, 8, 224, 224])
output min: -1.615512490272522, max: 2.0868873596191406
