In [None]:
# default_exp models.dino

In [None]:
## models.dino

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#skip
!pip install nbdev

import os
import yaml
from google.colab import drive
drive.mount('/content/drive')

with open("drive/MyDrive/config/secrets.yaml", 'r') as stream:
    secrets = yaml.safe_load(stream)

!export AWS_SHARED_CREDENTIALS_FILE=/content/drive/MyDrive/config/awscli.ini
path = "/content/drive/My Drive/config/awscli.ini"
os.environ['AWS_SHARED_CREDENTIALS_FILE'] = path

!git clone 'https://{secrets['ACCESS_TOKEN']}@github.com/willkunz13/synthetic_im'
!git checkout iter
%cd synthetic_im/
!pip install -e .


In [4]:
#export
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from skimage import io, transform
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import sys
import copy
from pathlib import Path
import synthetic_im.vision_transformer as vision_transformer

In [10]:
#deits8 = torch.hub.load('facebookresearch/dino:main', 'dino_deits8')
#torch.save(deits8, 'test_data/models/dino_small.pt')

Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main


In [7]:
#export

class DinoDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, path, transform=None):
        self.path = path
        self.files = list(Path(self.path).rglob('*.jpg'))
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = self.files[idx]
        image = io.imread(img_name)
        if self.transform:
            image = self.transform(image)

        return image

In [51]:
data_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((512, 512)),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


data_dir = 'test_data/cpu_test'
image_datasets = DinoDataset(data_dir, data_transforms)
                  
dataloader = torch.utils.data.DataLoader(image_datasets, batch_size=2,
                                             shuffle=False, num_workers=2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [52]:
assert isinstance(dataloader, DataLoader)
assert isinstance(dataloader.dataset, DinoDataset)
assert device == "cude:0" or "cpu"

assert dataloader.dataset.__getitem__(0).shape == torch.Size([3, 512, 512])
assert dataloader.dataset.__getitem__(1).shape == torch.Size([3, 512, 512])
try:
  assert dataloader.dataset.__getitem__(11).shape == False
except:
  assert True == True

In [6]:
sys.path.insert(0, './test_data/models')
sys.path.insert(0, './synthetic_im')

model = torch.load('./test_data/models/dino_small.pt')

In [55]:
#export
def predict_images(model = model, patch_size = 8, threshold = .6, output_dir = Path('data/output'), device=device):
    was_training = model.training
    for p in model.parameters():
      p.requires_grad = False
    model.eval()
    model.to(device)

    with torch.no_grad():
        for i, inputs in enumerate(dataloader):
            w, h = inputs.shape[2] - inputs.shape[2] % patch_size, inputs.shape[3] - inputs.shape[3] % patch_size
            inputs = inputs[:, :, :w, :h]


            w_featmap = inputs.shape[-2] // patch_size
            h_featmap = inputs.shape[-1] // patch_size

            attentions = model.forward_selfattention(inputs.to(device))
            bs = attentions.shape[0] # batch size
            nh = attentions.shape[1] # number of head

            # we keep only the output patch attention
            attentions = attentions[:, :, 0, 1:].reshape(bs,nh, -1)

            # we keep only a certain percentage of the mass
            val, idx = torch.sort(attentions)
            val /= torch.sum(val, dim=2, keepdim=True)
            cumval = torch.cumsum(val, dim=2)
            th_attn = cumval > (1 - threshold)
            idx2 = torch.argsort(idx)
            for batch in range(bs):
                for head in range(nh):
                    th_attn[batch,head] = th_attn[batch,head][idx2[batch,head]]
            th_attn = th_attn.reshape(bs,nh, w_featmap, h_featmap).float()
            # interpolate
            th_attn = nn.functional.interpolate(th_attn, scale_factor=patch_size, mode="nearest").cpu().numpy()

            attentions = attentions.reshape(bs, nh, w_featmap, h_featmap)
            attentions = nn.functional.interpolate(attentions, scale_factor=patch_size, mode="nearest").cpu().numpy()

            os.makedirs(output_dir, exist_ok=True)

            for batch in range(attentions.shape[0]):
                out_img = attentions[batch].sum(0)
                fname = str(output_dir) + '/attn-' + dataloader.dataset.files[i * bs + batch].name
                plt.imsave(
                    fname=fname,
                    arr=out_img,
                    cmap="inferno",
                    format="jpg"
                )
                print(f"{fname} saved.")

In [62]:
predict_images(output_dir= 'test_data/output', device="cpu")

assert os.listdir('test_data/output') != []
assert len(os.listdir('test_data/output')) == 2
assert plt.imread('test_data/output/attn-1.jpg').shape == (512, 512, 3)

dir = 'test_data/output'
for f in os.listdir(dir):
 os.remove(os.path.join(dir, f))