In [30]:
%load_ext autoreload
%autoreload 2

import torch
import time
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from diveslowlearnfast.datasets import Diving48Dataset
from diveslowlearnfast.egl.explainer import ExplainerStrategy
from diveslowlearnfast.train.helper import get_test_transform
from diveslowlearnfast.models import SlowFast, load_checkpoint
from diveslowlearnfast.config import Config
from diveslowlearnfast.visualise.gradcam import GradCAM
from diveslowlearnfast.train import helper as train_helper, StatsDB
from diveslowlearnfast.egl import helper as egl_helper
from diveslowlearnfast.loss.rrr import RRRLoss

cfg = Config()
cfg.DATA.DATASET_PATH = '/Users/youritomassen/Projects/xai/data/Diving48/'
cfg.TRAIN.BATCH_SIZE = 4
cfg.GRADCAM.TARGET_LAYERS = ['s5/pathway0_res2', 's5/pathway0_res2']
device = torch.device('cpu')
model = SlowFast(cfg)
_, optimiser, *_ = train_helper.get_train_objects(cfg, model)
model, *_ = load_checkpoint(model, optimiser, '../misc/checkpoint.pth', device)
explainer = ExplainerStrategy.get_explainer(model, cfg=cfg, device=device)

train_dataset = Diving48Dataset(
    cfg.DATA.DATASET_PATH,
    cfg.DATA.NUM_FRAMES,
    dataset_type='train',
    transform_fn=get_test_transform(cfg),  # use test_transform instead
    use_decord=cfg.DATA_LOADER.USE_DECORD,
    temporal_random_jitter=cfg.DATA.TEMPORAL_RANDOM_JITTER,
    temporal_random_offset=cfg.DATA.TEMPORAL_RANDOM_OFFSET,
    multi_thread_decode=cfg.DATA.MULTI_THREAD_DECODE,
    threshold=cfg.DATA.THRESHOLD,
    use_dynamic_temporal_stride=cfg.DATA.USE_DYNAMIC_TEMPORAL_STRIDE,
    masks_cache_dir='../results/run11/.masks'
)

train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.TRAIN.BATCH_SIZE,
    pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
    num_workers=cfg.DATA_LOADER.NUM_WORKERS,
    shuffle=True,
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload



In [80]:
xb, yb, _, _, _, masks = next(iter(train_loader))
masks[0, :, :, :, :] = 1

In [11]:
criterion = RRRLoss()

In [75]:
xb_fast = xb[:].to(device)
# reduce the number of frames by the alpha ratio
# B x C x T / alpha x H x W
xb_slow = xb[:, :, ::cfg.SLOWFAST.ALPHA].to(device)
xb_fast.requires_grad = True
logits = model([xb_slow, xb_fast])

In [42]:
loss, loss_components = criterion(logits, yb, model, xb_fast, masks)
start = time.time()
loss.backward()
print('backward took', time.time() - start)

cross entropy loss  0.0002338886260986328
True
log probs took  0.0002498626708984375
gradients took 4.902899742126465
backward took 22.167354822158813


In [23]:
from torch import nn

ce_loss = nn.CrossEntropyLoss()

In [76]:
loss = ce_loss(logits, yb)
start = time.time()
loss.backward()
print('backward took', time.time() - start)

backward took 5.390617847442627


In [83]:
xb_fast = xb[:].to(device)
# reduce the number of frames by the alpha ratio
# B x C x T / alpha x H x W
xb_slow = xb[:, :, ::cfg.SLOWFAST.ALPHA].to(device)
xb_fast.requires_grad = True
logits = model([xb_slow, xb_fast])

In [87]:
def rrr(logits, targets, inputs, masks, cross_entropy):
    """
    Compute the RRR loss.

    Args:
        logits: Model output logits (B x num_classes)
        targets: Ground truth labels (B)
        model: The model being trained
        inputs: Input data (B x ...)
        masks: Binary masks indicating where gradients should be small (B x ...)

    Returns:
        total_loss: Combined loss value
        losses: Dictionary containing individual loss components
    """
    ce_loss = cross_entropy(logits, targets)
    # boolean mask to select indices for which there is a mask available
    # this little optimisation ensures that we do not use autograd on inputs that will
    # be ignored anyway
    mask = (masks.sum(dim=(1, 2, 3, 4)) > 0)
    masked_elements = inputs[mask]

    if masked_elements.shape[0] > 0:
        log_probs = F.softmax(masked_elements, dim=1)
        summed_log_probs = log_probs.sum()
        gradients = torch.autograd.grad(summed_log_probs, inputs, create_graph=True, retain_graph=True)[0]
        gradient_loss = (masks * gradients).pow(2).mean()
        gradient_loss_item = gradient_loss.item()
    else:
        gradient_loss = 0
        gradient_loss_item = gradient_loss

    total_loss = ce_loss + 1000.0 * gradient_loss

    losses = {
        'total_loss': total_loss.item(),
        'ce_loss': ce_loss.item(),
        'gradient_loss': gradient_loss_item
    }

    return total_loss, losses

In [88]:
print('masks count =', (masks.sum(dim=(1, 2, 3, 4)) > 0).sum().item())
loss, all_losses = rrr(logits, yb, xb_fast, masks, nn.CrossEntropyLoss())
start = time.time()
loss.backward()
print('backward took', time.time() - start)

masks count= 1
backward took 7.548271179199219
