In [1]:
from train_ssl_medsam_backbone import get_dataloaders, ModelFactory

train, val, test = get_dataloaders()

In [2]:
image = next(iter(train))[0]
image = image.cuda()

In [3]:
from medAI.utils.masking_generator import MaskingGenerator
mask_gen = MaskingGenerator((64, 64), int(64 * 64 * 0.3), min_num_patches=16, max_num_patches=100)
import torch 

def generate_masks(image):
    masks = []
    for _ in range(image.shape[0]):
        masks.append(torch.from_numpy(mask_gen()).bool().to(image.device)) 
    return torch.stack(masks)

generate_masks(image).shape

torch.Size([4, 64, 64])

In [4]:
import torch
from torch import nn
from segment_anything.modeling.common import LayerNorm2d
from medAI.utils.masking_generator import MaskingGenerator
from medAI.modeling.swav import sinkhorn_knopp
from copy import deepcopy   


@torch.no_grad()
def do_ema_update(teacher, student, alpha=0.999):
    for teacher_param, student_param in zip(teacher.parameters(), student.parameters()):
        teacher_param.data.mul_(alpha).add_(1 - alpha, student_param.data)


class IBotStyleModel(nn.Module):
    def __init__(
        self,
        encoder_transformer_dim=768,
        proj_dim=512,
        num_classes=1024,
        feature_map_size=64,
        min_num_patches=16,
        max_num_patches=100,
        mask_ratio=0.3,
        ema_alpha=0.999,
        lambda_=20,
    ):
        super().__init__()
        self.mask_gen = MaskingGenerator(
            (feature_map_size, feature_map_size),
            int(feature_map_size * feature_map_size * mask_ratio),
            min_num_patches=min_num_patches,
            max_num_patches=max_num_patches,
        )
        self.student = MaskableMedSAMWithProjection(encoder_transformer_dim, proj_dim, num_classes)
        self.teacher = deepcopy(self.student)
        self.ema_alpha = ema_alpha
        self.lambda_ = lambda_  

    def generate_masks(self, image):
        masks = []
        for _ in range(image.shape[0]):
            masks.append(torch.from_numpy(self.mask_gen()).bool().to(image.device))
        return torch.stack(masks)

    def forward(self, image):
        mask = self.generate_masks(image)
        with torch.no_grad(): 
            target_token_scores = self.teacher(image, mask=None)
            target_token_scores = target_token_scores.permute(0, 2, 3, 1)
            target_token_scores = target_token_scores[mask]

            # compute target distributions - sinkhorn_knopp centering
            n_targets, n_sources = target_token_scores.shape
            target_row_sum = torch.ones((n_targets, 1)).cuda()
            target_col_sum = torch.ones((n_sources, 1)).cuda() * n_targets / n_sources 
            K = torch.exp(target_token_scores * self.lambda_)
            target_dist = sinkhorn_knopp(K, target_row_sum, target_col_sum, n_iters=3)

        student_token_scores = self.student(image, mask=mask)
        student_token_scores = student_token_scores.permute(0, 2, 3, 1)
        student_token_scores = student_token_scores[mask]

        loss = torch.sum(-target_dist * torch.log_softmax(student_token_scores, dim=-1), dim=-1).mean()

        return loss

    def ema_update(self):
        do_ema_update(self.teacher, self.student, self.ema_alpha)

    @property
    def image_encoder(self):
        return self.student.encoder


class MaskableMedSAMWithProjection(nn.Module):
    def __init__(self, encoder_transformer_dim=768, proj_dim=512, ntokens=1024):
        super().__init__()
        self.encoder = ModelFactory._medsam_image_encoder()

        self.proj = nn.Sequential(
            nn.Conv2d(256, proj_dim, 1),
            LayerNorm2d(proj_dim),
            nn.Conv2d(512, ntokens, 1),
        )

        self.mask_token = torch.nn.Parameter(torch.randn(encoder_transformer_dim))

    def forward(self, image, mask=None):
        embed = self.encoder.patch_embed(image)  # B, N, H, W

        if mask is not None:
            embed[mask] = self.mask_token
        
        # do the rest of the forward pass
        x = embed

        if self.encoder.pos_embed is not None:
            x = x + self.encoder.pos_embed

        for blk in self.encoder.blocks:
            x = blk(x)

        x = self.encoder.neck(x.permute(0, 3, 1, 2))
        x = self.proj(x)
        return x


model = IBotStyleModel().cuda()
loss = model(image)

In [5]:
loss

tensor(6.8620, device='cuda:0', grad_fn=<MeanBackward0>)

In [5]:
out_scores.shape

torch.Size([4908, 1024])

In [13]:
n_targets, n_sources = out_scores.shape
target_row_sum = torch.ones((n_targets, 1)).cuda()
target_col_sum = torch.ones((n_sources, 1)).cuda() * n_targets / n_sources

K = torch.exp(out_scores * 20)

out = sinkhorn_knopp(K, target_row_sum, target_col_sum, n_iters=3, last_norm='row')

In [14]:
out.sum(-1)

tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0',
       grad_fn=<SumBackward1>)

In [12]:
import matplotlib.pyplot as plt
plt.imshow(out.detach().cpu().numpy() * 10000)

NameError: name 'out' is not defined

In [15]:
next(iter(model.student.state_dict().values())) == next(iter(model.teacher.state_dict().values()))

tensor([ True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
        False,  True, False,  True,  True,  True,  True,  True, False,  True,
         True, False,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False,  True,
         True,  True,  True,  True, False,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True, False,  True,  True,  True,  True,
         True,  True, False,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True,  True,  True,  True,  True,  True, False, False,  True,  True,
         True,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True,  True, False, False,  True, False,  True,  True,  True,  True,
        False, False, False,  True,  True,  True,  True,  True, 

In [7]:
_.shape

torch.Size([4, 1024, 64, 64])

In [18]:
from copy import deepcopy

ImageEncoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
      )
    )
  )
  (neck): Sequential(
    (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): LayerNorm2d()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): LayerNorm2d()
  )
)

In [17]:
encoder.patch_embed(image).shape

torch.Size([4, 64, 64, 768])

torch.Size([4, 3, 1024, 1024])

In [None]:
from mas

In [10]:
encoder(image)

tensor([[[[-2.2027e-02, -8.9647e-03, -1.7505e-02,  ..., -2.8102e-02,
           -2.8458e-02, -2.1480e-02],
          [-1.4018e-02, -1.3767e-02, -2.2677e-02,  ..., -1.2591e-02,
           -1.0554e-02, -6.5527e-03],
          [-8.7758e-03, -2.1652e-02, -1.2533e-02,  ..., -2.1075e-02,
           -6.1827e-03, -1.4495e-02],
          ...,
          [-1.2857e-02, -1.1874e-02, -1.3135e-02,  ..., -1.6245e-02,
           -9.1269e-03, -1.9114e-02],
          [-6.8033e-03, -9.6007e-04, -1.0340e-02,  ..., -8.4152e-03,
           -8.4766e-03, -1.1661e-02],
          [-1.8129e-02, -1.1930e-02, -1.7203e-02,  ..., -1.8661e-02,
           -2.0748e-02, -2.2367e-02]],

         [[-1.1878e-01, -5.9102e-02, -7.6664e-02,  ...,  9.4831e-03,
           -9.0445e-02,  6.7099e-03],
          [-8.2192e-02, -1.0537e-01, -4.9334e-02,  ..., -1.1747e-02,
           -4.1459e-02, -7.9624e-02],
          [-4.2098e-02, -8.5422e-02, -1.2441e-01,  ..., -5.0162e-03,
           -1.1572e-01,  2.7252e-02],
          ...,
     

In [6]:
from tqdm import tqdm 
import time 

for _ in tqdm(range(100), mininterval=10): 
    time.sleep(1.5)

100%|██████████| 100/100 [02:30<00:00,  1.50s/it]


In [1]:
import os 
os.environ['WANDB_RUN_ID'] = "12345678"

import wandb
wandb.init(project='test')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpfrwilson[0m. Use [1m`wandb login --relogin`[0m to force relogin
