In [1]:
import time
from datetime import datetime
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import pandas as pd

from sparse_coding_torch.conv_sparse_model import ConvSparseLayer
from sparse_coding_torch.small_data_classifier import SmallDataClassifier
from sparse_coding_torch.utils import plot_filters
from sparse_coding_torch.utils import plot_video

from sparse_coding_torch.BamcPreprocessor import BamcPreprocessor
from sparse_coding_torch.video_loader import MinMaxScaler
from sparse_coding_torch.video_loader import VideoGrayScaler
from sparse_coding_torch.video_loader import VideoLoader
from sparse_coding_torch.video_loader import VideoClipLoader

from sparse_coding_torch.load_data import load_bamc_data

from IPython.display import HTML

In [2]:
# Get the devices available and set the batch size
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    batch_size = 1
else:
    batch_size = 4*8
    # batch_size = 3

In [None]:
# batch_size = 62
# video_path = "/shared_data/bamc_data"
video_path = "/shared_data/bamc_data_scale_cropped"

# scaled and cropped video size is 400x700
transforms = torchvision.transforms.Compose([VideoGrayScaler(),
                                             MinMaxScaler(0, 255),
                                             # BamcPreprocessor(),
                                             # torchvision.transforms.Resize(size=(172, 300))
                                            ])
dataset = VideoClipLoader(video_path, transform=transforms, 
                          frames_between_clips=1,
                          num_frames=4
                          # num_frames=60
                         )

targets = dataset.get_labels()
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                              shuffle=True)

  0%|          | 0/4 [00:00<?, ?it/s]

  3%|▎         | 101/3534 [00:36<20:21,  2.81it/s]

In [None]:
dataset[0][1].shape

In [None]:
example_data = next(iter(data_loader))
example_data[1].shape
ani = plot_video(example_data[1][2])
HTML(ani.to_html5_video())

In [None]:
sparse_layer = ConvSparseLayer(in_channels=1,
                               out_channels=64,
                               kernel_size=(4, 16, 16),
                               stride=2,
                               padding=0,
                               convo_dim=3,
                               rectifier=True,
                               lam=0.01,
                               max_activation_iter=200,
                               activation_lr=1e-2)
model = sparse_layer
model = torch.nn.DataParallel(model, device_ids=[1, 0, 2, 3])
model.to(device)

learning_rate = 3e-4
optimizer = torch.optim.Adam(sparse_layer.parameters(),
                                    lr=learning_rate)

criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
loss_log = []

for epoch in tqdm(range(500)):
    epoch_loss = 0
    # for local_batch in train_loader:
    for labels, local_batch, filenames in data_loader:
        local_batch = local_batch.to(device)

        activations = model(local_batch)
        loss = sparse_layer.loss(local_batch, activations)
        epoch_loss += loss.item() * local_batch.size(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        sparse_layer.normalize_weights()
      
    epoch_loss /= len(data_loader.sampler)

In [None]:
plt.plot(loss_log)

In [None]:
ani = plot_filters(sparse_layer.filters.cpu().detach())
HTML(ani.to_html5_video())

In [None]:
torch.save({'model_state_dict': model.module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),},
           datetime.now().strftime("saved_models/sparse_conv3d_model-best.pt"))