# Extract filters/patches from the training dataset

In [1]:
import os
import sys 
import inspect
import timeit
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

CURR_DIR = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
PARENT_DIR = os.path.dirname(CURR_DIR)
sys.path.insert(0, CURR_DIR)
sys.path.insert(0, PARENT_DIR)

from dataset.pytorch_dataset import GeoLifeCLEF2022Dataset

In [2]:
REGION = "both"
BAND = "rgb"
bs = 64
num_species = 17037
print("loading dataset ...")
dataset = GeoLifeCLEF2022Dataset(
    "/network/scratch/s/sara.ebrahim-elkafrawy/",
    subset="train",
    region=REGION,
    patch_data=BAND,
    use_ffcv_loader=False,
    use_rasters=False,
    patch_extractor=None,
    transform=None,  # transform=transforms.Compose([transforms.ToTensor()])
    target_transform=None,
)
print("finished loading dataset")

loader = DataLoader(dataset, batch_size=bs, num_workers=4, shuffle=True)


loading dataset ...
finished loading dataset


In [3]:
data = next(iter(loader))

In [4]:
data[0]['rgb'].shape  # [64, 3, 256, 256]
data[2]['obs_id']

tensor([10681314, 21744645, 20156896, 20259306, 20678203, 10696444, 21932948,
        20659169, 20040370, 21030916, 21001205, 20328023, 10269976, 10109820,
        21352089, 10303790, 20822713, 21592465, 21498629, 21293261, 20574827,
        10295205, 22061759, 10031782, 10673305, 20687463, 10059540, 20890999,
        10574562, 20115039, 20689416, 10239824, 21164737, 10702637, 20730004,
        10641335, 21111585, 10479052, 21545438, 10086212, 10438841, 10786869,
        20841231, 10113443, 10197934, 20405599, 20974419, 20876109, 20520917,
        10486957, 10087494, 10192432, 20065231, 10197973, 10769386, 22049851,
        10140024, 20566609, 10557822, 21961793, 20294930, 21363190, 21989492,
        20872072])

In [3]:
# number of features = num_filters
# shape of filters: (num_filters, num_channels, patch_size, patch_size)
ds_len = len(dataset)
img_dim = 256
num_ch = 3
patch_size = 3
num_feats = 512
TOT_PATCHES = int(1e5)
save_path = "/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/ckpts/mosaik_geo_filters"

# choosing random patches from the dataset
# img_idxs = np.random.choice(ds_len, num_feats, replace=False)
# img_idxs = np.sort(img_idxs)
x_idxs = np.random.choice(img_dim - patch_size - 1, num_feats)
y_idxs = np.random.choice(img_dim - patch_size - 1, num_feats)
x_idxs += int(np.ceil(patch_size / 2))
y_idxs += int(np.ceil(patch_size / 2))
x_idxs += int(np.ceil(patch_size / 2))
y_idxs += int(np.ceil(patch_size / 2))


In [4]:
def grab_patch_from_idx(img, idx_x, idx_y, patch_size, outpatch):
    sidx_x = int(idx_x - patch_size / 2)
    eidx_x = int(idx_x + patch_size / 2)
    sidx_y = int(idx_y - patch_size / 2)
    eidx_y = int(idx_y + patch_size / 2)
    outpatch[:, :, :] = img[:, sidx_x:eidx_x, sidx_y:eidx_y]
    return outpatch

In [5]:
patches = np.zeros((len(x_idxs), num_ch, patch_size, patch_size), dtype=np.float32)

start = timeit.default_timer()
patch_idx = 0
is_done = False

for bs_idx, batch in enumerate(loader):
    if is_done:
        break
    for idx in range(bs):
        if is_done: break
        flip = np.random.randint(0, 2, size=1)
        if flip:
            idx_x = x_idxs[patch_idx]
            idx_y = y_idxs[patch_idx]

            out_patch = patches[patch_idx, :, :, :]
            img = batch[0][BAND][idx]
            grab_patch_from_idx(img, idx_x, idx_y, patch_size, out_patch)  
            patch_idx += 1
            
            if(patch_idx >= num_feats):
                is_done = True


# Normalize patches

