In [None]:
# !pip install datasets evaluate torch torchvision 

from tqdm import tqdm
from datasets import load_dataset
import torch 
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
import evaluate
from coco_hf_dataset import coco_hf_dataset, expand_gray_channel

ds = load_dataset('CVdatasets/CocoSegmentationOnlyVal5000',  use_auth_token="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")
IMG_SIZE = 64
NC = 21  # Number of classes

In [35]:

img_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
    expand_gray_channel(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
mask_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.hub.load('pytorch/vision:v0.10.0', 'deeplabv3_resnet50', pretrained=True).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = .00001)

coco_hf = coco_hf_dataset(ds['train'], mask_transform=mask_transforms, img_transform=img_transforms, size=IMG_SIZE)
train_loader = DataLoader(coco_hf, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)

Using cache found in /Users/elliottchartock/.cache/torch/hub/pytorch_vision_v0.10.0


In [36]:
import torch
import numpy as np
from collections import defaultdict
import torch.nn.functional as F
import cv2
# edited mask to boundary function with np.where to avoid unexpected behaviour
import cv2 

def denorm(in_images):
    """Goes from small float values to pixel values 0 to 255."""
    if isinstance(in_images, torch.Tensor):
        images = in_images.clone().float()
    else:
        raise ValueError("Input must be tensor")
    if images[0].max() <= 1:
        images *= 255  # de-normalise (optional)
    if images.dim() != 4 or images.size(1) != 3:
        raise ValueError("Input tensor must have shape (n, 3, w, h).")
    return images.permute(0, 2, 3, 1).to(torch.int8).cpu().numpy().astype(np.uint8)

def dep_cls(probs, y, logits=False):
    """DEP for image-level classification."""
    # probs: float, (bs, n_rows, n_classes)
    # y: int, (bs, n_rows,)
    # if `logits` is True, we convert the probabilities to logits, and compute the margin on the logits.
    bs = probs.shape[0]
    if logits:
        # if we have probabilities, we can convert to logits by taking a logarithm.
        # we clip the probabilities first to avoid taking the log of 0.
        values = torch.log(torch.clamp(probs, min=1e-8))
    else:
        values = probs

    y_indices = y.reshape((bs, -1, 1)).expand(-1, -1, values.shape[2])
    value_at_ground_truth = torch.gather(values, 2, y_indices)[:, :, 0]

    next_highest = values.clone()
    next_highest.scatter_(2, y_indices, 0)
    next_highest = next_highest.max(dim=2).values

    return 1-(value_at_ground_truth - next_highest)

# edited mask to boundary function with np.where to avoid unexpected behaviour
def mask_to_boundary(mask, dilation_ratio=0.02):
    """
    Convert binary mask to boundary mask.
    :param mask (numpy array, uint8): binary mask
    :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
    :return: boundary mask (numpy array)
    """
    if mask.shape[1] == 1:
        mask = mask.squeeze(1)
    mask = mask.astype(np.uint8)
    n, h, w = mask.shape
    for im in range(n):
        img_diag = np.sqrt(h ** 2 + w ** 2)
        dilation = int(round(dilation_ratio * img_diag))
        if dilation < 1:
            dilation = 1
        # Pad image so mask truncated by the image border is also considered as boundary.
        new_mask = cv2.copyMakeBorder(mask[im], 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
        kernel = np.ones((3, 3), dtype=np.uint8)
        new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation)
        mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]
        # if the number does not equal either the old mask or 0 then set to 0
        mask_erode = np.where(mask_erode != mask[im], 0, mask_erode)
        boundary_mask = mask[im] - mask_erode
        # G_d intersects G in the paper.
        mask[im] = boundary_mask
    return mask

    

class StoreHook:
    def on_finish(*args,**kwargs):
        pass

    def hook(self,model, model_input, model_output):
        self.model = model
        self.model_input = model_input
        self.model_output = model_output['out']
        self.on_finish(model_input,model_output)

class BatchLogger:
    def __call__(self, batch):
        self.batch = batch
        return self.batch

class Manager:
    def __init__(self, model, num_classes: int = 10):
        self.step_pred = StoreHook()
        self.step_pred.h = model.register_forward_hook(self.step_pred.hook)
        self.bl = BatchLogger()
        self.hooked = True
        self.step_pred.on_finish = self._after_pred_step
        self.split = 'Train' # hard coded for now
        self.number_classes = num_classes

    def _after_pred_step(self, *args, **kwargs):
        with torch.no_grad():
            logging_data = self.bl.batch
            preds = self.step_pred.model_output
            
            # checks whether the model is (n, classes, w, h), or (n, w, h, classes)
            if preds.shape[1] == self.number_classes:
                preds = preds.permute(0, 2, 3, 1)
            
            argmax = torch.argmax(preds.permute(0, 2, 3, 1).clone(), dim=-1).unsqueeze(1)
            self.logits = preds.cpu().numpy() # (bs, w, h, classes)
            self.boundary_gold_masks = mask_to_boundary(logging_data['mask'].clone().cpu().numpy()) # (bs, w, h)
            self.boundary_pred_masks = mask_to_boundary(argmax.clone().cpu().numpy()) # (bs, w, h)
            if logging_data['mask'].shape[1] == 1:
                logging_data['mask'] = logging_data['mask'].squeeze(1) # (bs, w, h)
            self.gold_mask = logging_data['mask'].cpu().numpy()  # (bs, w, h)
            self.img_ids = logging_data['idx'].cpu().numpy()  # np.ndarray (bs,)
            # dq.log_model_outputs
            print("cool")
            
        # dq log model output

    
    def register_hooks(self, model):
        self.step_embs.h = model.register_forward_hook(self.step_embs.hook)
        


m = Manager(model, num_classes=NC)

In [37]:

epochs = 1
scaler = torch.cuda.amp.GradScaler()

with torch.autocast('cuda'):
    for epoch in range(epochs):
        for j, sample in enumerate(tqdm(train_loader)):
            imgs, masks = sample['image'], sample['mask']
            bs = imgs.shape[0]
            # log our batch - will need to do this automatically somehow
            m.bl.batch = sample
            out = model(imgs.to(device))

            # reshape to have loss for each pixel (bs * h * w, 21)
            pred = out['out'].permute(0, 2, 3, 1).contiguous().view( -1, 21)
            masks = masks.long()
            msks_for_loss = masks.view(-1).to(device)

            loss = criterion(pred, msks_for_loss)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()


            if j == 0: break
        if epoch == 0: break



  0%|          | 0/2016 [00:00<?, ?it/s]

> [0;32m/var/folders/pd/6x8lc6xj7w74_4vxq68yz2gm0000gn/T/ipykernel_86143/1898291775.py[0m(118)[0;36m_after_pred_step[0;34m()[0m
[0;32m    116 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    117 [0;31m            [0;31m# dq.log_model_outputs[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 118 [0;31m            [0mprint[0m[0;34m([0m[0;34m"cool"[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    119 [0;31m[0;34m[0m[0m
[0m[0;32m    120 [0;31m        [0;31m# dq log model output[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m
(2, 64, 64, 21)
(2, 64, 64, 21)
(2, 64, 64, 21)
cool


  0%|          | 0/2016 [1:15:05<?, ?it/s]


In [15]:
m.boundary_pred_masks.shape

(2, 64, 64)