In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import numpy as np
from utils.dataloader import get_dataloaders
import matplotlib
from utils.mlp_mixer import create_binary_mask

  from .autonotebook import tqdm as notebook_tqdm


In [43]:
from types import SimpleNamespace

In [44]:
with open('avg_attns_trainset.json') as json_file:
    data = json.load(json_file)

In [45]:
def get_mask_batch(image, idx, attn_dict, drop_lambda):
    idx_np = idx.numpy()
    w_featmap = int(np.sqrt(len(attn_dict[str(0)]))) # 14 0 is a random key
    h_featmap = int(np.sqrt(len(attn_dict[str(0)]))) # 14
    scale = image.shape[2] // w_featmap # to pass to interpolate
    batch_size = len(idx)

    batch_array = [] # collect attn maps
    for i in range(batch_size):
        batch_array.append(np.array(attn_dict[str(idx_np[i])]))
    batch_tensor = torch.tensor(batch_array)

    val, indices = torch.sort(batch_tensor, dim=1)
    threshold = torch.quantile(val, drop_lambda, dim=1)
    th_attn = val >= threshold[:,None]
    idx2 = torch.argsort(indices, dim=1) # rearrange patch positions
    for batch_idx in range(th_attn.shape[0]):
        th_attn[batch_idx] = th_attn[batch_idx][idx2[batch_idx]]

    th_attn = th_attn.float() # bool -> float
    bin_mask = th_attn.reshape(-1, w_featmap, h_featmap)
    mask = torch.nn.functional.interpolate(bin_mask.unsqueeze(1), scale_factor=scale, mode="nearest")
    return mask

In [46]:
def gen_mask(masks, unfold_fn):
    patched_tensor = unfold_fn(masks.repeat(1,3,1,1))
    patched_tensor = patched_tensor.permute(0,2,1)
    return patched_tensor

In [47]:
args = SimpleNamespace()
args.dataset = 'c10'
args.model = 'mlp_mixer'
args.batch_size = 10
args.eval_batch_size = 10
args.num_workers = 4
args.seed = 0
args.epochs = 300
args.patch_size = 4
args.autoaugment = False
args.use_cuda = False
args.size = 224
args.split = 'index'

In [48]:
train_dataloader, test_dataloader = get_dataloaders(args)

Files already downloaded and verified
Files already downloaded and verified


In [49]:
image, label, index = next(iter(train_dataloader))

In [50]:
def project_bin_mask(image, index, data, lambda_drop, ps, in_chan, hidden):
    A = torch.randn(hidden, in_chan * ps * ps)
    unfold_fn = nn.Unfold(kernel_size=(ps, ps), stride=ps)
    mask = get_mask_batch(image, index, data, lambda_drop)
    patched = gen_mask(mask, unfold_fn)
    output = torch.nn.functional.linear(patched, A, bias=None)
    bin_output = create_binary_mask(output)
    return bin_output

In [51]:
mask = project_bin_mask(image, index, data, 0.1, 16, 3, 512)

In [53]:
torch.sum(mask[0])

tensor(90112.)

In [56]:
100352 - (20*512)

90112

: 