# Deep Feature Factorization

Introduced in the paper:<br>
> Edo Collins, Radhakrishna Achanta and Sabine Süsstrunk. _Deep Feature Factorization for Concept Discovery_.  European Conference on Computer Vision (ECCV) 2018.


**DFF** is the application of non-negative matrix faxtorization (NMF) to the ReLU feature activations of a deep neural network. In the case of CNNs trained on images, the resulting factors decompose an image batch into semenatic parts with a high degree of invariance to complex transformations.

The geometry of ReLU activations plays a crucial role in creating favorable conditions for NMF. In particular, the "true" sparsity induced by ReLUs constrains the possible solutions to NMF and makes the factorization problem more well-posed.

The implementation below relies on Pytorch and includes a GPU implementation of NMF with multiplicative updates (Lee and Seung, 2001).

In [None]:
import os, time

import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

import torch
from torchvision import models
import torch.nn.functional as F
from nmf import NMF
from utils import imresize, show_heatmaps
#from kmeans_pytorch import kmeans
from sklearn.cluster import KMeans
!pip install git+https://github.com/wielandbrendel/bag-of-local-features-models.git
import bagnets.pytorchnet
import random

In [None]:
# CUDA flag. Speed-up due to CUDA is mostly noticable for large batches.
cuda = False

In [None]:
maxpool = torch.nn.MaxPool2d(kernel_size = (3, 1), stride = (1, 1), padding = (1,0))

## some flags

In [None]:
MAXPOOL = True
PREPARED_IMGS = True
NUM_IMGS = 150

### Load and pre-process the data

In [None]:
data_path = '/scratch/local/hdd/suny/Dante_new_images_WB/' # Exmaple data from the iCoseg dataset: http://chenlab.ece.cornell.edu/projects/touch-coseg/
filenames = []
folders = os.listdir(data_path)

filenames = []
for folder in folders:
    files = os.listdir(os.path.join(data_path, folder))
    filenames = filenames + [os.path.join(data_path, folder, file) for file in files]

print(filenames) 

In [None]:
size = (350, 350)

In [None]:
if PREPARED_IMGS:
    images = np.load("/scratch/local/hdd/suny/np_imgs/500imgs_250x250.npy")
else:    
    # Load images 
    raw_images = [plt.imread(filename) for filename in random.sample(filenames,500)]  
    raw_images = [imresize(img, size[0], size[1], crop = True) for img in raw_images] # resize  
    raw_images = np.stack(raw_images)

    # Preprocess
    images = raw_images.transpose((0,3,1,2)).astype('float32') # to numpy, NxCxHxW, float32
    images -= np.array([0.485, 0.456, 0.406]).reshape((1,3,1,1)) # zero mean
    images /= np.array([0.229, 0.224, 0.225]).reshape((1,3,1,1)) # unit variance

    np.save("/scratch/local/hdd/suny/np_imgs/500imgs_450x450", images)
    
images = images[:NUM_IMGS]   
raw_images = images
raw_images *= np.array([0.229, 0.224, 0.225]).reshape((1,3,1,1))
raw_images += np.array([0.485, 0.456, 0.406]).reshape((1,3,1,1))
raw_images = raw_images.transpose((0,2,3,1))*255#[img.transpose((0,2,3,1)) for img in images]
images = torch.from_numpy(images) # convert to Pytorch tensor
if cuda:
    images = images.cuda()
    
images.shape

### Load network and extract features

# VGG19

In [None]:
net = models.vgg19(pretrained=True) # load pre-trained VGG-19
if cuda:
    net = net.cuda()
del net.features._modules['36'] # remove max-pooling after final conv layer

In [None]:
for name, parameter in net.named_parameters():
    print(name)

In [None]:
num_f_layer = 9

