# RRRLossV2
This notebook contains some improvements that hopefully will stabilise the erratic behaviour of our original implementation.

## Imports & Initialisation

In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
import time

import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import diveslowlearnfast as dlf

from torch.utils.data import DataLoader
from diveslowlearnfast.datasets import Diving48Dataset
from diveslowlearnfast.egl.explainer import ExplainerStrategy
from diveslowlearnfast.train.helper import get_test_transform
from diveslowlearnfast.models import SlowFast, load_checkpoint
from diveslowlearnfast.config import Config
from diveslowlearnfast.visualise.gradcam import GradCAM
from diveslowlearnfast.train import helper as train_helper, StatsDB
from diveslowlearnfast.egl import helper as egl_helper

cfg = Config()
cfg.DATA.DATASET_PATH = '/Users/youritomassen/Projects/xai/data/Diving48/'
cfg.TRAIN.BATCH_SIZE = 4
cfg.GRADCAM.TARGET_LAYERS = ['s5/pathway0_res2', 's5/pathway0_res2']
device = torch.device('cpu')
model = SlowFast(cfg)
_, optimiser, *_ = train_helper.get_train_objects(cfg, model)
model, *_ = load_checkpoint(model, '../misc/checkpoint.pth', optimiser,  device)




## The `Code`
These code changes were obtained by integrating the author's original implementation. **Note** this implementation has an additional gradient penalty which may be ignored. Key differences are calculating the gradients over the log probalities

```python
# RRRLoss by the original author
def loss_function(self, l2_grads=1000, l1_grads=0, l2_params=0.0001):
    right_answer_loss = tf.reduce_sum(tf.multiply(self.y, -self.log_prob_ys))

    gradXes = tf.gradients(self.log_prob_ys, self.X)[0]
    A_gradX = tf.multiply(self.A, gradXes)
    right_reason_loss = 0
    if l1_grads > 0:
      right_reason_loss += l1_grads * tf.reduce_sum(tf.abs(A_gradX))
    if l2_grads > 0:
      right_reason_loss += l2_grads * tf.nn.l2_loss(A_gradX)

    small_params_loss = l2_params * tf.add_n([tf.nn.l2_loss(p) for p in self.W + self.b])

    return right_answer_loss + right_reason_loss + small_params_loss
```

In [11]:
class DualPathRRRLossV2(nn.Module):
    def __init__(self,
                 lambdas=None,
                 normalise_gradients=False,
                 skip_zero_masks=False):
        super().__init__()

        if lambdas is None:
            lambdas = [1000.0, 1000.0]

        self.lambdas = lambdas
        self.normalise_gradients = normalise_gradients
        self.skip_zero_masks = skip_zero_masks

    def forward(self, logits, targets, inputs, masks, warmup=False):
        batch_size = logits.size(0)

        # Convert targets to one-hot encoding for proper calculation
        num_classes = logits.size(1)
        target_one_hot = torch.zeros_like(logits).scatter_(1, targets.unsqueeze(1), 1)

        # Calculate the log probabilities
        log_probs = F.log_softmax(logits, dim=1)

        # Calculate cross-entropy loss (right answer loss)
        # This matches tf.reduce_sum(tf.multiply(self.y, -self.log_prob_ys))
        right_answer_loss = -torch.sum(target_one_hot * log_probs) / batch_size

        total_loss = right_answer_loss
        losses = {'ce_loss': right_answer_loss.item()}

        if warmup or (self.skip_zero_masks and all(torch.sum(mask) == 0 for mask in masks)):
            losses['total_loss'] = total_loss.item()
            for idx in range(len(inputs)):
                losses[f'gradient_loss_path_{idx}'] = 0
            return total_loss, losses

        # Calculate gradient penalties for each input path
        for idx, (inp, mask) in enumerate(zip(inputs, masks)):
            if torch.sum(mask) == 0 and self.skip_zero_masks:
                losses[f'gradient_loss_path_{idx}'] = 0
                continue

            # Calculate gradients of log_probs with respect to input
            # This matches tf.gradients(self.log_prob_ys, self.X)[0]
            gradients = torch.autograd.grad(
                log_probs,
                inp,
                grad_outputs=torch.ones_like(log_probs),
                create_graph=True,
                retain_graph=True
            )[0]

            if self.normalise_gradients:
                gradients = gradients / (torch.norm(gradients, dim=1, keepdim=True) + 1e-10)

            # Apply mask to gradients (A_gradX = tf.multiply(self.A, gradXes))
            masked_gradients = mask * gradients

            # L2 gradient penalty (l2_grads * tf.nn.l2_loss(A_gradX))
            n_frames = inp.size(2)
            l2_grad_loss = self.lambdas[idx] * torch.sum(masked_gradients**2) / (batch_size * n_frames)


            gradient_loss = l2_grad_loss
            total_loss += gradient_loss

            losses[f'gradient_loss_path_{idx}'] = gradient_loss.item()

        losses['total_loss'] = total_loss.item()
        return total_loss, losses

