# EGL Train Epoch Tests


## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import os
import torch
import random

import torch.nn as nn
import numpy as np
import diveslowlearnfast as dlf
import matplotlib.pyplot as plt

from diveslowlearnfast.config import Config
from diveslowlearnfast.datasets import Diving48Dataset
from diveslowlearnfast.train import helper as train_helper
from diveslowlearnfast import load_checkpoint
from diveslowlearnfast.train.helper import get_train_transform, get_test_transform, get_mask_transform, get_base_transform
from diveslowlearnfast.models import SlowFast
from diveslowlearnfast.egl import ExplainerStrategy
from diveslowlearnfast.egl.run_train_epoch import get_loss_params, get_mask_indices
from diveslowlearnfast.loss.rrr import DualPathRRRLossV2
from diveslowlearnfast.loss.dice import DiceLoss


cfg = Config()
cfg.DATA.DATASET_PATH = '/Users/youritomassen/Projects/xai/data/Diving48/'
cfg.TRAIN.BATCH_SIZE = 2
cfg.DATA.TEST_CROP_SIZE = 224

device = torch.device('cpu')
model = SlowFast(cfg)
_, optimiser, *_ = train_helper.get_train_objects(cfg, model)
model, *_ = load_checkpoint(model, '../misc/checkpoint.pth', optimiser, device)

explainer = ExplainerStrategy.get_explainer(model, cfg, device)




## Masks Cache
Evaluate the masks cache functionality.

1. First construct a dataloader with the right parameters.
2. Then forward the data through the model.
3. Lastly apply the loss function to the params.

In [2]:
vidnames = ['3PLiUG_DuC8_00208', '5i1begTTucc_00043', 'cYkUl8MrXgA_00252']
dataset = Diving48Dataset(
    cfg.DATA.DATASET_PATH,
    cfg.DATA.NUM_FRAMES,
    cfg.SLOWFAST.ALPHA,
    transform_fn=get_test_transform(cfg),
    video_ids=vidnames,
    loader_mode='mp4',
    mask_type='cache',
    masks_cache_dir='.masks'
)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=cfg.TRAIN.BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

In [3]:
x, y, _, _, _, masks = dataset[0]
assert masks is not None
assert len(masks) == 2

In [4]:
xb, yb, _, masks = dlf.get_batch(iter(loader), device)
masks[0].shape, masks[1].shape

(torch.Size([2, 1, 4, 224, 224]), torch.Size([2, 1, 32, 224, 224]))

In [5]:
cfg.EGL.METHOD = 'cache'
cfg.EGL.LOSS_FUNC = 'rrr_v2'
cfg.GRADCAM.TARGET_LAYERS = ['s5/pathway0_res2', 's5/pathway1_res2']
loss_fn = DualPathRRRLossV2()
exp = ExplainerStrategy.get_explainer(model, cfg, device)
inputs = dlf.to_slowfast_inputs(xb, cfg.SLOWFAST.ALPHA, requires_grad=True)
maps, logits = exp(inputs)
loss_params = get_loss_params(cfg, maps, inputs, yb, logits, masks=masks)
loss, losses = loss_fn(**loss_params)
loss, losses

hi m0m
torch.Size([2, 1, 4, 224, 224]) torch.Size([2, 3, 4, 224, 224])
torch.Size([2, 1, 32, 224, 224]) torch.Size([2, 3, 32, 224, 224])


(tensor(3.1905, grad_fn=<AddBackward0>),
 {'ce_loss': 3.1726107597351074,
  'gradient_loss_path_0': 0.01279267854988575,
  'gradient_loss_path_1': 0.005137650761753321,
  'total_loss': 3.1905410289764404})

## GradCAM Masks

In [12]:
dataset = Diving48Dataset(
    cfg.DATA.DATASET_PATH,
    cfg.DATA.NUM_FRAMES,
    cfg.SLOWFAST.ALPHA,
    transform_fn=get_test_transform(cfg),
    video_ids=vidnames,
    loader_mode='mp4',
)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2
)

x, y, _, _, _, masks = dataset[0]
assert masks is not None
assert masks == False

In [13]:
xb, yb, video_ids, *_ = dlf.get_batch(iter(loader), device)

In [14]:
cfg.EGL.METHOD = 'gradcam'
cfg.GRADCAM.TARGET_LAYERS = ['s5/pathway0_res2', 's5/pathway1_res2']
loss_fn = DualPathRRRLossV2()
inputs = dlf.to_slowfast_inputs(xb, cfg.SLOWFAST.ALPHA, requires_grad=True)
exp = ExplainerStrategy.get_explainer(model, cfg, device)
maps, logits = exp(inputs)
indices = get_mask_indices(video_ids, {'3PLiUG_DuC8_00208'})
loss_params = get_loss_params(cfg, maps, inputs, yb, logits, indices=indices)
loss, losses = loss_fn(**loss_params)
loss, losses

('5i1begTTucc_00043', 'cYkUl8MrXgA_00252', '3PLiUG_DuC8_00208') [False False  True]


(tensor(4.1357, grad_fn=<AddBackward0>),
 {'ce_loss': 3.1364586353302,
  'gradient_loss_path_0': 0.9957987666130066,
  'gradient_loss_path_1': 0.0034821194130927324,
  'total_loss': 4.135739803314209})

## Segment Masks

In [7]:
vidnames = ['3qq031609lA_00002', 'iv0Gu1VXAgc_00167', 'aektxm8cLdo_00000']

dataset = Diving48Dataset(
    cfg.DATA.DATASET_PATH,
    cfg.DATA.NUM_FRAMES,
    cfg.SLOWFAST.ALPHA,
    transform_fn=get_test_transform(cfg),
    mask_transform_fn=get_mask_transform(cfg),
    video_ids=vidnames,
    loader_mode='jpg',
    mask_type='segments'
)

loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=cfg.TRAIN.BATCH_SIZE,
    shuffle=False,
    num_workers=2
)

x, y, _, _, _, masks = dataset[0]
assert masks is not None
masks.shape

torch.Size([1, 32, 224, 224])

In [8]:
xb, yb, _, masks = dlf.get_batch(iter(loader), device)

In [20]:
cfg.EGL.METHOD = 'ogl'
cfg.EGL.LOSS_FUNC = 'dice'
cfg.GRADCAM.TARGET_LAYERS = ['s5/pathway0_res2', 's5/pathway1_res2']
loss_fn = DiceLoss(nn.CrossEntropyLoss())
inputs = dlf.to_slowfast_inputs(xb, cfg.SLOWFAST.ALPHA, requires_grad=True)
exp = ExplainerStrategy.get_explainer(model, cfg, device)
maps, logits = exp(inputs)
loss_params = get_loss_params(cfg, maps, inputs, yb, logits, masks=masks)
loss, losses = loss_fn(**loss_params)
loss, losses

(tensor(4.3445, grad_fn=<AddBackward0>),
 {'ce_loss': tensor(3.4035, grad_fn=<MulBackward0>),
  'dice_loss': tensor(0.9411),
  'total_loss': tensor(4.3445, grad_fn=<AddBackward0>)})

In [10]:
video_ids = ['a', 'b', 'k']
hard_video_ids = {'a', 'c', 'd', 'b'}

indices = np.array([video_id in hard_video_ids for video_id in video_ids])

indices

array([ True,  True, False])