In [None]:
import torch
from Loader_17 import DAVIS_Rawset, DAVIS_Infer, DAVIS_Dataset, normalize
from polygon import RasLoss, SoftPolygon
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import copy
import matplotlib.pyplot as plt
from einops import rearrange
import torch.nn.functional as F
from MyLoss import deviation_loss, total_len_loss
from torch.nn.init import xavier_uniform_
from ms_deform_attn import MSDeformAttn
from torch.utils.data import Dataset
import json
import random
import gc
import regex as re

In [None]:
train_rawset = DAVIS_Rawset(is_train=True)
val_rawset = DAVIS_Rawset(is_train=False)

In [None]:
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")


def get_valid_ratio(mask):
    _, H, W = mask.shape
    valid_H = torch.sum(~mask[:, :, 0], 1)
    valid_W = torch.sum(~mask[:, 0, :], 1)
    valid_ratio_h = valid_H.float() / H
    valid_ratio_w = valid_W.float() / W
    valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
    return valid_ratio


class DeformableTransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        d_model=256,
        d_ffn=1024,
        dropout=0.1,
        activation="relu",
        n_levels=4,
        n_heads=8,
        n_points=4,
    ):
        super().__init__()

        # self attention
        self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout3 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, src):
        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
        src = src + self.dropout3(src2)
        src = self.norm2(src)
        return src

    def forward(
        self,
        src,
        pos,
        reference_points,
        spatial_shapes,
        level_start_index,
        padding_mask=None,
    ):
        # self attention
        src2 = self.self_attn(
            self.with_pos_embed(src, pos),
            reference_points,
            src,
            spatial_shapes,
            level_start_index,
            padding_mask,
        )
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # ffn
        src = self.forward_ffn(src)

        return src


class DeformableTransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers

    @staticmethod
    def get_reference_points(spatial_shapes, valid_ratios, device):
        reference_points_list = []
        for lvl, (H_, W_) in enumerate(spatial_shapes):

            ref_y, ref_x = torch.meshgrid(
                torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
            )
            ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
            ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
            ref = torch.stack((ref_x, ref_y), -1)
            reference_points_list.append(ref)
        reference_points = torch.cat(reference_points_list, 1)
        reference_points = reference_points[:, :, None] * valid_ratios[:, None]
        return reference_points

    def forward(
        self,
        src,
        spatial_shapes,
        level_start_index,
        valid_ratios,
        pos=None,
        padding_mask=None,
    ):
        output = src
        reference_points = self.get_reference_points(
            spatial_shapes, valid_ratios, device=src.device
        )
        for _, layer in enumerate(self.layers):
            output = layer(
                output,
                pos,
                reference_points,
                spatial_shapes,
                level_start_index,
                padding_mask,
            )

        return output


def get_bou_feats(img_feats: torch.Tensor, boundary: torch.Tensor) -> torch.Tensor:
    return img_feats[
        torch.arange(boundary.shape[0]).unsqueeze(1),
        :,
        boundary[:, :, 1],
        boundary[:, :, 0],
    ]


class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, dropout=0.1, max_seq_len=102400) -> None:
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_seq_len, embed_dim)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embed_dim, 2).float() * (-np.log(10000.0) / embed_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1)]
        return self.dropout(x)


def get_img_tokens(img_features: torch.Tensor) -> torch.Tensor:
    img_tokens = rearrange(
        img_features,
        "b c h w -> b (h w) c",
    )
    return img_tokens


def get_extrame_4_points(batch_indices: torch.Tensor):
    if batch_indices.shape[0] == 0:
        mid = 224 // 2
        return torch.tensor([[mid, mid], [mid, mid], [mid, mid], [mid, mid]])
    x_min = batch_indices[:, 1].min()
    x_max = batch_indices[:, 1].max()
    y_min = batch_indices[:, 0].min()
    y_max = batch_indices[:, 0].max()
    return torch.tensor(
        [[x_min, y_min], [x_min, y_max], [x_max, y_max], [x_max, y_min]]
    )


def get_bounding_box(sgm: torch.Tensor):
    indices = torch.nonzero(sgm)
    batch_size = sgm.shape[0]
    bounding_boxes = []
    for i in range(batch_size):
        batch_indices = indices[indices[:, 0] == i][:, 1:]
        bounding_boxes.append(get_extrame_4_points(batch_indices))
    return torch.stack(bounding_boxes)


def add_mid_points(points: torch.Tensor) -> torch.Tensor:
    points_shift = torch.roll(points, 1, 1)
    mid_points = (points + points_shift) / 2
    new_points = torch.zeros((points.shape[0], points.shape[1] * 2, 2)).to(
        points.device
    )
    new_points[:, ::2] = mid_points
    new_points[:, 1::2] = points
    return new_points


class DeformableTransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model=256,
        d_ffn=1024,
        dropout=0.1,
        activation="relu",
        n_levels=4,
        n_heads=8,
        n_points=4,
    ):
        super().__init__()

        # cross attention
        self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward(
        self,
        tgt,
        query_pos,
        reference_points,
        src,
        src_spatial_shapes,
        level_start_index,
        src_padding_mask=None,
    ):
        # self attention
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(
            q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1)
        )[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # cross attention
        tgt2 = self.cross_attn(
            self.with_pos_embed(tgt, query_pos),
            reference_points,
            src,
            src_spatial_shapes,
            level_start_index,
            src_padding_mask,
        )
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # ffn
        tgt = self.forward_ffn(tgt)

        return tgt


class DeformableTransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate

    def forward(
        self,
        tgt,
        reference_points,
        src,
        src_spatial_shapes,
        src_level_start_index,
        src_valid_ratios,
        query_pos=None,
        src_padding_mask=None,
    ):
        output = tgt

        intermediate = []
        intermediate_reference_points = []
        for lid, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                reference_points_input = (
                    reference_points[:, :, None]
                    * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
                )
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = (
                    reference_points[:, :, None] * src_valid_ratios[:, None]
                )
            output = layer(
                output,
                query_pos,
                reference_points_input,
                src,
                src_spatial_shapes,
                src_level_start_index,
                src_padding_mask,
            )

            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(intermediate_reference_points)

        return output, reference_points


class MLP(nn.Module):
    """Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def inverse_sigmoid(x, eps=1e-5):
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1 / x2)





def get_bou_iou(
    index: int,
    boundary: torch.Tensor,
    mask: torch.Tensor,
    rasterizer,
) -> torch.Tensor:
    pred_sgm = rasterizer(boundary, 224, 224)
    pred_sgm[pred_sgm == -1] = 0
    pred_sgm = pred_sgm[index]
    boundary = boundary[index]
    mask = mask[index]
    intersection = pred_sgm * mask
    intersection = intersection.sum()
    union = pred_sgm.sum() + mask.sum() - intersection.sum()
    iou = intersection / union
    return iou

def get_batch_average_bou_iou(
    boundary: torch.Tensor,
    mask: torch.Tensor,
    rasterizer,
) -> torch.Tensor:
    with torch.no_grad():
        pred_sgm = rasterizer(boundary, 224, 224)
        pred_sgm[pred_sgm == -1] = 0
        pred_sgm = pred_sgm.flatten(1)
        mask = mask.flatten(1)
        intersection = pred_sgm * mask
        intersection = intersection.sum(-1)
        union = pred_sgm.sum(-1) + mask.sum(-1) - intersection
        iou = intersection / union
        return iou.mean()

In [None]:
class DeformLearnImage(nn.Module):
    def __init__(
        self,
        layer_num=1,
        up_scale_num=4,
        head_num=6,
        medium_level_size=[14, 28, 56, 112],
        offset_limit=56,
        n_points=4,
        freeze_backbone=True,
    ) -> None:
        super(DeformLearnImage, self).__init__()
        self.up_scale_num = up_scale_num
        self.offset_limit = offset_limit
        self.medium_level_size = medium_level_size
        self.featup = torch.hub.load(
            "mhamilton723/FeatUp",
            "dino16",
            use_norm=True,
        ).cuda()
        if freeze_backbone:
            for param in self.featup.parameters():
                param.requires_grad = False
        d_model = 384
        d_ffn = 1024
        n_levels = len(medium_level_size) + 1
        self.pos_enoc = PositionalEncoding(d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        deform_encoder_layer = DeformableTransformerEncoderLayer(
            d_model=d_model,
            d_ffn=d_ffn,
            n_levels=n_levels,
            n_heads=head_num,
            n_points=n_points,
        )
        self.deform_encoder = DeformableTransformerEncoder(
            deform_encoder_layer,
            num_layers=layer_num,
        )
        query_num = 4
        for i in range(up_scale_num):
            query_num *= 2
        self.query_embed = nn.Embedding(query_num, d_model)
        # init the query embedding
        xavier_uniform_(self.query_embed.weight)
        deform_decoder_layer = DeformableTransformerDecoderLayer(
            d_model=d_model,
            d_ffn=d_ffn,
            n_levels=n_levels,
            n_heads=head_num,
            n_points=n_points,
        )
        deform_decoder = DeformableTransformerDecoder(
            deform_decoder_layer,
            num_layers=layer_num,
        )
        self.deform_decoders = _get_clones(deform_decoder, up_scale_num + 1)
        xy_fc = MLP(d_model, d_model, 2, 3).cuda()
        # xy_fc = nn.Linear(d_model, 2).cuda()
        self.xy_fc = _get_clones(xy_fc, up_scale_num + 1)

    def get_valid_ratio(self, mask: torch.Tensor):
        _, H, W = mask.shape
        valid_H = torch.sum(~mask[:, :, 0], 1)
        valid_W = torch.sum(~mask[:, 0, :], 1)
        valid_ratio_h = valid_H.float() / H
        valid_ratio_w = valid_W.float() / W
        valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
        return valid_ratio

    def forward(
        self,
        img: torch.Tensor,
        mask: torch.Tensor,
    ):
        feats = self.featup(img)
        # prepare the input for the MSDeformAttn module
        srcs = []
        padding_masks = []
        for low_res in self.medium_level_size:
            srcs.append(
                F.interpolate(
                    feats,
                    size=(low_res, low_res),
                    mode="bilinear",
                ),
            )
        srcs.append(feats)
        for src in srcs:
            padding_masks.append(torch.zeros_like(src[:, 0:1, :, :]).squeeze(1).bool())
        src_flatten = []
        spatial_shapes = []
        for src in srcs:
            src_flatten.append(
                rearrange(src, "b c h w -> b (h w) c"),
            )
            spatial_shapes.append(src.shape[-2:])
        level_start_index = torch.cat(
            (
                torch.tensor([0]),
                torch.cumsum(
                    torch.tensor([x.shape[1] for x in src_flatten]),
                    0,
                )[:-1],
            )
        ).cuda()
        src_flatten = torch.cat(src_flatten, 1).cuda()
        valid_ratios = torch.stack(
            [self.get_valid_ratio(mask) for mask in padding_masks],
            1,
        ).cuda()
        spatial_shapes = torch.as_tensor(
            spatial_shapes,
            dtype=torch.long,
            device=src_flatten.device,
        )
        src_flatten = self.layer_norm(src_flatten)
        src_flatten = self.pos_enoc(src_flatten)
        src_flatten = self.deform_encoder(
            src=src_flatten,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
        )

        B, S, C = src_flatten.shape
        queries = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1).cuda()
        init_bou = get_bounding_box(mask).cuda() / 224
        current_query_num = init_bou.shape[1]
        current_query = queries[:, :current_query_num]
        decode_output, _ = self.deform_decoders[0](
            current_query,
            init_bou,
            src_flatten,
            spatial_shapes,
            level_start_index,
            valid_ratios,
        )

        xy_offset = (
            (self.xy_fc[0](decode_output).sigmoid() - 0.5) * self.offset_limit / 224
        )
        init_bou += xy_offset

        # xy_offset = self.xy_fc[0](decode_output)
        # xy_offset = xy_offset.sigmoid()
        # init_bou = inverse_sigmoid(init_bou) + xy_offset
        # init_bou = init_bou.sigmoid().clone()

        # init_bou = init_bou.clamp(0, 1)

        results = [init_bou]
        for i in range(self.up_scale_num):
            new_query = queries[:, current_query_num : current_query_num * 2]
            current_query_num *= 2

            current_query = torch.zeros((B, current_query_num, C)).to(
                src_flatten.device
            )
            current_query[:, ::2] = new_query
            current_query[:, 1::2] = decode_output
            # current_query = torch.cat([new_query, decode_output], 1)

            cur_bou = add_mid_points(results[-1])
            decode_output, _ = self.deform_decoders[i + 1](
                current_query,
                cur_bou,
                src_flatten,
                spatial_shapes,
                level_start_index,
                valid_ratios,
            )

            xy_offset = (
                (self.xy_fc[i + 1](decode_output).sigmoid() - 0.5)
                * self.offset_limit
                / 224
            )
            cur_bou += xy_offset

            # xy_offset = self.xy_fc[i + 1](decode_output)
            # cur_bou = inverse_sigmoid(cur_bou) + xy_offset
            # cur_bou = cur_bou.sigmoid().clone()

            # cur_bou = cur_bou.clamp(0, 1)

            results.append(cur_bou)
        results = [result * 224 for result in results]

        if self.training:
            return results
        else:
            result = results[-1]
            result = result.clamp(0, 223)
            return result

In [None]:
model = DeformLearnImage().cuda()
model.train()
None

In [None]:
class DAVIS_IMG_Dataset(Dataset):
    def __init__(self, rawset: DAVIS_Rawset, is_train: bool, val_sample_num: int = 4):
        self.is_train = is_train
        if not is_train:
            self.val_sample_num = val_sample_num
        # remove the empty frame
        empty_frame_idx = []
        for video_idx, video_data in enumerate(rawset.data_set):
            for frame_idx, frame_data in enumerate(video_data):
                img, mask = frame_data
                if mask.sum() == 0:
                    empty_frame_idx.append((video_idx, frame_idx))
        self.data_set = []
        # add the data without empty frame
        for video_idx, video_data in enumerate(rawset.data_set):
            self.data_set.append([])
            for frame_idx, frame_data in enumerate(video_data):
                if (video_idx, frame_idx) in empty_frame_idx:
                    continue
                img, mask = frame_data
                self.data_set[-1].append((img, mask))

    def __len__(self):
        if self.is_train:
            return len(self.data_set)
        else:
            return len(self.data_set) * self.val_sample_num

    def __getitem__(self, idx: int):
        if self.is_train:
            video_idx = idx
            video_data = self.data_set[video_idx]
            # random select one frame
            frame_idx = random.randint(0, len(video_data) - 1)
            img, mask = video_data[frame_idx]
            return img, mask
        else:
            # get the video index and frame index
            video_idx = idx // self.val_sample_num
            video_data = self.data_set[video_idx]
            video_data_len = len(video_data)
            frame_step = video_data_len // self.val_sample_num
            frame_idx = (idx % self.val_sample_num) * frame_step
            img, mask = video_data[frame_idx]
            return img, mask

In [None]:
img_train_dataset = DAVIS_IMG_Dataset(train_rawset, is_train=True)
img_val_dataset = DAVIS_IMG_Dataset(val_rawset, is_train=False)

In [None]:
len(img_train_dataset), len(img_val_dataset)

In [None]:
index = 9
first_frame, mask = img_val_dataset[0]

In [None]:
len(img_train_dataset), len(img_val_dataset)

In [None]:
# test train dataset
train_loader = DataLoader(img_train_dataset, batch_size=1, shuffle=True)
first_frame, mask = next(iter(train_loader))
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(normalize(first_frame[0]).permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(mask[0])
plt.axis('off')
plt.show()

In [None]:
# test val dataset
val_loader = DataLoader(img_val_dataset, batch_size=2, shuffle=True)
first_frame, mask = next(iter(val_loader))
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(normalize(first_frame[0]).permute(1, 2, 0))
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(mask[0])
plt.axis('off')
plt.show()

In [None]:
ras_loss = RasLoss().cuda()
gt_rasterizer = SoftPolygon(1, "hard_mask").cuda()

In [None]:
model = DeformLearnImage().cuda()
model.train()
results = model(first_frame.cuda(), mask.cuda())
results[-1].shape

In [None]:
train_loader = DataLoader(img_train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(img_val_dataset, batch_size=1, shuffle=True)

In [None]:
model = DeformLearnImage().cuda()
loss_dict = {}
iou_train_dict = {}
iou_val_dict = {}
optimizer = optim.Adam(model.parameters(), lr=1e-4)
epoch_num = 20
eval_period = 5

In [None]:
for e in range(epoch_num):
    mean_loss = 0
    train_mean_iou = 0
    for first_frame, mask in tqdm(train_loader):
        model.train()
        optimizer.zero_grad()
        results = model(first_frame.cuda(), mask.cuda())
        loss = 0
        for result in results:
            loss += ras_loss(result, mask.cuda())
        loss /= len(results)
        loss.backward()
        optimizer.step()
        mean_loss += loss.item()
        iou = get_batch_average_bou_iou(results[-1], mask.cuda(), gt_rasterizer)
        train_mean_iou += iou.item()
    mean_loss /= len(train_loader)
    train_mean_iou /= len(train_loader)
    loss_dict[e] = mean_loss
    iou_train_dict[e] = train_mean_iou
    print(f"Epoch {e} train loss: {mean_loss:.4f}, iou: {train_mean_iou:.4f}")
    if e % eval_period == 0 or e == epoch_num - 1:
        val_mean_iou = 0
        model.eval()
        for first_frame, mask in tqdm(val_loader):
            result = model(first_frame.cuda(), mask.cuda())
            iou = get_batch_average_bou_iou(result, mask.cuda(), gt_rasterizer)
            val_mean_iou += iou.item()
        val_mean_iou /= len(val_loader)
        iou_val_dict[e] = val_mean_iou
        print(f"Epoch {e} val iou: {val_mean_iou:.4f}")

In [None]:
model_name = "deform_img_davis"
model_path = f"./model/{model_name}_best.pth"
log_dir = f"./log/{model_name}"
# load the log
with open(f"{log_dir}/loss.json", "r") as f:
    loss_dict = json.load(f)
with open(f"{log_dir}/iou_train.json", "r") as f:
    iou_train_dict = json.load(f)
with open(f"{log_dir}/iou_val.json", "r") as f:
    iou_val_dict = json.load(f)
# plot the loss
plt.figure(figsize=(10, 5))
plt.plot(loss_dict.keys(), loss_dict.values(), label="train loss")
plt.xlabel("epoch")
plt.xticks(np.arange(0, 1500, 100))
plt.ylabel("loss")
plt.title("train loss")
plt.legend()
plt.show()
# plot the iou
plt.figure(figsize=(10, 5))
plt.plot(iou_train_dict.keys(), iou_train_dict.values(), label="train iou")
plt.plot(iou_val_dict.keys(), iou_val_dict.values(), label="val iou")
plt.xlabel("epoch")
plt.xticks(np.arange(0, 1500, 100))
plt.ylabel("iou")
plt.title("iou")
plt.legend()
plt.show()
# print the best iou in train and val
best_train_iou = max(iou_train_dict.values())
best_val_iou = max(iou_val_dict.values())
print(f"best train iou: {best_train_iou:.4f}, best val iou: {best_val_iou:.4f}")


In [None]:
model = DeformLearnImage().cuda()
# load the best model
model.load_state_dict(torch.load(model_path))

In [None]:
# test the model
first_frame, mask = next(iter(val_loader))
index = 0
model.eval()
pred_bou = model(first_frame.cuda(), mask.cuda())
# calculate the iou
iou = get_bou_iou(index, pred_bou, mask.cuda(), gt_rasterizer)
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.title(f"iou: {iou:.4f}")
plt.imshow(normalize(first_frame[0]).permute(1, 2, 0))
pred_bou_np = pred_bou[index].detach().cpu().numpy()
plt.plot(pred_bou_np[:, 0], pred_bou_np[:, 1], "r")
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(mask[index])
plt.axis('off')
plt.show()

In [None]:
len(img_val_dataset)

In [None]:
val_text_idxs = np.arange(0, len(img_val_dataset), 4)
# get the text data
val_text_data = []
for idx in val_text_idxs:
    first_frame, mask = img_val_dataset[idx]
    val_text_data.append((first_frame, mask))
# show the text data
plt.figure(figsize=(10, 5 * len(val_text_data)))
for i, (first_frame, mask) in enumerate(val_text_data):
    plt.subplot(len(val_text_data), 2, 2 * i + 1)
    plt.imshow(normalize(first_frame).permute(1, 2, 0))
    plt.axis('off')
    model.eval()
    pred_bou = model(first_frame.unsqueeze(0).cuda(), mask.unsqueeze(0).cuda())
    iou = get_bou_iou(0, pred_bou, mask.unsqueeze(0).cuda(), gt_rasterizer)
    plt.title(f"iou: {iou:.4f}")
    pred_bou_np = pred_bou[0].detach().cpu().numpy()
    plt.plot(pred_bou_np[:, 0], pred_bou_np[:, 1], "r")
    plt.scatter(pred_bou_np[:, 0], pred_bou_np[:, 1], c="r", s=5)
    plt.subplot(len(val_text_data), 2, 2 * i + 2)
    plt.imshow(mask)
    plt.title(f"ground truth, {i}")
    plt.axis('off')

## Here we start the video part

In [None]:
train_point_path = "sample_results/train_256_uniform.json"
val_point_path = "sample_results/val_256_uniform.json"
with open(train_point_path, "r") as f:
    train_points = json.load(f)
with open(val_point_path, "r") as f:
    val_points = json.load(f)

In [None]:
len(train_points[0])

In [None]:
torch.tensor(train_points[100][0]["4"]["boundary"])

In [None]:
class DAVIS_withPoint(Dataset):
    def __init__(
        self,
        raw_set: DAVIS_Rawset,
        point_num: int,
        is_train: bool,
    ) -> None:
        super().__init__()
        self.point_num = point_num
        # remove all the video with empty frame
        empty_video_idx = []
        for video_idx, video_data in enumerate(raw_set.data_set):
            for frame_data in video_data:
                img, mask = frame_data
                if mask.sum() == 0:
                    empty_video_idx.append(video_idx)
                    break
        self.raw_data_set = []
        if is_train:
            train_point_path = "sample_results/train_256_uniform.json"
        else:
            train_point_path = "sample_results/val_256_uniform.json"
        with open(train_point_path, "r") as f:
            points = json.load(f)
        for video_idx, video_data in enumerate(raw_set.data_set):
            if video_idx in empty_video_idx:
                continue
            self.raw_data_set.append([])
            for frame_idx, frame_data in enumerate(video_data):
                img, mask = frame_data
                point_data = points[video_idx][frame_idx][str(point_num)]["boundary"]
                point = torch.tensor(point_data)
                self.raw_data_set[-1].append((img, mask, point))
        self.data = []
        for video_idx, video_data in enumerate(self.raw_data_set):
            for frame_idx in range(len(video_data) - 1):
                self.data.append(
                    (
                        video_idx,
                        frame_idx,
                        video_data[0],
                        video_data[frame_idx],
                        video_data[frame_idx + 1],
                    )
                )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        video_idx, frame_idx, first_frame, previous_frame, current_frame = self.data[
            idx
        ]
        return video_idx, frame_idx, first_frame, previous_frame, current_frame

In [None]:
video_train_dataset = DAVIS_withPoint(
    train_rawset,
    point_num=64,
    is_train=True,
)
video_val_dataset = DAVIS_withPoint(
    val_rawset,
    point_num=64,
    is_train=False,
)

In [None]:
test_rawset = copy.copy(train_rawset)
test_rawset.data_set = train_rawset.data_set[:10]
len(test_rawset.data_set), test_rawset.data_set[-1][-1]

In [None]:
video_test_dataset = DAVIS_withPoint(
    test_rawset,
    point_num=64,
    is_train=True,
)

In [None]:
video_test_loader = DataLoader(video_test_dataset, batch_size=1, shuffle=True)

In [None]:
video_idx, frame_idx, first_frame, previous_frame, current_frame = next(
    iter(video_test_loader)
)
first_frame[0].shape

In [None]:
# test the test loader
video_idx, frame_idx, first_frame, previous_frame, current_frame = next(
    iter(video_test_loader)
)
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(normalize(first_frame[0][0]).permute(1, 2, 0))
plt.imshow(first_frame[1][0], alpha=0.5)
plt.plot(first_frame[2][0][:, 0], first_frame[2][0][:, 1], "r")
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(normalize(previous_frame[0][0]).permute(1, 2, 0))
plt.imshow(previous_frame[1][0], alpha=0.5)
plt.plot(previous_frame[2][0][:, 0], previous_frame[2][0][:, 1], "r")
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(normalize(current_frame[0][0]).permute(1, 2, 0))
plt.imshow(current_frame[1][0], alpha=0.5)
plt.plot(current_frame[2][0][:, 0], current_frame[2][0][:, 1], "r")
plt.axis('off')
plt.show()

In [None]:

len(video_train_dataset), len(video_val_dataset), len(video_test_dataset)

In [None]:
len(video_train_dataset.raw_data_set), len(video_val_dataset.raw_data_set), len(video_test_dataset.raw_data_set)

In [None]:
video_train_loader = DataLoader(video_train_dataset, batch_size=1, shuffle=True)
video_val_loader = DataLoader(video_val_dataset, batch_size=1, shuffle=True)

In [None]:
video_idx, frame_idx, first_frame, previous_frame, current_frame = next(iter(video_train_loader))
first_frame[0].shape

In [None]:
# test the video train dataset
video_idx, frame_idx, first_frame, previous_frame, current_frame = next(iter(video_train_loader))
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(normalize(first_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(first_frame[1][0], alpha=0.5)
plt.plot(first_frame[2][0][:, 0], first_frame[2][0][:, 1], "r")
plt.subplot(1, 3, 2)
plt.imshow(normalize(previous_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(previous_frame[1][0], alpha=0.5)
plt.plot(previous_frame[2][0][:, 0], previous_frame[2][0][:, 1], "r")
plt.subplot(1, 3, 3)
plt.imshow(normalize(current_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(current_frame[1][0], alpha=0.5)
plt.plot(current_frame[2][0][:, 0], current_frame[2][0][:, 1], "r")
plt.show()
# test the video val dataset
video_idx, frame_idx, first_frame, previous_frame, current_frame = next(iter(video_val_loader))
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(normalize(first_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(first_frame[1][0], alpha=0.5)
plt.plot(first_frame[2][0][:, 0], first_frame[2][0][:, 1], "r")
plt.subplot(1, 3, 2)
plt.imshow(normalize(previous_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(previous_frame[1][0], alpha=0.5)
plt.plot(previous_frame[2][0][:, 0], previous_frame[2][0][:, 1], "r")
plt.subplot(1, 3, 3)
plt.imshow(normalize(current_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(current_frame[1][0], alpha=0.5)
plt.plot(current_frame[2][0][:, 0], current_frame[2][0][:, 1], "r")
plt.show()

In [None]:
class DeformableTransformerExtraDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model=256,
        d_ffn=1024,
        dropout=0.1,
        activation="relu",
        n_levels=4,
        n_heads=8,
        n_points=4,
    ):
        super().__init__()

        # extra cross attention
        self.extra_cross_attn = nn.MultiheadAttention(
            d_model,
            n_heads,
            dropout=dropout,
            batch_first=True,
        )
        self.extra_norm = nn.LayerNorm(d_model)
        self.extra_dropout = nn.Dropout(dropout)

        # cross attention
        self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # self attention
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)

        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)

    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos

    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt

    def forward(
        self,
        tgt,
        query_pos,
        reference_points,
        src,
        src_spatial_shapes,
        level_start_index,
        extra_memory,
        src_padding_mask=None,
    ):
        # self attention
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(
            q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1)
        )[0].transpose(0, 1)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)

        # extra cross attention
        tgt_extra = self.extra_cross_attn(
            self.with_pos_embed(tgt, query_pos),
            extra_memory,
            extra_memory,
        )[0]
        tgt = tgt + self.extra_dropout(tgt_extra)
        tgt = self.extra_norm(tgt)

        # cross attention
        tgt2 = self.cross_attn(
            self.with_pos_embed(tgt, query_pos),
            reference_points,
            src,
            src_spatial_shapes,
            level_start_index,
            src_padding_mask,
        )
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

        # ffn
        tgt = self.forward_ffn(tgt)

        return tgt


class DeformableTransformerExtraDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, return_intermediate=False):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate

    def forward(
        self,
        tgt,
        reference_points,
        src,
        src_spatial_shapes,
        src_level_start_index,
        src_valid_ratios,
        extra_memory,
        query_pos=None,
        src_padding_mask=None,
    ):
        output = tgt

        intermediate = []
        intermediate_reference_points = []
        for lid, layer in enumerate(self.layers):
            if reference_points.shape[-1] == 4:
                reference_points_input = (
                    reference_points[:, :, None]
                    * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
                )
            else:
                assert reference_points.shape[-1] == 2
                reference_points_input = (
                    reference_points[:, :, None] * src_valid_ratios[:, None]
                )
            output = layer(
                output,
                query_pos,
                reference_points_input,
                src,
                src_spatial_shapes,
                src_level_start_index,
                extra_memory,
                src_padding_mask,
            )

            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)

        if self.return_intermediate:
            return torch.stack(intermediate), torch.stack(intermediate_reference_points)

        return output, reference_points

In [None]:
# test the video train dataset
video_idx, frame_idx, first_frame, previous_frame, current_frame = next(iter(video_train_loader))
plt.figure(figsize=(10, 10))
plt.subplot(1, 3, 1)
plt.imshow(normalize(first_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(first_frame[1][0], alpha=0.5)
plt.plot(first_frame[2][0][:, 0], first_frame[2][0][:, 1], "r")
plt.subplot(1, 3, 2)
plt.imshow(normalize(previous_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(previous_frame[1][0], alpha=0.5)
plt.plot(previous_frame[2][0][:, 0], previous_frame[2][0][:, 1], "r")
plt.subplot(1, 3, 3)
plt.imshow(normalize(current_frame[0][0]).permute(1, 2, 0))
plt.axis('off')
plt.imshow(current_frame[1][0], alpha=0.5)
plt.plot(current_frame[2][0][:, 0], current_frame[2][0][:, 1], "r")
plt.show()

In [None]:
fir_img, fir_mask, fir_point = first_frame
pre_img, pre_mask, pre_point = previous_frame
cur_img, cur_mask, cur_point = current_frame

In [None]:
class DeformVideo(nn.Module):
    def __init__(
        self,
        layer_num=1,
        up_scale_num=4,
        head_num=6,
        medium_level_size=[14, 28, 56, 112],
        offset_limit=56,
        n_points=4,
        mem_point_num=64,
        freeze_backbone=True,
    ) -> None:
        super(DeformVideo, self).__init__()
        self.up_scale_num = up_scale_num
        self.offset_limit = offset_limit
        self.medium_level_size = medium_level_size
        self.featup = torch.hub.load(
            "mhamilton723/FeatUp",
            "dino16",
            use_norm=True,
        ).cuda()
        if freeze_backbone:
            for param in self.featup.parameters():
                param.requires_grad = False
        d_model = 384
        d_ffn = 1024
        n_levels = len(medium_level_size) + 1
        self.pos_enoc = PositionalEncoding(d_model)

        enc_layer = DeformableTransformerEncoderLayer(
            d_model=d_model,
            d_ffn=d_ffn,
            n_levels=n_levels,
            n_heads=head_num,
            n_points=n_points,
        )
        dec_layer = DeformableTransformerDecoderLayer(
            d_model=d_model,
            d_ffn=d_ffn,
            n_levels=n_levels,
            n_heads=head_num,
            n_points=n_points,
        )

        self.first_query_embed = nn.Embedding(mem_point_num, d_model)
        xavier_uniform_(self.first_query_embed.weight)
        self.first_layer_norm = nn.LayerNorm(d_model)
        self.fir_enc = DeformableTransformerEncoder(
            enc_layer,
            num_layers=layer_num,
        )
        self.fir_dec = DeformableTransformerDecoder(
            dec_layer,
            num_layers=layer_num,
        )
        self.previous_query_embed = nn.Embedding(mem_point_num, d_model)
        xavier_uniform_(self.previous_query_embed.weight)
        self.previous_layer_norm = nn.LayerNorm(d_model)
        self.pre_enc = DeformableTransformerEncoder(
            enc_layer,
            num_layers=layer_num,
        )
        self.pre_dec = DeformableTransformerDecoder(
            dec_layer,
            num_layers=layer_num,
        )
        self.extra_layer_norm = nn.LayerNorm(d_model)

        query_num = 4
        for _ in range(up_scale_num):
            query_num *= 2
        self.current_query_embed = nn.Embedding(query_num, d_model)
        xavier_uniform_(self.current_query_embed.weight)
        self.current_layer_norm = nn.LayerNorm(d_model)
        self.cur_enc = DeformableTransformerEncoder(
            enc_layer,
            num_layers=layer_num,
        )
        extra_dec_layer = DeformableTransformerExtraDecoderLayer(
            d_model=d_model,
            d_ffn=d_ffn,
            n_levels=n_levels,
            n_heads=head_num,
            n_points=n_points,
        )
        extra_dec = DeformableTransformerExtraDecoder(
            extra_dec_layer,
            num_layers=layer_num,
        )
        self.extra_decs = _get_clones(extra_dec, up_scale_num + 1)
        xy_fc = MLP(d_model, d_model, 2, 3)
        self.xy_fcs = _get_clones(xy_fc, up_scale_num + 1)

    def forward(
        self,
        fir_img: torch.Tensor,
        fir_bou: torch.Tensor,
        pre_img: torch.Tensor,
        pre_bou: torch.Tensor,
        pre_sgm: torch.Tensor,
        cur_img: torch.Tensor,
    ):
        fir_bou = fir_bou / 224
        pre_bou = pre_bou / 224
        (
            fir_img_srcs_flatten,
            fir_spatial_shapes,
            fir_level_start_index,
            fir_valid_ratios,
        ) = self._get_enced_img_scrs(
            fir_img,
            self.fir_enc,
            self.first_layer_norm,
        )
        (
            pre_img_srcs_flatten,
            pre_spatial_shapes,
            pre_level_start_index,
            pre_valid_ratios,
        ) = self._get_enced_img_scrs(
            pre_img,
            self.pre_enc,
            self.previous_layer_norm,
        )
        (
            cur_img_srcs_flatten,
            cur_spatial_shapes,
            cur_level_start_index,
            cur_valid_ratios,
        ) = self._get_enced_img_scrs(
            cur_img,
            self.cur_enc,
            self.current_layer_norm,
        )

        B, S, C = fir_img_srcs_flatten.shape
        first_queries = (
            self.first_query_embed.weight.unsqueeze(0).repeat(B, 1, 1).cuda()
        )
        first_memory, _ = self.fir_dec(
            first_queries,
            fir_bou,
            fir_img_srcs_flatten,
            fir_spatial_shapes,
            fir_level_start_index,
            fir_valid_ratios,
        )
        previous_queries = (
            self.previous_query_embed.weight.unsqueeze(0).repeat(B, 1, 1).cuda()
        )
        previous_memory, _ = self.pre_dec(
            previous_queries,
            pre_bou,
            pre_img_srcs_flatten,
            pre_spatial_shapes,
            pre_level_start_index,
            pre_valid_ratios,
        )
        extra_memory = torch.cat([first_memory, previous_memory], 1)
        extra_memory = self.extra_layer_norm(extra_memory)
        extra_memory = self.pos_enoc(extra_memory)

        cur_queries = (
            self.current_query_embed.weight.unsqueeze(0).repeat(B, 1, 1).cuda()
        )
        init_bou = get_bounding_box(pre_sgm).cuda() / 224
        current_query_num = init_bou.shape[1]
        current_query = cur_queries[:, :current_query_num]
        decode_output, _ = self.extra_decs[0](
            current_query,
            init_bou,
            cur_img_srcs_flatten,
            cur_spatial_shapes,
            cur_level_start_index,
            cur_valid_ratios,
            extra_memory,
        )

        xy_offset = (
            (self.xy_fcs[0](decode_output).sigmoid() - 0.5) * self.offset_limit / 224
        )
        init_bou += xy_offset

        results = [init_bou]
        for i in range(self.up_scale_num):
            new_query = cur_queries[:, current_query_num : current_query_num * 2]
            current_query_num *= 2

            current_query = torch.zeros((B, current_query_num, C)).to(
                cur_img_srcs_flatten.device
            )
            current_query[:, ::2] = new_query
            current_query[:, 1::2] = decode_output

            cur_bou = add_mid_points(results[-1])
            decode_output, _ = self.extra_decs[i + 1](
                current_query,
                cur_bou,
                cur_img_srcs_flatten,
                cur_spatial_shapes,
                cur_level_start_index,
                cur_valid_ratios,
                extra_memory,
            )

            xy_offset = (
                (self.xy_fcs[i + 1](decode_output).sigmoid() - 0.5)
                * self.offset_limit
                / 224
            )
            cur_bou += xy_offset

            results.append(cur_bou)
        results = [result * 224 for result in results]

        if self.training:
            return results
        else:
            result = results[-1]
            result = result.clamp(0, 223)
            return result

    def _get_img_scrs(self, img: torch.Tensor, layernorm: nn.LayerNorm):
        feats = self.featup(img)
        srcs = []
        padding_masks = []
        for low_res in self.medium_level_size:
            srcs.append(
                F.interpolate(
                    feats,
                    size=(low_res, low_res),
                    mode="bilinear",
                ),
            )
        srcs.append(feats)
        for src in srcs:
            padding_masks.append(torch.zeros_like(src[:, 0:1, :, :]).squeeze(1).bool())
        src_flatten = []
        spatial_shapes = []
        for src in srcs:
            src_flatten.append(
                rearrange(src, "b c h w -> b (h w) c"),
            )
            spatial_shapes.append(src.shape[-2:])
        level_start_index = torch.cat(
            (
                torch.tensor([0]),
                torch.cumsum(
                    torch.tensor([x.shape[1] for x in src_flatten]),
                    0,
                )[:-1],
            )
        ).cuda()
        src_flatten = torch.cat(src_flatten, 1).cuda()
        valid_ratios = torch.stack(
            [get_valid_ratio(mask) for mask in padding_masks],
            1,
        ).cuda()
        spatial_shapes = torch.as_tensor(
            spatial_shapes,
            dtype=torch.long,
            device=src_flatten.device,
        )
        src_flatten = layernorm(src_flatten)
        src_flatten = self.pos_enoc(src_flatten)
        return src_flatten, spatial_shapes, level_start_index, valid_ratios

    def _get_enced_img_scrs(
        self,
        img: torch.Tensor,
        encoder: DeformableTransformerEncoder,
        layernorm: nn.LayerNorm,
    ):
        src_flatten, spatial_shapes, level_start_index, valid_ratios = (
            self._get_img_scrs(img, layernorm)
        )
        src_flatten = encoder(
            src=src_flatten,
            spatial_shapes=spatial_shapes,
            level_start_index=level_start_index,
            valid_ratios=valid_ratios,
        )
        return src_flatten, spatial_shapes, level_start_index, valid_ratios


model = DeformVideo().cuda()
model.eval()
results = model(
    fir_img.cuda(),
    fir_point.cuda(),
    pre_img.cuda(),
    pre_point.cuda(),
    pre_mask.cuda(),
    cur_img.cuda(),
)

In [None]:
model.train()
results = model(
    fir_img.cuda(),
    fir_point.cuda(),
    pre_img.cuda(),
    pre_point.cuda(),
    pre_mask.cuda(),
    cur_img.cuda(),
)

In [None]:
results[-1].shape

In [None]:
model = DeformVideo().cuda()
loss_dict = {}
iou_test_dict = {}
optimizer = optim.Adam(model.parameters(), lr=1e-4)
epoch_num = 20

In [None]:
for e in range(epoch_num):
    mean_loss = 0
    train_mean_iou = 0
    for video_idx, frame_idx, first_frame, previous_frame, current_frame in tqdm(
        video_test_loader
    ):
        model.train()
        optimizer.zero_grad()
        fir_img, fir_mask, fir_point = first_frame
        pre_img, pre_mask, pre_point = previous_frame
        cur_img, cur_mask, cur_point = current_frame
        results = model(
            fir_img.cuda(),
            fir_point.cuda(),
            pre_img.cuda(),
            pre_point.cuda(),
            pre_mask.cuda(),
            cur_img.cuda(),
        )
        loss = 0
        for result in results:
            loss += ras_loss(result, cur_mask.cuda())
        loss /= len(results)
        loss.backward()
        optimizer.step()
        mean_loss += loss.item()
        iou = get_batch_average_bou_iou(results[-1], cur_mask.cuda(), gt_rasterizer)
        train_mean_iou += iou.item()
    mean_loss /= len(video_test_loader)
    train_mean_iou /= len(video_test_loader)
    loss_dict[e] = mean_loss
    iou_train_dict[e] = train_mean_iou
    print(f"Epoch {e} train loss: {mean_loss:.4f}, iou: {train_mean_iou:.4f}")

In [None]:
# check the results
model.eval()
video_idx, frame_idx, first_frame, previous_frame, current_frame = next(
    iter(video_test_loader)
)
fir_img, fir_mask, fir_point = first_frame
pre_img, pre_mask, pre_point = previous_frame
cur_img, cur_mask, cur_point = current_frame
result = model(
    fir_img.cuda(),
    fir_point.cuda(),
    pre_img.cuda(),
    pre_point.cuda(),
    pre_mask.cuda(),
    cur_img.cuda(),
)
np_result = result[0].detach().cpu().numpy()
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(normalize(cur_img[0]).permute(1, 2, 0))
plt.imshow(cur_mask[0], alpha=0.5)
plt.plot(np_result[:, 0], np_result[:, 1], "r")
plt.scatter(np_result[:, 0], np_result[:, 1], c="r", s=5)
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(cur_mask[0])
plt.axis("off")
plt.show()

In [None]:
class VideoInferer:
    def __init__(
        self,
        dataset: DAVIS_withPoint,
        gt_rasterizer: SoftPolygon,
    ) -> None:
        self.data_set = dataset.raw_data_set
        self.gt_rasterizer = gt_rasterizer

    def infer_one_video(self, video_idx: int, model: nn.Module):
        infer_results = []
        video_data = self.data_set[video_idx]
        model.eval()
        fir_img, fir_mask, fir_point = video_data[0]
        pre_img, pre_mask, pre_point = video_data[0]
        fir_img = fir_img.unsqueeze(0)
        pre_img = pre_img.unsqueeze(0)
        fir_mask = fir_mask.unsqueeze(0)
        pre_mask = pre_mask.unsqueeze(0)
        fir_point = fir_point.unsqueeze(0)
        pre_point = pre_point.unsqueeze(0)
        infer_results.append(None)
        with torch.no_grad():
            for i in range(1, len(video_data)):
                cur_img, cur_mask, cur_point = video_data[i]
                cur_img = cur_img.unsqueeze(0)
                cur_mask = cur_mask.unsqueeze(0)
                cur_point = cur_point.unsqueeze(0)
                pred_bou = model(
                    fir_img.cuda(),
                    fir_point.cuda(),
                    pre_img.cuda(),
                    pre_point.cuda(),
                    pre_mask.cuda(),
                    cur_img.cuda(),
                )
                pred_mask = self.gt_rasterizer(pred_bou, 224, 224)
                pred_mask[pred_mask == -1] = 0
                iou = get_batch_average_bou_iou(
                    pred_bou, cur_mask.cuda(), self.gt_rasterizer
                )
                infer_results.append((pred_bou, pred_mask, iou.item()))
                pre_img, pre_mask, pre_point = cur_img, pred_mask, pred_bou
        return infer_results

    def infer_all_videos(self, model: nn.Module):
        self.infer_results = []
        for video_idx in tqdm(range(len(self.data_set))):
            infer_results = self.infer_one_video(video_idx, model)
            self.infer_results.append(infer_results)

    def compute_video_iou(self, video_idx: int):
        infer_results = self.infer_results[video_idx]
        ious = [result[-1] for result in infer_results[1:]]
        return np.mean(ious)

    def compute_all_videos_iou(self):
        self.video_ious = []
        for video_idx in range(len(self.data_set)):
            iou = self.compute_video_iou(video_idx)
            self.video_ious.append(iou)
        # return the average iou
        return np.mean(self.video_ious)

    def show_video_results(
        self,
        video_idx: int,
        mask_alpha=0.2,
        img_per_line=5,
    ):
        video_data = self.data_set[video_idx]
        pred_results = self.infer_results[video_idx]
        frame_num = len(video_data)
        line_num = frame_num // img_per_line + 1
        plt.figure(figsize=(20, 4 * line_num))
        for i, pred_data in enumerate(pred_results):
            plt.subplot(line_num, img_per_line, i + 1)
            cur_img, cur_mask, cur_point = video_data[i]
            plt.imshow(normalize(cur_img).permute(1, 2, 0))
            plt.imshow(cur_mask, alpha=mask_alpha)
            if pred_data is None:
                plt.title("ground truth")
                plt.axis("off")
                plt.plot(cur_point[:, 0], cur_point[:, 1], "r")
                plt.scatter(cur_point[:, 0], cur_point[:, 1], c="r", s=5)
            else:
                pred_bou, pred_mask, iou = pred_data
                plt.title(f"iou: {iou:.4f}")
                plt.axis("off")
                pred_bou = pred_bou[0].detach().cpu().numpy()
                plt.plot(pred_bou[:, 0], pred_bou[:, 1], "r")
                plt.scatter(pred_bou[:, 0], pred_bou[:, 1], c="r", s=5)
        

# video_inferer = VideoInferer(video_test_dataset, gt_rasterizer)
# video_inferer.infer_all_videos(model)

In [None]:
short_train_rawset = copy.copy(train_rawset)
short_train_rawset.data_set = train_rawset.data_set[:5]
short_video_train_dataset = DAVIS_withPoint(
    short_train_rawset,
    point_num=64,
    is_train=True,
)
len(short_video_train_dataset.raw_data_set)

In [None]:
short_val_rawset = copy.copy(val_rawset)
short_val_rawset.data_set = val_rawset.data_set[:2]
short_video_val_dataset = DAVIS_withPoint(
    short_val_rawset,
    point_num=64,
    is_train=False,
)
# short_video_val_dataset.raw_data_set = short_video_val_dataset.raw_data_set[6:9]
# len(short_video_val_dataset.raw_data_set)

In [None]:
len(video_train_dataset), len(video_val_dataset), len(short_video_train_dataset), len(short_video_val_dataset)

In [None]:
val_inferer = VideoInferer(short_video_train_dataset, gt_rasterizer)
short_train_loader = DataLoader(short_video_train_dataset, batch_size=1, shuffle=True)

In [None]:
model = DeformVideo().cuda()
loss_dict = {}
iou_train_dict = {}
iou_val_dict = {}
optimizer = optim.Adam(model.parameters(), lr=1e-4)
epoch_num = 9
eval_period = 3

In [None]:
for e in range(epoch_num):
    mean_loss = 0
    train_mean_iou = 0
    for video_idx, frame_idx, first_frame, previous_frame, current_frame in tqdm(
        short_train_loader
    ):
        model.train()
        optimizer.zero_grad()
        fir_img, fir_mask, fir_point = first_frame
        pre_img, pre_mask, pre_point = previous_frame
        cur_img, cur_mask, cur_point = current_frame
        results = model(
            fir_img.cuda(),
            fir_point.cuda(),
            pre_img.cuda(),
            pre_point.cuda(),
            pre_mask.cuda(),
            cur_img.cuda(),
        )
        loss = 0
        for result in results:
            loss += ras_loss(result, cur_mask.cuda())
        loss /= len(results)
        loss.backward()
        optimizer.step()
        mean_loss += loss.item()
        iou = get_batch_average_bou_iou(results[-1], cur_mask.cuda(), gt_rasterizer)
        train_mean_iou += iou.item()
    mean_loss /= len(short_train_loader)
    train_mean_iou /= len(short_train_loader)
    loss_dict[e] = mean_loss
    iou_train_dict[e] = train_mean_iou
    print(f"Epoch {e} train loss: {mean_loss:.4f}, iou: {train_mean_iou:.4f}")
    if e % eval_period == 0 or e == epoch_num - 1:
        val_inferer.infer_all_videos(model)
        val_iou = val_inferer.compute_all_videos_iou()
        iou_val_dict[e] = val_iou
        print(f"Epoch {e} val iou: {val_iou:.4f}")

In [None]:
val_inferer.infer_all_videos(model)

In [None]:
val_inferer.compute_all_videos_iou()

In [None]:
val_inferer.video_ious

In [None]:
val_inferer.show_video_results(1)

### Load the model and check the results

In [None]:
# load the model and check the results
model_name = "deform_video"
log_dir = f"./log/{model_name}"
log_path = f"{log_dir}/{model_name}.log"
model_path = f"./model/{model_name}_best.pth"
# load the log
with open(log_path, "r") as f:
    logs = f.readlines()
loss_dict = {}
iou_train_dict = {}
iou_val_dict = {}
# get the loss and iou using regex
for log in logs:
    # get the epoch
    epoch = re.search(r"Epoch (\d+)", log)
    if epoch is not None:
        epoch = int(epoch.group(1))
        # get the loss
        loss = re.search(r"train loss: ([\d.]+)", log)
        if loss is not None:
            loss_dict[epoch] = float(loss.group(1))
        # get the train iou
        iou = re.search(r"train iou: ([\d.]+)", log)
        if iou is not None:
            iou_train_dict[epoch] = float(iou.group(1))
        # get the val iou
        iou = re.search(r"val iou: ([\d.]+)", log)
        if iou is not None:
            iou_val_dict[epoch] = float(iou.group(1))
# plot the loss and iou
plt.figure(figsize=(10, 10))
plt.subplot(2, 1, 1)
plt.grid()
plt.plot(loss_dict.keys(), loss_dict.values())
plt.xlabel("epoch")
plt.ylabel("loss")
plt.title("train loss")
plt.subplot(2, 1, 2)
plt.grid()
plt.plot(iou_train_dict.keys(), iou_train_dict.values(), label="train iou")
plt.plot(iou_val_dict.keys(), iou_val_dict.values(), label="val iou")
plt.xlabel("epoch")
plt.ylabel("iou")
plt.title("iou")
plt.legend()
plt.show()


In [None]:
# load the model and check the results
model_name = "def_davis_std"
log_dir = f"./log/{model_name}"
log_path = f"{log_dir}/{model_name}.log"
model_path = f"./model/{model_name}_best.pth"
# load the log
with open(log_path, "r") as f:
    logs = f.readlines()
dif_loss_dict = {}
std_loss_dict = {}
iou_train_dict = {}
iou_val_dict = {}
# # get the loss and iou using regex
# for log in logs:
#     # get the epoch
#     epoch = re.search(r"Epoch (\d+)", log)
#     if epoch is not None:
#         epoch = int(epoch.group(1))
#         # get the dif loss
#         dif_loss = re.search(r"dif loss: ([\d.]+)", log)
#         if dif_loss is not None:
#             dif_loss_dict[epoch] = float(dif_loss.group(1))
#         # get the std loss
#         std_loss = re.search(r"std loss: ([\d.]+)", log)
#         if std_loss is not None:
#             std_loss_dict[epoch] = float(std_loss.group(1))
#         # get the train iou
#         iou = re.search(r"train iou: ([\d.]+)", log)
#         if iou is not None:
#             iou_train_dict[epoch] = float(iou.group(1))
#         # get the val iou
#         iou = re.search(r"val iou: ([\d.]+)", log)
#         if iou is not None:
#             iou_val_dict[epoch] = float(iou.group(1))
with open(f"{log_dir}/dif_loss.json", "r") as f:
    dif_loss_dict = json.load(f)
with open(f"{log_dir}/std_loss.json", "r") as f:
    std_loss_dict = json.load(f)
with open(f"{log_dir}/iou_train.json", "r") as f:
    iou_train_dict = json.load(f)
with open(f"{log_dir}/iou_val.json", "r") as f:
    iou_val_dict = json.load(f)
# plot the loss and iou
plt.figure(figsize=(10, 10))
plt.subplot(2, 1, 1)
plt.grid()
plt.plot(dif_loss_dict.keys(), dif_loss_dict.values(), label="dif loss")
plt.plot(std_loss_dict.keys(), std_loss_dict.values(), label="std loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.title("train loss")
plt.subplot(2, 1, 2)
plt.grid()
plt.plot(iou_train_dict.keys(), iou_train_dict.values(), label="train iou")
plt.plot(iou_val_dict.keys(), iou_val_dict.values(), label="val iou")
plt.xlabel("epoch")
plt.ylabel("iou")
plt.title("iou")
plt.legend()
plt.show()

In [None]:
# load the model
model = DeformVideo().cuda()
model.load_state_dict(torch.load(model_path))

In [None]:
train_inferer = VideoInferer(video_train_dataset, gt_rasterizer)
val_inferer = VideoInferer(video_val_dataset, gt_rasterizer)

In [None]:
train_inferer.infer_all_videos(model)
val_inferer.infer_all_videos(model)

In [None]:
train_inferer.compute_all_videos_iou(), val_inferer.compute_all_videos_iou()

In [None]:
# print all the ious
for i, iou in enumerate(train_inferer.video_ious):
    print(f"train video {i} iou: {iou:.4f}")
for i, iou in enumerate(val_inferer.video_ious):
    print(f"val video {i} iou: {iou:.4f}")

In [None]:
# plot the train iou distribution
plt.figure(figsize=(10, 5))
plt.hist(train_inferer.video_ious, bins=10)
plt.title("train iou distribution")
plt.xticks(np.arange(0, 1.1, 0.1))
plt.show()
# plot the validation iou distribution
plt.figure(figsize=(10, 5))
plt.hist(val_inferer.video_ious, bins=10)
plt.title("val iou distribution")
plt.xticks(np.arange(0, 1.1, 0.1))
plt.show()

In [None]:
train_check_idxs = [0, 100, 46, 55, 96, 1, 41, 21, 28, 51, 76, 93]
for idx in train_check_idxs:
    train_inferer.show_video_results(idx)

In [None]:
val_check_idxs = [2, 7, 9, 10, 18, 24, 13, 41, 46, 50, 54, 39]
for idx in val_check_idxs:
    val_inferer.show_video_results(idx)