You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
103: # for image
104: _visual_mask = torch.zeros((batch_size, visual_token_num), dtype=torch.float32, device=device)
105: # need to mask token content in selected_idx for prediction/generation
106: num_masks = random.randint(max(1, int(0.1 * visual_token_num)), visual_token_num)
107: selected_idx = random.sample(range(visual_token_num), num_masks)
108: _visual_mask[:, selected_idx] = 1
109: mask_position = (_visual_mask == 1).to(torch.long).view(-1)
110: mask_position = mask_position.nonzero().squeeze()
I think '_visual_mask = 1' means the model can see it, '_visual_mask = 0' is the opposite. The above codes randomly sample mask position, which selects which grid(8*8) the model can see(_visual_mask=1). The position that really needs to be masked is the position where the _visual_mask is equal to 0. So the code on line 109 should be changed to mask_position = (_visual_mask == 0).to(torch.long).view(-1)
is this right?
The text was updated successfully, but these errors were encountered:
In train.py(103-110)
I think '_visual_mask = 1' means the model can see it, '_visual_mask = 0' is the opposite. The above codes randomly sample mask position, which selects which grid(8*8) the model can see(_visual_mask=1). The position that really needs to be masked is the position where the _visual_mask is equal to 0. So the code on line 109 should be changed to
mask_position = (_visual_mask == 0).to(torch.long).view(-1)
is this right?
The text was updated successfully, but these errors were encountered: