In [1]:
%cd ..

/fs01/home/pwilson/projects/ibot


In [2]:

%load_ext autoreload
%autoreload

from omegaconf import OmegaConf
from main_ibot import build_models
conf = OmegaConf.load("conf_new/main_ibot.yaml")

student, teacher = build_models(conf)



In [3]:


teacher.backbone


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
  (f

In [4]:
import os
import dotenv
from torchvision.datasets import ImageFolder
dotenv.load_dotenv()

env = os.environ


In [5]:


import torch
from torch.utils.data import DataLoader, dataloader
from torch.utils.data.distributed import Dataset
from warnings import warn 
import torchvision
from torchvision.transforms import transforms
from src.transform import NormalizeToTensor
from torchvision import transforms as T 


transform = T.Compose([
    T.Resize((224, 224)),
    NormalizeToTensor()
])

train_ds = ImageFolder(os.path.join(env['NCT_PATCHES'], 'train'), transform=transform, target_transform=lambda l: torch.tensor(l).long())
train_loader = DataLoader(train_ds, batch_size=8)
val_ds = ImageFolder(os.path.join(env['NCT_PATCHES'], 'val'), transform=transform, target_transform=lambda l: torch.tensor(l).long())
val_loader = DataLoader(val_ds)
im = next(iter(train_loader))[0]

In [8]:
print(im.shape)
teacher.backbone(im, return_all_tokens=True).shape

torch.Size([8, 3, 224, 224])


torch.Size([8, 197, 384])

In [9]:
teacher.backbone.masked_im_modeling

False

In [10]:
teacher.backbone.forward.__code__.co_varnames

('self', 'x', 'return_all_tokens', 'mask', 'blk')

In [11]:
teacher.backbone

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
  (head): Identity()
  (f

In [13]:
%autoreload

from src.ssl_evaluation import FineTuning

teacher.backbone

ft = FineTuning(in_features=384)
ft.run(teacher.backbone, train_loader, val_loader)

Train epoch 0:   0%|          | 0/250 [00:00<?, ?it/s]

Train epoch 0:   0%|          | 0/250 [00:00<?, ?it/s]


RuntimeError: Expected target size [8, 2], got [8]

In [None]:
image.shape

torch.Size([3, 512, 512])

In [None]:
image, label = next(iter(train_loader))
student.backbone.masked_im_modeling = False
student.backbone(image)
student([image])[0].shape

torch.Size([8, 8192])

In [None]:
import torch.distributed
from tqdm import tqdm 
import torch
from torch import nn
import numpy as np 
from sklearn.metrics import *
import matplotlib.pyplot as plt
import warnings 
import wandb 


class IBOTModule(nn.Module):
    def __call__(self, x: torch.Tensor) -> torch.Tensor: 
        """Run forward pass on the model. 
        
        Args: 
            x: Input tensor. should be an image tensor of shape B, C, H, W

        Returns:
            Output tensor. This will be a tensor of shape B, N, C, where B is the batch size 
            and N is the number of tokens (one token per patch plus one for the cls token). 
            cls token is the 0'th token.
        """


def compute_binary_classification_metrics(y_score, y_true, log_images=False):
    """Calculate metrics for the cancer classification problem.

    Args:
        y_score (np.array or torch.Tensor) - A column vector of predicted probabilities for
            cancer (1) or benign(0)
        y_true (np.array or torch.Tensor) - A column vector of true labels for cancer (1) or benign(0)
        log_images (bool) - If True, log images of the histogram of predictions and the ROC curve to
            wandb. Default is False.
    """

    if isinstance(y_score, torch.Tensor):
        y_score = y_score.cpu().numpy()
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().numpy()

    # augmentations can cause NaNs
    nanvalues = np.isnan(y_score)
    y_score = y_score[~nanvalues]
    y_true = y_true[~nanvalues]

    metrics = {}

    try: 
        metrics["auc"] = roc_auc_score(y_true, y_score)
    except ValueError:
        warnings.warn("ROC AUC score could not be calculated. Setting to 0.5")
        metrics["auc"] = 0.5

    # find the sensitivity at fixed specificities
    fpr, tpr, thresholds = roc_curve(y_true, y_score)

    for specificity in [0.20, 0.40, 0.60, 0.80]:
        sensitivity = tpr[np.argmax(fpr > 1 - specificity)]
        metrics[f"sens_at_{specificity*100:.0f}_spe"] = sensitivity

    # choose the threshold that maximizes balanced accuracy
    best_threshold = thresholds[np.argmax(tpr - fpr)]
    metrics["f1"] = f1_score(y_true, y_score > best_threshold)

    if log_images:
        plt.hist(y_score[y_true == 0], bins=100, alpha=0.5, density=True)
        plt.hist(y_score[y_true == 1], bins=100, alpha=0.5, density=True)
        plt.legend(["Benign", "Cancer"])
        plt.xlabel(f"Probability of cancer")
        plt.ylabel("Density")
        plt.title(f"AUC: {metrics['auc']:.3f}")
        metrics["histogram"] = wandb.Image(plt, caption="Histogram of core predictions")
        plt.close()

        plt.figure()
        plt.plot(fpr, tpr)
        plt.xlabel("False positive rate")
        plt.ylabel("True positive rate")
        plt.title("ROC curve")
        metrics["roc_curve"] = wandb.Image(plt, caption="ROC curve")
        plt.close()

    metrics['balanced_accuracy'] = balanced_accuracy_score(y_true, y_score > best_threshold)
    metrics['auprc'] = average_precision_score(y_true, y_score)

    return metrics


class LinearProbing: 
    def __init__(self, train_loader, val_loader, device):
        self.train_loader = train_loader 
        self.val_loader = val_loader 
        self.device = device

    @torch.no_grad()
    def _extract_features(self, loader, model: IBOTModule, desc: str = None): 
        model.eval().to(self.device)

        features = []
        labels = []
        for (image, label) in tqdm(loader, desc=desc): 
            image = image.to(self.device)
            label = label.to(self.device)
            outputs = model(image)
            cls = outputs[:, 0, :]
            cls = concat_all_gather(cls)
            label = concat_all_gather(label)
            features.append(cls)
            labels.append(label)

        features = torch.cat(features, dim=0)
        labels = torch.cat(labels, dim=0)

        return features, labels 

    def run_probing(self, model: IBOTModule):
        """Returns the metrics for the linear probing task."""

        metrics = {}

        X_train, y_train = self._extract_features(self.train_loader, model)
        X_train = X_train.cpu().numpy()
        y_train = y_train.cpu().numpy()
        X_val, y_val = self._extract_features(self.val_loader, model)
        X_val = X_val.cpu().numpy()
        y_val = y_val.cpu().numpy()

        if torch.distributed.is_initialized(): 
            if torch.distributed.get_rank() != 0: 
                return 

        from sklearn.linear_model import LogisticRegression
        clf = LogisticRegression(max_iter=5000, class_weight='balanced')
        clf.fit(X_train, y_train)

        y_pred_train = clf.predict_proba(X_train)[:, -1]
        y_pred_val = clf.predict_proba(X_val)[:, -1]

        metrics['train'] = compute_binary_classification_metrics(y_pred_train, y_train)
        metrics['val'] = compute_binary_classification_metrics(y_pred_val, y_val)
 
        return metrics


lp = LinearProbing(train_loader, val_loader, device='cuda')
metrics = lp.run_probing(student.backbone)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:19<00:00, 12.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 885/885 [00:14<00:00, 61.81it/s]


In [None]:
import wandb
wandb.init(project="ibot", job_type="linear-probing")
wandb.log(metrics)

[34m[1mwandb[0m: Currently logged in as: [33mpfrwilson[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
features.shape

torch.Size([1998, 2048])

In [None]:
labels.shape

labels.float().mean()

tensor(0.1426, device='cuda:0')