In [33]:
import torch

tconv = torch.nn.ConvTranspose2d(1, 1, 2, stride=2)
tconv.weight.data = torch.tensor([[[[1, 0], [0, 1]]]], dtype=torch.float32)
tconv.bias.data = torch.tensor([0], dtype=torch.float32)
input = torch.tensor([[[[1, 2], [3, 4]]]], dtype=torch.float32)

print(f"input: {input.shape}") 
print(f"input: {input}")    
print(f"weight: {tconv.weight.shape}")
print(f"weight: {tconv.weight}")
output = tconv(input)
print(f"output: {output.shape}")
print(f"output: {output}")

input: torch.Size([1, 1, 2, 2])
input: tensor([[[[1., 2.],
          [3., 4.]]]])
weight: torch.Size([1, 1, 2, 2])
weight: Parameter containing:
tensor([[[[1., 0.],
          [0., 1.]]]], requires_grad=True)
output: torch.Size([1, 1, 4, 4])
output: tensor([[[[1., 0., 2., 0.],
          [0., 1., 0., 2.],
          [3., 0., 4., 0.],
          [0., 3., 0., 4.]]]], grad_fn=<SlowConvTranspose2DBackward>)


In [47]:


def get_frame_id_list(duration, indices, skip_offsets, skip_length=64, new_step=4):
        frame_id_list = []
        if duration < skip_length:
             step = duration // (skip_length // new_step)
             frame_id_list = [i for i in range(0, duration, step)]
             return frame_id_list
        for seg_ind in indices:
            offset = int(seg_ind)
            for i, _ in enumerate(range(0, skip_length, new_step)):
                if offset + skip_offsets[i] <= duration:
                    frame_id = offset + skip_offsets[i] - 1
                else:
                    frame_id = offset - 1
                frame_id_list.append(frame_id)
                if offset + new_step < duration:
                    offset += new_step
        return frame_id_list

