# Two-layer custom CNN from MOSAIKS

## Prepare GeoLife data loader

In [1]:
import os
import sys
import time
import inspect
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import timm

from sklearn.cluster import MiniBatchKMeans
from sklearn.feature_extraction.image import extract_patches_2d

CURR_DIR = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
PARENT_DIR = os.path.dirname(CURR_DIR)
sys.path.insert(0, "/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/remote_sensing")

from dataset.pytorch_dataset import GeoLifeCLEF2022Dataset
from torch.utils.data import random_split, DataLoader
from mosaiks_utils import visualize_3d_patches, visTensor, normalize_patches, DBN

random_state = np.random.RandomState(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = "/network/scratch/s/sara.ebrahim-elkafrawy/" #"/network/scratch/s/sara.ebrahim-elkafrawy/small_geo_data"  # "/network/scratch/s/sara.ebrahim-elkafrawy/" 
split = "train"
use_ffcv_loader = False
num_species= 17037
bands = ["rgb"] 
batch_size = 1
num_workers = 2

In [3]:
geo_train_dataset = GeoLifeCLEF2022Dataset(
                root=data_dir,
                subset=split,
                use_ffcv_loader=use_ffcv_loader,
                region="both",
                patch_data=bands,
                use_rasters=False,
                patch_extractor=None,
                transform=None,
                target_transform=None,
                opts=None,
            )

geo_train_loader = DataLoader(
                geo_train_dataset,
                batch_size=batch_size,
                num_workers=num_workers,
                shuffle=True,
                pin_memory=True,
            )

trf = torch.nn.Sequential(
    transforms.Resize(size=(224, 224), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.Normalize((106.9413, 114.8733, 104.5285), (51.0005, 44.8595, 43.2014)),
)

## Define the model: two-layer CNN MOSAIK

In [4]:
# hyperparameters for the model
kernel_size = 7
mode = 'whiten_minibatch_allGeo' # options ['whiten', 'no_whiten']
conv1_num_filters = 100
conv2_num_filters = 64
whiten = True
zca_bias = 0.001
save_path = f"/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/ckpts/two_layer_mosaiks_kmeans_{kernel_size}_{mode}_zcaBias_{zca_bias}_minibatch.pt"
max_iter = 6
random_state = np.random.RandomState(0)

In [5]:
model = nn.Sequential(
      nn.Conv2d(in_channels=3, out_channels=conv1_num_filters, kernel_size=kernel_size, padding='same', bias=True),
      nn.LeakyReLU(),
      nn.MaxPool2d(2, stride=2),

      nn.Conv2d(in_channels=conv1_num_filters, out_channels=conv2_num_filters, kernel_size=kernel_size, padding='same', bias=True),
      nn.LeakyReLU(),
      nn.MaxPool2d(2, stride=2),
    
      nn.AdaptiveAvgPool2d(9),
    
      nn.Flatten(),
      nn.Dropout(0.5),
      nn.Linear(5184, 512), #50176
      nn.ReLU(),
      nn.Linear(512, num_species)
      ) 
model(torch.rand((1, 3, 224, 224))).shape

torch.Size([1, 17037])

## Custom Mosaiks

In [6]:
# model = nn.Sequential(
#       nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding='same', bias=True),
#       nn.LeakyReLU(),
#       nn.MaxPool2d(2, stride=2),

#       nn.Conv2d(in_channels=64, out_channels=64, kernel_size=7, padding='same', bias=True),
#       nn.LeakyReLU(),
#       nn.MaxPool2d(2, stride=2),

#       nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding='same', bias=True),
#       nn.LeakyReLU(),
#       nn.MaxPool2d(2, stride=2),

#       nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding='same', bias=True),
#       nn.LeakyReLU(),
#       nn.MaxPool2d(2, stride=2),
    
#       nn.AdaptiveAvgPool2d(9),
    
#       nn.Flatten(),
#       nn.Dropout(0.5),
#       nn.Linear(20736, 512), #50176
#       nn.ReLU(),
#       nn.Linear(512, num_species)
#       ) 
# model(torch.rand((1, 3, 224, 224))).shape

In [7]:
model

Sequential(
  (0): Conv2d(3, 100, kernel_size=(7, 7), stride=(1, 1), padding=same)
  (1): LeakyReLU(negative_slope=0.01)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(100, 64, kernel_size=(7, 7), stride=(1, 1), padding=same)
  (4): LeakyReLU(negative_slope=0.01)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): AdaptiveAvgPool2d(output_size=9)
  (7): Flatten(start_dim=1, end_dim=-1)
  (8): Dropout(p=0.5, inplace=False)
  (9): Linear(in_features=5184, out_features=512, bias=True)
  (10): ReLU()
  (11): Linear(in_features=512, out_features=17037, bias=True)
)

## Manually set the indices of all convolution layers and the afterwards activation layers

In [8]:
# for 2-conv layers
conv_lyrs = [0, 3]
act_lyrs = [1, 4]

# for 4-conv layers
# conv_lyrs = [0, 2, 4, 6]
# act_lyrs = [1, 4, 7, 10]

In [9]:
for name, param in model.named_parameters():
    print(name, '---------------\t', param.shape)

0.weight ---------------	 torch.Size([100, 3, 7, 7])
0.bias ---------------	 torch.Size([100])
3.weight ---------------	 torch.Size([64, 100, 7, 7])
3.bias ---------------	 torch.Size([64])
9.weight ---------------	 torch.Size([512, 5184])
9.bias ---------------	 torch.Size([512])
11.weight ---------------	 torch.Size([17037, 512])
11.bias ---------------	 torch.Size([17037])


In [10]:
list(model.children())

[Conv2d(3, 100, kernel_size=(7, 7), stride=(1, 1), padding=same),
 LeakyReLU(negative_slope=0.01),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 Conv2d(100, 64, kernel_size=(7, 7), stride=(1, 1), padding=same),
 LeakyReLU(negative_slope=0.01),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 AdaptiveAvgPool2d(output_size=9),
 Flatten(start_dim=1, end_dim=-1),
 Dropout(p=0.5, inplace=False),
 Linear(in_features=5184, out_features=512, bias=True),
 ReLU(),
 Linear(in_features=512, out_features=17037, bias=True)]

## Hook the activation layers

In [11]:
features_dim = {}
def get_features(name):
    def hook(model, input, output):
        features_dim[name] = output.detach()
    return hook

In [12]:
# for 2-conv layers
model[1].register_forward_hook(get_features('relu_layer_1'))
model[4].register_forward_hook(get_features('relu_layer_4'))

# for 4-conv layers (without nn.AdaptivePool layer)
# model[1].register_forward_hook(get_features('relu_layer_1'))
# model[4].register_forward_hook(get_features('relu_layer_4'))
# model[7].register_forward_hook(get_features('relu_layer_7'))
# model[10].register_forward_hook(get_features('relu_layer_10'))

<torch.utils.hooks.RemovableHandle at 0x7f6a14320a90>

## Apply the K-Means MiniBatch to two conv layers

In [None]:
all_patches = []
params = []
# dbn = DBN(3, 3, affine=False, momentum=1.)

for layer_idx, relu_idx in enumerate(act_lyrs):
    if layer_idx == len(conv_lyrs):
        break
        
    print(f'for conv layer#{conv_lyrs[layer_idx]}')
    print(f'kmeans for output of relu act layer#{relu_idx}')
    
    curr_param_idx = conv_lyrs[layer_idx]
    curr_param_sz = model[curr_param_idx].weight.data.shape # or list(model.children())[curr_param_idx].weight

    print(f'current parameter size: {curr_param_sz}')
        
    num_feats = curr_param_sz[0]
    num_ch = curr_param_sz[1]
    patch_size = (curr_param_sz[2], curr_param_sz[3])
    num_iters = max_iter   # The online learning part: cycle over the whole dataset 6 times
    max_patches = int(num_feats/4)

    print(f'num_feats:{num_feats}, num_ch:{num_ch}, patch_size:{patch_size}')
    
    kmeans = MiniBatchKMeans(n_clusters=num_feats, 
                             random_state=random_state,
                             verbose=False)

    geo_train_loader = DataLoader(
                    geo_train_dataset,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    shuffle=True,
                )

    index = 0
    buffer = []

    for _ in range(num_iters):
        for batch in geo_train_loader:
            patches, target, meta = batch

            # this is trf_stand: means standardization with the GeoLife stats
            patches['rgb'] = trf(patches['rgb'])
            output = model(patches['rgb'])
        
            if layer_idx == 0:
                curr_feats = patches['rgb'].numpy()
                
            else:
                curr_feats = features_dim[f'relu_layer_{act_lyrs[layer_idx-1]}'].numpy()
            
            curr_feats = curr_feats.squeeze(0)
            curr_feats = curr_feats.transpose((1,2,0))
            
            # expects image shape of (width, height, n_channels)
            data = extract_patches_2d(curr_feats, 
                                      patch_size, 
                                      max_patches=max_patches,
                                      random_state=random_state)
#             print(data.shape) # (32, 7, 7, 3) , (32, 7, 7, 64)

            data = np.reshape(data, (len(data), -1))
            all_patches.append(data)
            buffer.append(data)
            
            index += 1
            if index % int(len(geo_train_loader)/30) == 0:   # by the end of the train loader ~2m patches
                data = np.concatenate(buffer, axis=0)
                if np.any(np.isnan(data)):
                    data = np.nan_to_num(data)
                    
                data, means, zca_mat = normalize_patches(data, zca_bias=zca_bias, whiten=whiten)
                print(f'whiten_patches [min],[max],[mean],[std]: {data.min():.3f}, {data.max():.3f}, {data.mean():.3f}, {data.std():.3f}')
                
                print(f'Running Kmeans for {len(data)} patch')
                kmeans.partial_fit(data)
                
                buffer = []
#             if index % 10000 == 0:
#                 print(f"Partial fit of {index} out of {num_iters * len(geo_train_loader)}")

    # change the weights of the corresponding conv layer
    print(f'Updating parameter#{curr_param_idx} with size: {curr_param_sz}')
    params.append(kmeans.cluster_centers_)
    model[curr_param_idx].weight.data = torch.from_numpy(kmeans.cluster_centers_.reshape(
                                            num_feats, 
                                            patch_size[0], 
                                            patch_size[1],
                                            num_ch,).transpose(0, 3, 1, 2)
                                        )
#     norm_param = (x - x.mean())/(x.std())
#     model[curr_param_idx].weight.data = norm_param

    # save the model
    torch.save(model.state_dict(), save_path)


for conv layer#0
kmeans for output of relu act layer#1
current parameter size: torch.Size([100, 3, 7, 7])
num_feats:100, num_ch:3, patch_size:(7, 7)
zca bias 0.001
negatives eigen values: (0,)
whiten_patches [min],[max],[mean],[std]: -12.696, 10.032, -0.000, 0.585
Running Kmeans for 1322825 patch
zca bias 0.001
negatives eigen values: (0,)
whiten_patches [min],[max],[mean],[std]: -10.776, 11.342, -0.000, 0.585
Running Kmeans for 1322825 patch
zca bias 0.001
negatives eigen values: (0,)
whiten_patches [min],[max],[mean],[std]: -10.479, 13.238, -0.000, 0.584
Running Kmeans for 1322825 patch
zca bias 0.001
negatives eigen values: (0,)
whiten_patches [min],[max],[mean],[std]: -16.121, 11.034, -0.000, 0.585
Running Kmeans for 1322825 patch
zca bias 0.001
negatives eigen values: (0,)
whiten_patches [min],[max],[mean],[std]: -10.589, 10.844, -0.000, 0.585
Running Kmeans for 1322825 patch
zca bias 0.001
negatives eigen values: (0,)
whiten_patches [min],[max],[mean],[std]: -10.137, 10.725, -0.0

# Visualize the filters of the first layer

In [None]:
def check_weights(weights_data):
    print(f'min: {weights_data.min()}')
    print(f'max: {weights_data.max()}')
    print(f'mean: {weights_data.mean()}')
    print(f'std: {weights_data.std()}')
    print(f'num_complex numbers: {np.iscomplex(weights_data.sum())}')
check_weights(model[0].weight.data.numpy())

In [None]:
visTensor(model[0].weight.data, ch=0, allkernels=False)
plt.axis('off')
plt.ioff()
plt.rcParams['savefig.facecolor']='black'
plt.savefig(f'conv1_{str(kernel_size)}_{mode}_zcaBias_{zca_bias}_minibatch.png')
plt.show()

In [None]:
# model_path = "/home/mila/s/sara.ebrahim-elkafrawy/scratch/ecosystem_project/ckpts/two_layer_mosaiks_kmeans_7_whiten_minibatch.pt"
# model.load_state_dict(torch.load(model_path))

In [None]:
# visTensor(model[0].weight.data, ch=0, allkernels=False)
# plt.axis('off')
# plt.show()