In [None]:
import time
from datetime import datetime
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.animation import FuncAnimation

from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from sparse_coding_torch.small_data_classifier import SmallDataClassifierConv3d

from sklearn.model_selection import train_test_split

from sparse_coding_torch.utils import plot_filters
from sparse_coding_torch.utils import plot_video

from sparse_coding_torch.load_data import load_yolo_clips

from IPython.display import HTML

from tqdm import tqdm


In [None]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
batch_size = 1
    # batch_size = 3

# train_loader = load_balls_data(batch_size)
train_loader, _ = load_yolo_clips(batch_size, mode='all_train', device=device, n_splits=1, sparse_model=None)
print('Loaded', len(train_loader), 'train examples')

example_data = next(iter(train_loader))

sparse_layer = ConvSparseLayer(in_channels=1,
                               out_channels=64,
                               kernel_size=(5, 15, 15),
                               stride=1,
                               padding=(0, 7, 7),
                               convo_dim=3,
                               rectifier=True,
                               lam=0.05,
                               max_activation_iter=200,
                               activation_lr=1e-1)
sparse_layer.to(device)

In [None]:
# Load models if we'd like to
checkpoint = torch.load("/home/dwh48@drexel.edu/sparse_coding_torch/sparse_conv3d_model-pleural_clips2_5x15x15-11-14-21.pt")
sparse_layer.load_state_dict(checkpoint['model_state_dict'])

In [None]:
fp_ids = ['image_24164968068436_CLEAN', 'image_73815992352100_clean', 'image_74132233134844_clean']
fn_ids = ['image_610066411380_CLEAN', 'image_634125159704_CLEAN', 'image_588695055398_clean', 'image_584357289931_clean', 'Image_262499828648_clean', 'image_267456908021_clean', 'image_2743083265515_CLEAN', 'image_1749559540112_clean']

incorrect_sparsity = []
correct_sparsity = []
incorrect_filter_act = torch.zeros(64)
correct_filter_act = torch.zeros(64)

for labels, local_batch, vid_f in tqdm(train_loader):
    activations = sparse_layer(local_batch.to(device))
    sparsity = torch.count_nonzero(activations) / torch.numel(activations)
    filter_act = torch.sum(activations.squeeze(), dim=[1, 2])
    filter_act = filter_act / torch.max(filter_act)
    filter_act = filter_act.detach().cpu()
    
    if vid_f[0] in fp_ids or vid_f[0] in fn_ids:
        incorrect_sparsity.append(sparsity)
        incorrect_filter_act += filter_act
    else:
        correct_sparsity.append(sparsity)
        correct_filter_act += filter_act
        
print(torch.mean(torch.tensor(correct_sparsity)))
print(torch.mean(torch.tensor(incorrect_sparsity)))
    

In [None]:
filters = sparse_layer.filters.cpu().detach()
print(filters.size())

filters = torch.stack([filters[val] for val in incorrect_filter_act.argsort(descending=True)])

print(filters.size())

ani = plot_filters(filters)
# HTML(ani.to_html5_video())
ani.save("/home/dwh48@drexel.edu/sparse_coding_torch/incorrect_vis.mp4")