In [6]:
def normalize_patches(
    patches, min_divisor=1e-8, zca_bias=0.001, mean_rgb=np.array([0, 0, 0])
):
    if patches.dtype == "uint8":
        patches = patches.astype("float64")
        patches /= 255.0
    print("zca bias", zca_bias)
    n_patches = patches.shape[0]
    orig_shape = patches.shape
    patches = patches.reshape(patches.shape[0], -1)
    # Zero mean every feature
    patches = patches - np.mean(patches, axis=1)[:, np.newaxis]

    # Normalize
    patch_norms = np.linalg.norm(patches, axis=1)

    # Get rid of really small norms
    patch_norms[np.where(patch_norms < min_divisor)] = 1

    # Make features unit norm
    patches = patches / patch_norms[:, np.newaxis]

    patchesCovMat = 1.0 / n_patches * patches.T.dot(patches)

    (E, V) = np.linalg.eig(patchesCovMat)

    E += zca_bias
    sqrt_zca_eigs = np.sqrt(E)
    inv_sqrt_zca_eigs = np.diag(np.power(sqrt_zca_eigs, -1))
    global_ZCA = V.dot(inv_sqrt_zca_eigs).dot(V.T)
    patches_normalized = (patches).dot(global_ZCA).dot(global_ZCA.T)

    return patches_normalized.reshape(orig_shape).astype("float32")

In [7]:
filter_scale = 1e-3
normalized_patches = normalize_patches(patches, zca_bias=filter_scale)

zca bias 0.001


In [8]:
normalized_patches.shape

(512, 3, 3, 3)

In [9]:
np.save(os.path.join(save_path, 'mosaik_512_filters_geo'), normalized_patches)

In [10]:
patches[0], normalized_patches[0]

(array([[[203., 216., 213.],
         [207., 211., 208.],
         [201., 211., 208.]],
 
        [[195., 209., 205.],
         [201., 206., 202.],
         [197., 205., 201.]],
 
        [[182., 190., 182.],
         [189., 187., 180.],
         [185., 189., 182.]]], dtype=float32),
 array([[[  4.2722588 ,   5.5827312 ,   7.720792  ],
         [  0.9486916 ,  -7.2913423 ,   0.55563956],
         [-17.787073  ,   3.924743  ,   6.780162  ]],
 
        [[-20.366072  ,   9.646288  ,   2.8795617 ],
         [ -7.166939  ,  12.9466095 ,   4.3600044 ],
         [  2.2949765 ,   4.005745  ,  -6.27462   ]],
 
        [[  2.290555  ,   0.8417432 , -13.909103  ],
         [ 18.62511   , -12.784602  ,  -8.748385  ],
         [  7.8612976 ,   0.26827443,  -1.4780316 ]]], dtype=float32))

# Create Mosaik's network

In [4]:
class MosaikConv(nn.Module):
    """ All image inputs in torch must be C, H, W """

    def __init__(
        self,
        patches_np,
        patch_size=3,
        in_channels=3,
        pool_size=256,
        pool_stride=256,
        bias=0.0,
        filter_batch_size=1024,
    ):
        super().__init__()
        self.pool_size = pool_size
        self.pool_stride = pool_stride
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.bias = bias
        self.filter_batch_size = filter_batch_size
        self.active_filter_set = []
        self.start = None
        self.end = None
        self.gpu = False
        self.conv = nn.Conv2d(self.in_channels, num_feats, self.patch_size, bias=False)
        filters = torch.from_numpy(patches_np)
        self.conv.weight = nn.Parameter(filters)
        
    def forward(self, x):
        conv = self.conv(x)

        x_pos = F.avg_pool2d(
            F.relu(conv - self.bias),
            [self.pool_size, self.pool_size],
            stride=[self.pool_stride, self.pool_stride],
            ceil_mode=True,
        )
       
        x_neg = F.avg_pool2d(
            F.relu((-1 * conv) - self.bias),
            [self.pool_size, self.pool_size],
            stride=[self.pool_stride, self.pool_stride],
            ceil_mode=True,
        )
        return torch.cat((x_pos, x_neg), dim=1)

In [5]:
# patches = np.load(os.path.join(save_path, 'mosaik_512_filters_geo.npy'))
patches = np.load(os.path.join(save_path, 'kmeans_512_small.npy'))
# patches = patches.reshape((-1, 3, 3, 3))
# patches.shape
net = MosaikConv(patches)
img = torch.rand((32, 3, 256, 256))
net(img).shape

torch.Size([32, 1024, 1, 1])

In [15]:
patches.shape

(512, 27)

In [6]:
net = MosaikConv(patches)
data_batch = next(iter(loader))
data_batch[0]['rgb'].shape
net(data_batch[0]['rgb']).shape

torch.Size([64, 64, 1, 1])

In [34]:
net = nn.Sequential(
      nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding='same', bias=True),
      nn.ReLU(),
      nn.MaxPool2d(2, stride=2),

      nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding='same', bias=True),
      nn.ReLU(),
      nn.MaxPool2d(2, stride=2),

      nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same', bias=True),
      nn.ReLU(),
      nn.MaxPool2d(2, stride=2),

      nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding='same', bias=True),
      nn.ReLU(),
      nn.MaxPool2d(2, stride=2),

      nn.Flatten(),
      nn.Dropout(0.5),
      nn.Linear(50176, 512),
      nn.ReLU(),
      nn.Linear(512, num_species)
      ) 
