In [None]:
import torch
from models.network import *

In [None]:
device = "cuda"
num_frame = 10
num_channel = 6
img_size = 120

model = MS2TAN(
        dim_list=[384, 256, 192],
        num_frame=num_frame,
        image_size=img_size,
        patch_list=[12, 10, 8],
        in_chans=num_channel+1,
        out_chans=num_channel,
        depth_list=[4, 4, 4],
        heads_list=[8, 8, 8],
        dim_head_list=[48, 32, 24],
    ).to(device)
init_weights(model)

In [None]:
total = sum(p.numel() for p in model.parameters())
print("Total params: %.2fM" % (total/1e6))

In [None]:
batch_size = 1

# input and output time-series images
X = torch.randn(batch_size, num_frame, num_channel, img_size, img_size).to(device)
y = torch.randn(batch_size, num_frame, num_channel, img_size, img_size).to(device)

# artificial masked pixels in trainset
artificial = torch.randn(batch_size, num_frame, 1, img_size, img_size).to(device)

# hint tensor for each missing pixels (both artificial and real)
hint_tensor = torch.randn(batch_size, num_frame, 1, img_size, img_size).to(device)

In [None]:
# forward
out = model(X, (hint_tensor, artificial), y, mode='each')

# each immediate result
out_list = out['hist_list']
for idx, res in enumerate(out_list):
    print(f'Immediate result {idx}:', res.shape)

# final result after replacement
final_result = out['replace_out']
print(f'Final result:', final_result.shape)