f_layers = [0, 2, 5, 6, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
layers = [str(i) for i in range(35)]
num_layers = f_layers[num_f_layer]


with torch.no_grad(): 
    features = net.features._modules['0'](images)
    for i in range(1, num_layers):
        features = net.features._modules[layers[i]](features)
        print(layers[i])
    flat_features = features.permute(0, 2, 3, 1).contiguous().view((-1, features.size(1))) # NxCxHxW -> (N*H*W)xC

print('Reshaped features from {0}x{1}x{2}x{3} to ({0}*{2}*{3})x{1} = {4}x{1}'.format(*features.shape, flat_features.size(0)))

In [None]:
# Show original images
show = 5
show_heatmaps(raw_images[:show], None, 0, enhance=1)

# Run NMF with K=(1,2,3,4) Note: the permutation of factors is random and is not consistent across iterations
for K in range(1,5):
    with torch.no_grad():
        W, _ = NMF(flat_features, K, random_seed=0, cuda=cuda, max_iter=50)

    heatmaps = W.cpu().view(features.size(0), features.size(2), features.size(3), K).permute(0,3,1,2) # (N*H*W)xK -> NxKxHxW
    print(heatmaps.shape)
    heatmaps = heatmaps[:show]
    print(heatmaps.shape)
    heatmaps = torch.nn.functional.interpolate(heatmaps, size=size, mode='nearest', align_corners=None) ## 14x14 -> 224x224
    heatmaps /= heatmaps.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0] # normalize by factor (i.e., 1 of K)
    heatmaps = heatmaps.cpu().numpy()
    
    # Show heat maps
    show_heatmaps(raw_images[:show], heatmaps, K,  title='$k$ = {}'.format(K), enhance=0.3) 

# BagNet33

In [None]:
net = bagnets.pytorchnet.bagnet33(pretrained=True) # load pre-trained bagnet-19
if cuda:
    net = net.cuda()
for name, parameter in net.named_parameters():
    print(name)
    

## w/ Spatial features

In [None]:
with torch.no_grad(): 
    features = net.conv1(images)
    features = net.conv2(features)
    features = net.bn1(features)
    features = net.layer1(features) 
    features = net.layer2(features)
    features = net.layer3._modules["0"](features)
    features = net.layer3._modules["1"](features)
    #features = net.layer3._modules["2"](features)
    #features = net.layer3(features)
    #features = net.layer4(features)
print(features.shape)
permuted_features = features.permute(0, 2, 3, 1)
spatial_features = torch.ones([permuted_features.shape[1], permuted_features.shape[2], 2]).cuda()
vals = np.zeros(26)#np.arange(0, 0, 1/spatial_features.shape[0]).tolist()

for i in range(spatial_features.shape[0]):
    for j in range(spatial_features.shape[1]):
        spatial_features[i][j][0] = vals[i]
        spatial_features[i][j][1] = vals[j]

new_features = torch.zeros([permuted_features.shape[0],permuted_features.shape[1],permuted_features.shape[2],permuted_features.shape[3]+2]).cuda()
print(new_features.shape)
for i in range(permuted_features.shape[0]):     
    new_features[i,:,:,:] = torch.cat([permuted_features[i,:,:,:], spatial_features], dim=-1)
        
print(permuted_features.shape)    
flat_features = new_features.contiguous().view((-1, new_features.size(-1))) # NxCxHxW -> (N*H*W)xC
print(new_features.shape)
print(flat_features)   

    
print('Reshaped features from {0}x{1}x{2}x{3} to ({0}*{2}*{3})x{1} = {4}x{1}'.format(*features.shape, flat_features.size(0)))

## Without spatial features

In [None]:
with torch.no_grad(): 
    features = net.conv1(images)
    features = net.conv2(features)
    features = net.bn1(features)
    features = net.layer1(features) 
    features = net.layer2(features)
    features = net.layer3._modules["0"](features)
    #features = net.layer3._modules["1"](features)
    #features = net.layer3._modules["2"](features)
    #features = net.layer3(features)
    #features = net.layer4(features)
    
    flat_features = features.permute(0, 2, 3, 1).contiguous().view((-1, features.size(1))) # NxCxHxW -> (N*H*W)xC
    
print('Reshaped features from {0}x{1}x{2}x{3} to ({0}*{2}*{3})x{1} = {4}x{1}'.format(*features.shape, flat_features.size(0)))

In [None]:
show = 5
# Show original images
show_heatmaps(raw_images[:show], None, 0, enhance=1)