# print(get_frame_id_list(80, [16], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
print(get_frame_id_list(50, [1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))

[0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48]


In [44]:
import numpy as np

def _sample_train_indices(num_frames, skip_length=64, num_segments=1, new_step=4, temporal_jitter=False):
    # (total frames - sample number + 1) // 1
    average_duration = (num_frames - skip_length +
                        1) // num_segments
    if average_duration > 0:
        offsets = np.multiply(
            list(range(num_segments)), average_duration)
        offsets = offsets + np.random.randint(
            average_duration, size=num_segments)
    elif num_frames > max(num_segments, skip_length):
        offsets = np.sort(
            np.random.randint(
                num_frames - skip_length + 1, size=num_segments))
    else:
        offsets = np.zeros((num_segments, ))

    if temporal_jitter:
        skip_offsets = np.random.randint(
            new_step, size=skip_length // new_step)
    else:
        skip_offsets = np.zeros(
            skip_length // new_step, dtype=int)
    return offsets + 1, skip_offsets

print(_sample_train_indices(50))

(array([1.]), array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))


In [5]:
import torch

def round_tensor(x, decimals=0):
    scale = 10 ** decimals
    return torch.round(x * scale) / scale


def closest_odd_numbers(num):
    num = round_tensor(num, decimals=3)
    assert num >= 1, "Number must be greater than or equal to 1, num = " + str(num)
    base = torch.floor(num).int().item()

    lower = base if base % 2 != 0 else base - 1
    num_temp = num.to(dtype=torch.float32)
    lower = torch.where(base <= num_temp, lower, lower - 2)
    higher = lower + 2

    higher_weight = (num - lower) / 2
    lower_weight = 1 - higher_weight

    return lower.to(num.device), higher.to(num.device), lower_weight.to(num.device), higher_weight.to(num.device)


for i in range(10, 100):
    num = torch.tensor(i/10)
    print(num, closest_odd_numbers(num))

tensor(1.) (tensor(1), tensor(3), tensor(1.), tensor(0.))
tensor(1.1000) (tensor(1), tensor(3), tensor(0.9500), tensor(0.0500))
tensor(1.2000) (tensor(1), tensor(3), tensor(0.9000), tensor(0.1000))
tensor(1.3000) (tensor(1), tensor(3), tensor(0.8500), tensor(0.1500))
tensor(1.4000) (tensor(1), tensor(3), tensor(0.8000), tensor(0.2000))
tensor(1.5000) (tensor(1), tensor(3), tensor(0.7500), tensor(0.2500))
tensor(1.6000) (tensor(1), tensor(3), tensor(0.7000), tensor(0.3000))
tensor(1.7000) (tensor(1), tensor(3), tensor(0.6500), tensor(0.3500))
tensor(1.8000) (tensor(1), tensor(3), tensor(0.6000), tensor(0.4000))
tensor(1.9000) (tensor(1), tensor(3), tensor(0.5500), tensor(0.4500))
tensor(2.) (tensor(1), tensor(3), tensor(0.5000), tensor(0.5000))
tensor(2.1000) (tensor(1), tensor(3), tensor(0.4500), tensor(0.5500))
tensor(2.2000) (tensor(1), tensor(3), tensor(0.4000), tensor(0.6000))
tensor(2.3000) (tensor(1), tensor(3), tensor(0.3500), tensor(0.6500))
tensor(2.4000) (tensor(1), tensor(3)

In [6]:
import torch

finetune_path = "model_zoo/vit_g_hybrid_pt_1200e.pth"
checkpoint = torch.load(finetune_path, map_location='cpu')

# print the keys of the checkpoint
print(checkpoint['model'].keys())


odict_keys(['mask_token', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias', 'encoder.blocks.0.norm1.weight', 'encoder.blocks.0.norm1.bias', 'encoder.blocks.0.attn.q_bias', 'encoder.blocks.0.attn.v_bias', 'encoder.blocks.0.attn.qkv.weight', 'encoder.blocks.0.attn.proj.weight', 'encoder.blocks.0.attn.proj.bias', 'encoder.blocks.0.norm2.weight', 'encoder.blocks.0.norm2.bias', 'encoder.blocks.0.mlp.fc1.weight', 'encoder.blocks.0.mlp.fc1.bias', 'encoder.blocks.0.mlp.fc2.weight', 'encoder.blocks.0.mlp.fc2.bias', 'encoder.blocks.1.norm1.weight', 'encoder.blocks.1.norm1.bias', 'encoder.blocks.1.attn.q_bias', 'encoder.blocks.1.attn.v_bias', 'encoder.blocks.1.attn.qkv.weight', 'encoder.blocks.1.attn.proj.weight', 'encoder.blocks.1.attn.proj.bias', 'encoder.blocks.1.norm2.weight', 'encoder.blocks.1.norm2.bias', 'encoder.blocks.1.mlp.fc1.weight', 'encoder.blocks.1.mlp.fc1.bias', 'encoder.blocks.1.mlp.fc2.weight', 'encoder.blocks.1.mlp.fc2.bias', 'encoder.blocks.2.norm1.weight', '

In [5]:
import torch

finetune_path = "model_zoo/pytorch_model.bin"
# finetune_path = "model_zoo/vit_g_hybrid_pt_1200e.pth"
checkpoint = torch.load(finetune_path, map_location='cpu')

# print the keys of the checkpoint
# print(checkpoint['model'].keys())
print(checkpoint['state_dict'].keys())

odict_keys(['backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.q_bias', 'backbone.blocks.0.attn.v_bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.1.norm1.weight', 'backbone.blocks.1.norm1.bias', 'backbone.blocks.1.attn.q_bias', 'backbone.blocks.1.attn.v_bias', 'backbone.blocks.1.attn.qkv.weight', 'backbone.blocks.1.attn.proj.weight', 'backbone.blocks.1.attn.proj.bias', 'backbone.blocks.1.norm2.weight', 'backbone.blocks.1.norm2.bias', 'backbone.blocks.1.mlp.fc1.weight', 'backbone.blocks.1.mlp.fc1.bias', 'backbone.blocks.1.mlp.fc2.weight', 'backbone.blocks.1.mlp.fc2.bias', 'backbone.blocks.2.n

In [26]:
import torch

frame_diff = torch.rand(8, 6, 2, 2)
frame_diff[:, -4:] = 0

# print(frame_diff)

zero_diff = torch.sum(frame_diff == 0.0, dim=(2, 3))
# print(zero_diff)

B, T, H, W = frame_diff.shape
for i in range(B):
    matching_mask = (zero_diff[i] == H * W)
    matching_indices = torch.nonzero(matching_mask).flatten()
    for j in matching_indices:
        if j > 0:
            frame_diff[i, j] = frame_diff[i, j - 1]

print(frame_diff)

tensor([[[[0.5156, 0.9380],
          [0.2992, 0.4276]],

         [[0.5195, 0.0078],
          [0.2015, 0.7975]],

         [[0.5195, 0.0078],
          [0.2015, 0.7975]],

         [[0.5195, 0.0078],
          [0.2015, 0.7975]],

         [[0.5195, 0.0078],
          [0.2015, 0.7975]],

         [[0.5195, 0.0078],
          [0.2015, 0.7975]]],


        [[[0.4275, 0.0850],
          [0.2917, 0.6724]],

         [[0.6563, 0.4740],
          [0.3515, 0.2606]],

         [[0.6563, 0.4740],
          [0.3515, 0.2606]],

         [[0.6563, 0.4740],
          [0.3515, 0.2606]],

         [[0.6563, 0.4740],
          [0.3515, 0.2606]],

         [[0.6563, 0.4740],
          [0.3515, 0.2606]]],


        [[[0.7610, 0.6424],
          [0.5726, 0.8674]],

         [[0.4713, 0.3859],
          [0.9365, 0.9999]],

         [[0.4713, 0.3859],
          [0.9365, 0.9999]],

         [[0.4713, 0.3859],
          [0.9365, 0.9999]],

         [[0.4713, 0.3859],
          [0.9365, 0.9999]],

         [