net(torch.rand((1, 3, 224, 224))).shape

torch.Size([1, 17037])

In [37]:
net.ch

torch.nn.modules.container.Sequential

In [49]:
for _, layer in net.named_modules():
    print(type(layer))
    if isinstance(layer, nn.Conv2d):
        print(layer.weight.requires_grad)
        print(layer.weight.shape)
        

<class 'torch.nn.modules.container.Sequential'>
<class 'torch.nn.modules.conv.Conv2d'>
True
torch.Size([32, 3, 3, 3])
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
True
torch.Size([64, 32, 3, 3])
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
True
torch.Size([128, 64, 3, 3])
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.conv.Conv2d'>
True
torch.Size([256, 128, 3, 3])
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.pooling.MaxPool2d'>
<class 'torch.nn.modules.flatten.Flatten'>
<class 'torch.nn.modules.dropout.Dropout'>
<class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.activation.ReLU'>
<class 'torch.nn.modules.linear.Linear'>


# Using TIMM

In [None]:
import timm

from pprint import pprint
model_names = timm.list_models(pretrained=True)
# pprint(model_names)

In [None]:
model = timm.create_model(
                "vit_base_patch16_224",
                pretrained=self.opts.module.pretrained,
                num_classes=self.target_size,
            )

In [6]:
from torchvision import models
model = models.resnet50(pretrained=True)

In [10]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [13]:
m = nn.AdaptiveAvgPool2d((1,1))
input = torch.randn(1, 64, 8, 9)
output = m(input)
# >>> # target output size of 7x7 (square)
# >>> m = nn.AdaptiveAvgPool2d(7)
# >>> input = torch.randn(1, 64, 10, 9)
# >>> output = m(input)
# >>> # target output size of 10x7
# >>> m = nn.AdaptiveAvgPool2d((None, 7))
# >>> input = torch.randn(1, 64, 10, 9)
# >>> output = m(input)

In [14]:
output.shape

torch.Size([1, 64, 1, 1])

In [17]:
for n,p in model.named_parameters():
    print(n, p.shape)

conv1.weight torch.Size([64, 3, 7, 7])
bn1.weight torch.Size([64])
bn1.bias torch.Size([64])
layer1.0.conv1.weight torch.Size([64, 64, 1, 1])
layer1.0.bn1.weight torch.Size([64])
layer1.0.bn1.bias torch.Size([64])
layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
layer1.0.bn2.weight torch.Size([64])
layer1.0.bn2.bias torch.Size([64])
layer1.0.conv3.weight torch.Size([256, 64, 1, 1])
layer1.0.bn3.weight torch.Size([256])
layer1.0.bn3.bias torch.Size([256])
layer1.0.downsample.0.weight torch.Size([256, 64, 1, 1])
layer1.0.downsample.1.weight torch.Size([256])
layer1.0.downsample.1.bias torch.Size([256])
layer1.1.conv1.weight torch.Size([64, 256, 1, 1])
layer1.1.bn1.weight torch.Size([64])
layer1.1.bn1.bias torch.Size([64])
layer1.1.conv2.weight torch.Size([64, 64, 3, 3])
layer1.1.bn2.weight torch.Size([64])
layer1.1.bn2.bias torch.Size([64])
layer1.1.conv3.weight torch.Size([256, 64, 1, 1])
layer1.1.bn3.weight torch.Size([256])
layer1.1.bn3.bias torch.Size([256])
layer1.2.conv1.weight tor

In [32]:
from torch.nn import Module

class Net(Module):   
    def __init__(self):
        super(Net, self).__init__()

        self.cnn_layers = nn.Sequential(
            # Defining a 2D convolution layer
            nn.Conv2d(1, 4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # Defining another 2D convolution layer
            nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.linear_layers = nn.Sequential(
            nn.Linear(4 * 7 * 7, 10)
        )

    # Defining the forward pass    
    def forward(self, x):
        print(x.shape)
        x = self.cnn_layers(x)
        print(f'cnn layers: {x.shape}')
        x = x.view(x.size(0), -1)
        print(f'cnn after view: {x.shape}')
        x = self.linear_layers(x)
        return x
new_net = Net()

In [33]:
data_batch = next(iter(loader))

img = torch.rand((32, 1, 28, 28))
new_net(img).shape

torch.Size([32, 1, 28, 28])
cnn layers: torch.Size([32, 4, 7, 7])
cnn after view: torch.Size([32, 196])


torch.Size([32, 10])