# Run NMF with K=(1,2,3,4) Note: the permutation of factors is random and is not consistent across iterations
for K in range(1,5):
    with torch.no_grad():
        W, H = NMF(flat_features, K, random_seed=0, cuda=cuda, max_iter=100)

    heatmaps = W.cpu().view(features.size(0), features.size(2), features.size(3), K).permute(0,3,1,2) # (N*H*W)xK -> NxKxHxW
    print(heatmaps.shape)
    heatmaps = heatmaps[:show]
    heatmaps = torch.nn.functional.interpolate(heatmaps, size=size, mode='nearest', align_corners=None) ## 14x14 -> 224x224
    heatmaps /= heatmaps.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0] # normalize by factor (i.e., 1 of K)
    heatmaps = heatmaps.cpu().numpy()
    
    # Show heat maps
    show_heatmaps(raw_images[:show], heatmaps, K,  title='$k$ = {}'.format(K), enhance=0.3) 

# ResNet50

In [None]:
net = models.resnet50(pretrained=True) # load pre-trained bagnet-19
if cuda:
    net = net.cuda()
for name, parameter in net.named_parameters():
    print(name)

In [None]:
with torch.no_grad(): 
    features = net.conv1(images)
    #features = net.conv2(features)
    features = net.bn1(features)
    features = net.layer1(features) 
    features = net.layer2(features)
    features = net.layer3._modules["0"](features)
    features = net.layer3._modules["1"](features)
    #features = net.layer3._modules["2"](features)
    #features = net.layer3(features)
    #features = net.layer4(features)
    
    flat_features = features.permute(0, 2, 3, 1).contiguous().view((-1, features.size(1))) # NxCxHxW -> (N*H*W)xC
    
print('Reshaped features from {0}x{1}x{2}x{3} to ({0}*{2}*{3})x{1} = {4}x{1}'.format(*features.shape, flat_features.size(0)))

### Factorize activations with NMF

In [None]:
# Show original images
show  = 5
show_heatmaps(raw_images[:show], None, 0, enhance=1)

# Run NMF with K=(1,2,3,4) Note: the permutation of factors is random and is not consistent across iterations
for K in range(1,5):
    with torch.no_grad():
        W, _ = NMF(flat_features, K, random_seed=0, cuda=cuda, max_iter=50)

    heatmaps = W.cpu().view(features.size(0), features.size(2), features.size(3), K).permute(0,3,1,2) # (N*H*W)xK -> NxKxHxW
    heatmaps = torch.nn.functional.interpolate(heatmaps, size=size, mode='bilinear', align_corners=False) ## 14x14 -> 224x224
    heatmaps /= heatmaps.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0] # normalize by factor (i.e., 1 of K)
    heatmaps = heatmaps.cpu().numpy()
    
    # Show heat maps
    show_heatmaps(raw_images[:show], heatmaps, K,  title='$k$ = {}'.format(K), enhance=0.3)


# K- means

In [None]:
show = 5
# Show original images
show_heatmaps(raw_images[5:show+5], None, 0, enhance=1)

# Run NMF with K=(1,2,3,4) Note: the permutation of factors is random and is not consistent across iterations
for K in range(2,7):
    kmeans = KMeans(n_clusters=K, random_state=0).fit(flat_features.cpu())
    targets = kmeans.labels_ 
    #targets, _ = kmeans(X=flat_features, num_clusters=K, distance='cosine', device=torch.device('cuda:0'))
    print(targets)
    W = torch.zeros([len(targets), K]) 
    W[range(W.shape[0]), targets]=0.2
        

    heatmaps = W.cpu().view(features.size(0), features.size(2), features.size(3), K).permute(0,3,1,2) # (N*H*W)xK -> NxKxHxW
    heatmaps = heatmaps[5:show+5]
    heatmaps = torch.nn.functional.interpolate(heatmaps, size=size, mode='nearest', align_corners=None) ## 14x14 -> 224x224
    heatmaps /= heatmaps.max(dim=3, keepdim=True)[0].max(dim=2, keepdim=True)[0] # normalize by factor (i.e., 1 of K)
    heatmaps = heatmaps.cpu().numpy()
    
    # Show heat maps
    show_heatmaps(raw_images[:show], heatmaps, K,  title='$k$ = {}'.format(K), enhance=0.3)