## Data Preparation

In [7]:
cfg.EGL.MASKS_CACHE_DIR = './.masks'
slow_masks_dir = os.path.join(cfg.EGL.MASKS_CACHE_DIR, 'slow')

stats_db = StatsDB('./data/stats.db')
difficult_samples = stats_db.get_below_median_samples(
    epoch_start=90,
    run_id='/home/s2871513/Projects/diveslowlearnfast/results/run18',
    split='train'
)

non_mask_video_ids = list(map(lambda x: x[0], difficult_samples))[:8]
mask_npys = os.listdir(slow_masks_dir)
masked_video_ids = list(map(lambda x: x.replace('.npy', ''), mask_npys))[:8]
mixed_video_ids = masked_video_ids[:1] + non_mask_video_ids[:7]


def get_batch(video_ids=None):
    cfg.DATA.TEST_CROP_SIZE = 224
    dataset = Diving48Dataset(
        cfg.DATA.DATASET_PATH,
        cfg.DATA.NUM_FRAMES,
        alpha=cfg.SLOWFAST.ALPHA,
        dataset_type='train',
        transform_fn=get_test_transform(cfg),  # use test_transform instead
        use_decord=cfg.DATA_LOADER.USE_DECORD,
        temporal_random_jitter=cfg.DATA.TEMPORAL_RANDOM_JITTER,
        temporal_random_offset=cfg.DATA.TEMPORAL_RANDOM_OFFSET,
        multi_thread_decode=cfg.DATA.MULTI_THREAD_DECODE,
        threshold=cfg.DATA.THRESHOLD,
        use_sampling_ratio=cfg.DATA.USE_SAMPLING_RATIO,
        video_ids=video_ids,
        mask_type='gradcam',
        masks_cache_dir=cfg.EGL.MASKS_CACHE_DIR,
    )
    loader = DataLoader(
        dataset,
        batch_size=8,
        pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
        num_workers=cfg.DATA_LOADER.NUM_WORKERS,
        shuffle=True,
    )
    return next(iter(loader))


non_mask_video_ids, masked_video_ids, mixed_video_ids

(['3N1kUtqJ25A_00065',
  '3N1kUtqJ25A_00185',
  '3PLiUG_DuC8_00167',
  '9BC6ssCjyfg_00232',
  '9jZYYtzYqwE_00081',
  'Bb0ZiYVNtDs_00072',
  'D6zILEKIJbk_00190',
  'JzOshOJgofw_00022'],
 ['JzOshOJgofw_00176',
  'nOlRwoxsDJ0_00609',
  'db7DmpqzGmc_00121',
  'RWNrARSbRCY_00021',
  'iv0Gu1VXAgc_00067',
  'fohMq9tOn6E_00135',
  'JzOshOJgofw_00006',
  'uDESPUxbjnI_00186'],
 ['JzOshOJgofw_00176',
  '3N1kUtqJ25A_00065',
  '3N1kUtqJ25A_00185',
  '3PLiUG_DuC8_00167',
  '9BC6ssCjyfg_00232',
  '9jZYYtzYqwE_00081',
  'Bb0ZiYVNtDs_00072',
  'D6zILEKIJbk_00190'])

In [8]:
xb, yb, _, _, _, mslow, mfast = get_batch(video_ids=masked_video_ids)
xb.shape, yb.shape, mslow.shape, mfast.shape

(torch.Size([8, 3, 32, 224, 224]),
 torch.Size([8]),
 torch.Size([8, 1, 4, 224, 224]),
 torch.Size([8, 1, 32, 224, 224]))

In [9]:
inputs = dlf.to_slowfast_inputs(xb, alpha=cfg.SLOWFAST.ALPHA, requires_grad=True)
logits = model(inputs)
nn.CrossEntropyLoss()(logits, yb)

tensor(0.0366, grad_fn=<NllLossBackward0>)

In [12]:
criterion = DualPathRRRLossV2(lambdas=[.01, .01])
total_loss, all_losses = criterion(logits, yb, inputs, [mslow, mfast])
total_loss, all_losses

4
32


(tensor(2.4581, grad_fn=<AddBackward0>),
 {'ce_loss': 0.036621104925870895,
  'gradient_loss_path_0': 1.0924240350723267,
  'gradient_loss_path_1': 1.3290122747421265,
  'total_loss': 2.458057403564453})

In [30]:
all_losses

{'ce_loss': 1.531362771987915,
 'gradient_loss_path_0': 351848.9375,
 'gradient_loss_path_1': 2585570.75,
 'total_loss': 2937421.25}