In [41]:
from slowfast.datasets.ptv_datasets import PTVDatasetWrapper, PackPathway, DictToTuple
from slowfast.config.defaults import assert_and_infer_cfg

from slowfast.utils.misc import launch_job
from slowfast.utils.parser import load_config, parse_args
from pytorchvideo.data import (
    LabeledVideoDataset,
    make_clip_sampler,
)
from torch.utils.data import (
    RandomSampler
)
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    RandomShortSideScale,
    ShortSideScale,
    UniformCropVideo,
    UniformTemporalSubsample,
)

from torchvision.transforms import Compose, Lambda, RandomApply
from torchvision.transforms._transforms_video import (
    NormalizeVideo,
    RandomCropVideo,
    RandomHorizontalFlipVideo,

)
import torch
from slowfast.utils.metrics import topk_accuracies,topks_correct
import torch.nn.functional as F
import slowfast.visualization.tensorboard_vis as tb
import os
from sklearn.metrics import confusion_matrix
from slowfast.utils import metrics
import numpy as np
from slowfast.datasets import loader
from slowfast.datasets.transform import RandomColorJitter, RandomGaussianBlur, RandomVerticalFlipVideo, RandomRot90Video, VarianceImageTransform


In [42]:
def div255(x):
    """
    Scale clip frames from [0, 255] to [0, 1].
    Args:
        x (Tensor): A tensor of the clip's RGB frames with shape:
            (channel, time, height, width).

    Returns:
        x (Tensor): Scaled tensor by divide 255.
    """
    return x / 255.0

def change_brightness(x,max_b=60):
    """
    Randomly changes the brightness by some delta_b.
    Args:
        x: A tensor of the clip's  frames with shape:
            (channel, time, height, width).
        max_b: maximum value of intensity to add/subtract from the clip

    Returns:
        x_hat (Tensor): clip with modified brightness

    """
    b = random.randint(-max_b,max_b)
    x_hat = x+b
    x_hat = x_hat.clip(0,255)
    return x_hat

def rgb2gray(x):
    """
    Convert clip frames from RGB mode to GRAYSCALE mode.
    Args:
        x (Tensor): A tensor of the clip's RGB frames with shape:
            (channel, time, height, width).

    Returns:
        x (Tensor): Converted tensor
    """
    return x[[0], ...]


def rgb2var(x,var_dim=1):
    assert var_dim in [1,2]
    gray = torch.squeeze(x[[0],...])
    var = gray.var(axis=0).numpy()
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
    ekernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(7,7))
    opening = cv2.morphologyEx(var, cv2.MORPH_OPEN, kernel)
    erode = cv2.erode(opening,ekernel,iterations=2)
    dilate_var = torch.tensor(cv2.dilate(erode,kernel,iterations=10))
    if var_dim==2:
        var_array = torch.stack((gray,torch.stack([dilate_var]*gray.shape[0]),torch.stack([dilate_var]*gray.shape[0])))
    elif var_dim==1:
        var_array = torch.stack((gray, gray, torch.stack([dilate_var] * gray.shape[0])))
    return var_array
def Ptvfishbase(cfg, mode):
    """
    Construct the Fishbase video loader with a directory, each directory is split into modes ('train', 'val', 'test')
    and inside each mode are subdirectories for each label class.
    For `train` and `val` mode, a single clip is randomly sampled from every video
    with random cropping, scaling, and flipping. For `test` mode, multiple clips are
    uniformaly sampled from every video with center cropping.
    Args:
        cfg (CfgNode): configs.
        mode (string): Options includes `train`, `val`, or `test` mode.
            For the train and val mode, the data loader will take data
            from the train or val set, and sample one clip per video.
            For the test mode, the data loader will take data from test set,
            and sample multiple clips per video.
    """
    # Only support train, val, and test mode.
    assert mode in [
        "train",
        "val",
        "test",
        'train_eval',
        'val_eval',
    ], "Split '{}' not supported".format(mode)

    clip_duration = (
        cfg.DATA.NUM_FRAMES * cfg.DATA.SAMPLING_RATE / cfg.DATA.TARGET_FPS
    )
    path_to_dir = os.path.join(
        cfg.DATA.PATH_TO_DATA_DIR, mode.split('_')[0] #added split to deal with the case of train_eval and val_eval
    )

    labeled_video_paths = LabeledVideoPaths.from_directory(path_to_dir)
    num_videos = len(labeled_video_paths)
    labeled_video_paths.path_prefix = cfg.DATA.PATH_PREFIX
    if mode in ["train", "val"]:
        num_clips = 1
        num_crops = 1

        transform = Compose(
            [
                ApplyTransformToKey(
                    key="video",
                    transform=Compose(
                        [
                            UniformTemporalSubsample(cfg.DATA.NUM_FRAMES),
                            Lambda(div255),
                            RandomColorJitter(brightness_ratio=cfg.DATA.BRIGHTNESS_RATIO, p=cfg.DATA.BRIGHTNESS_PROB), #first trial 0.3
                            RandomGaussianBlur(kernel=13, sigma=(6.0,10.0), p=cfg.DATA.BLUR_PROB), # first trial 0.2
                            NormalizeVideo(cfg.DATA.MEAN, cfg.DATA.STD),
                            ShortSideScale(cfg.DATA.TRAIN_JITTER_SCALES[0]),
                        ]
                        + (
                            [Lambda(rgb2gray)]
                            if cfg.DATA.INPUT_CHANNEL_NUM[0] == 1
                            else []
                        )
                        + (
                            [VarianceImageTransform(var_dim=cfg.DATA.VAR_DIM)]
                            if cfg.DATA.VARIANCE_IMG
                            else []
                        )
                        + (
                            [RandomHorizontalFlipVideo(p=0.5),
                             RandomVerticalFlipVideo(p=0.5),
                             RandomRot90Video(p=0.5)]
                            if cfg.DATA.RANDOM_FLIP
                            else []
                        )
                        + [PackPathway(cfg)]
                    ),
                ),
                DictToTuple(num_clips, num_crops),
            ]
        )

        clip_sampler = make_clip_sampler("random", clip_duration)
        if cfg.NUM_GPUS > 1:
            video_sampler = DistributedSampler
        else:
            video_sampler = (
                RandomSampler if mode == "train" else SequentialSampler
            )
    else:
        num_clips = cfg.TEST.NUM_ENSEMBLE_VIEWS
        num_crops = cfg.TEST.NUM_SPATIAL_CROPS

        transform = Compose(
            [
                ApplyTransformToKey(
                    key="video",
                    transform=Compose(
                        [
                            UniformTemporalSubsample(cfg.DATA.NUM_FRAMES),
                            Lambda(div255),
                            NormalizeVideo(cfg.DATA.MEAN, cfg.DATA.STD),
                            ShortSideScale(
                                size=cfg.DATA.TRAIN_JITTER_SCALES[0]
                            ),
                        ]
                        + (
                            [Lambda(rgb2gray)]
                            if cfg.DATA.INPUT_CHANNEL_NUM[0] == 1
                            else []
                        )
                        + (
                            [VarianceImageTransform(var_dim=cfg.DATA.VAR_DIM)]
                            if cfg.DATA.VARIANCE_IMG
                            else []
                        )
                    ),
                ),
                ApplyTransformToKey(key="video", transform=PackPathway(cfg)),
                DictToTuple(num_clips, num_crops),
            ]
        )
        clip_sampler = make_clip_sampler(
            "constant_clips_per_video",
            clip_duration,
            num_clips,
            num_crops,
        )
        video_sampler = (
            DistributedSampler if cfg.NUM_GPUS > 1 else SequentialSampler
        )

    return PTVDatasetWrapper(
        num_videos=num_videos,
        clips_per_video=num_clips,
        crops_per_clip=num_crops,
        dataset=LabeledVideoDataset(
            labeled_video_paths=labeled_video_paths,
            clip_sampler=clip_sampler,
            video_sampler=video_sampler,
            transform=transform,
            decode_audio=False,
        ),
    )



In [32]:
def construct_loader(cfg, split, is_precise_bn=False):
    """
    Constructs the data loader for the given dataset.
    Args:
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        split (str): the split of the data loader. Options include `train`,
            `val`, and `test`.
    """
    assert split in ["train", "val", "test","train_eval","val_eval"]
    if split in ["train"]:
        dataset_name = cfg.TRAIN.DATASET
        batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
        shuffle = True
        drop_last = True
    elif split in ["val", "val_eval","train_eval"]:
        dataset_name = cfg.TRAIN.DATASET
        batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
        shuffle = False
        drop_last = False
    elif split in ["test"]:
        dataset_name = cfg.TEST.DATASET
        batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
        shuffle = False
        drop_last = False

    # Construct the dataset
    dataset = Ptvfishbase(cfg, split)

    loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=cfg.DATA_LOADER.NUM_WORKERS,
            pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
            drop_last=drop_last,
            collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
            worker_init_fn=utils.loader_worker_init_fn(dataset),
        )
    return loader


In [33]:
def eval_epoch(model,loader,cfg):
    model.eval()
    num_correct = 0
    y_hats = []
    all_labels = []
    stats = {}
    running_err = 0
    all_preds = []
    all_file_names = []
    file_list = loader.dataset.dataset._labeled_videos._paths_and_labels
    with torch.no_grad():
        for cur_iter, (inputs, labels, indices, meta) in enumerate(loader):
            if isinstance(inputs, (list,)):
                for i in range(len(inputs)):
                    inputs[i] = inputs[i].cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)
            labels = labels.cuda()
            preds = model(inputs)
            preds = torch.nn.functional.softmax(preds, dim=1)
            file_names = [file_list[f][0] for f in indices]
            all_file_names+=file_names
            y_hat = preds.max(axis=1).indices
            all_preds.append(preds)
            y_hats.append(y_hat)
            all_labels.append(labels)
            num_correct += (labels == y_hat).sum()
            k = min(cfg.MODEL.NUM_CLASSES, 5)  # in case there aren't at least 5 classes in the dataset
            num_topks_correct = metrics.topks_correct(preds, labels, (1, k))
            # Combine the errors across the GPUs.
            top1_err, _ = [
                (1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct
            ]
            top1_err = top1_err.item()
            running_err += top1_err*preds.size(0)
        all_labels = torch.hstack(all_labels)
        y_hats = torch.hstack(y_hats)
        all_preds = torch.vstack(all_preds)
        
        tn, fp, fn, tp = confusion_matrix(1-all_labels.cpu(),1-y_hats.cpu()).ravel()
        stats['fns'] = fn
        stats['fps'] = fp
        stats['tns'] = tn
        stats['tps'] = tp
        stats['top1_err'] = (running_err/len(loader.dataset))
        stats['accuracy'] = (num_correct/len(loader.dataset)).item()
    return all_labels, y_hats, stats,all_preds, all_file_names


def train_one_epoch(model,optim,loader_train,loss_func):
    model.train()
    train_loss = 0
    num_correct = 0
    train_stats = {'fps':0,'tns':0,'fns':0,'tps':0,'accuracy':0}
    y_hats = []
    all_labels = []
    running_err = 0
    for cur_iter, (inputs, labels, _, meta) in enumerate(loader_train):
        if isinstance(inputs, (list,)):
            for i in range(len(inputs)):
                inputs[i] = inputs[i].cuda(non_blocking=True)
        else:
            inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda()
        preds = model(inputs)
        preds = torch.nn.functional.softmax(preds,dim=1)
        y_hat = preds.max(axis=1).indices
        y_hats.append(y_hat)
        all_labels.append(labels)
        num_correct += (labels==y_hat).sum()
        loss = loss_func(preds,labels)
        optim.zero_grad()
        loss.backward()
        optim.step()
        k = min(cfg.MODEL.NUM_CLASSES, 5)  # in case there aren't at least 5 classes in the dataset
        num_topks_correct = metrics.topks_correct(preds, labels, (1, k))
        top1_err, _ = [
            (1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
        running_err += top1_err * preds.size(0)
        train_loss += loss/preds.shape[0]
    all_labels = torch.hstack(all_labels)
    y_hats = torch.stack(y_hats).ravel()
    tn, fp, fn, tp = confusion_matrix(1 - all_labels.cpu(),
                                      1 - y_hats.cpu()).ravel()  # since feed is label 0 and we want it as label 1,
    train_stats['loss'] = train_loss
    train_stats['fns'] = fn
    train_stats['fps'] = fp
    train_stats['tns'] = tn
    train_stats['tps'] = tp
    train_stats['top1_err'] = (running_err / len(loader_train.dataset))
    train_stats['accuracy'] = num_correct / len(loader_train.dataset)
    return model, optim, train_stats

def train(cfg, pretrained=True):
    print('starting train')
    np.random.seed(cfg.RNG_SEED)
    torch.manual_seed(cfg.RNG_SEED)
    loader_train = construct_loader(cfg, 'train')
    loader_val = construct_loader(cfg, 'val')
    model_name = "slowfast_r50"
    model = torch.hub.load("facebookresearch/pytorchvideo:main", model=model_name, pretrained=pretrained)
    model.blocks[6].proj = torch.nn.Linear(in_features=2304, out_features=cfg.MODEL.NUM_CLASSES)
    model.cuda()
    if cfg.TENSORBOARD.ENABLE:
        writer = tb.TensorboardWriter(cfg)
    else:
        writer = None
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum = cfg.SOLVER.MOMENTUM)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
    loss_func = torch.nn.CrossEntropyLoss(reduction='mean')
    exp_dir = cfg.OUTPUT_DIR
    os.makedirs(os.path.join(exp_dir,'checkpoints'), exist_ok=True)

    train_f1s = []
    val_f1s = []
    train_recall = []
    val_recall = []
    train_losses = []
    prec_func = lambda st: st['tps'] / (st['tps'] + st['fps'])
    rec_func = lambda st: st['tps'] / (st['tps'] + st['fns'])
    f1_func = lambda st: 2 * (st['precision'] * st['recall']) / (st['precision'] + st['recall'])
    for cur_epoch in range(cfg.SOLVER.MAX_EPOCH):
        loader.shuffle_dataset(loader_train, cur_epoch)
        model, optimizer, train_stats = train_one_epoch(model, optimizer, loader_train, loss_func)
        scheduler.step(train_stats['loss'])
        train_labels,train_y_hats,train_eval_stats,_,_ = eval_epoch(model,loader_train,cfg)
        val_labels,val_y_hats,val_stats,_,_ = eval_epoch(model,loader_val,cfg)
        train_stats['precision'] = prec_func(train_stats)
        train_stats['recall'] = rec_func(train_stats)
        train_stats['f1'] = f1_func(train_stats)
        train_eval_stats['precision'] = prec_func(train_eval_stats)
        train_eval_stats['recall'] = rec_func(train_eval_stats)
        train_eval_stats['f1'] = f1_func(train_eval_stats)
        val_stats['precision'] = prec_func(val_stats)
        val_stats['recall'] = rec_func(val_stats)
        val_stats['f1'] = f1_func(val_stats)
        train_f1s.append(train_stats['f1'])
        val_f1s.append(val_stats['f1'])
        train_recall.append(train_stats['recall'])
        val_recall.append(val_stats['recall'])
        train_losses.append(train_stats['loss'])
        if writer is not None:
            writer.add_scalars(
                {
                    "Train/epoch_loss": train_stats['loss'],
                    "Train/epoch_top1_err": train_stats['top1_err'],
                    "Train_eval/epoch_top1_err": train_eval_stats['top1_err'],
                    "Train_eval/epoch_accuracy": train_eval_stats['accuracy'],
                    "Train_eval/epoch_precision": train_eval_stats['precision'],
                    "Train_eval/epoch_recall": train_eval_stats['recall'],
                    "Val/epoch_top1_err": val_stats['top1_err'],
                    "Val/epoch_precision": train_eval_stats['precision'],
                    "Val/epoch_recall": train_eval_stats['recall']
                },
                global_step=cur_epoch,
            )
        print(f'{cur_epoch}/{cfg.SOLVER.MAX_EPOCH}: loss {train_stats["loss"]} '
              f'Train F1 {train_stats["f1"]:.2f}, acc {train_stats["accuracy"]:.2f}, '
              f'recall {train_stats["recall"]:.2f}')
        print(f'Val F1 {val_stats["f1"]:.2f}, acc {val_stats["accuracy"]:.2f}, recall {val_stats["recall"]:.2f}')
        torch.save({'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(),
                    'train_losses':train_losses,
                    'train_labels': train_labels,
                    'train_y_hats':train_y_hats,
                    'val_labels':val_labels,
                    'val_y_hats':val_y_hats,
                    'scheduler_state': scheduler.state_dict(),
                    'train_stats': train_stats,
                    'train_eval_stats':train_eval_stats,
                    'val_stats': val_stats},
                   os.path.join(exp_dir, 'checkpoints',f'pretrained_epoch{cur_epoch}.pt'))
    if writer is not None:
        writer.close()


In [43]:
class Args:
    def __init__(self,cfg_file):
        self.cfg_file = cfg_file
        self.shard_id = 0
        self.num_shards = 1
        self.init_method = 'tcp://localhost:9999'
        self.opts = None
        
args = Args('/media/shirbar/DATA/codes/SlowFast/configs/FishBase/SLOWFAST_8x8_R50_feed_pretrained.yaml')
cfg = load_config(args)
cfg = assert_and_infer_cfg(cfg)
#cfg.DATA.PATH_TO_DATA_DIR =  '/home/shirbar/data/strike_ds_dist_cropped_equal' 
#cfg.OUTPUT_DIR = '/mnt/slowfast_results/pretrained_sgd_ablation'
train(cfg, pretrained=True)

starting train


Using cache found in /home/shirbar/.cache/torch/hub/facebookresearch_pytorchvideo_main


0/50: loss 3.4249308109283447 Train F1 0.10, acc 0.53, recall 0.05
Val F1 0.50, acc 0.64, recall 0.36
1/50: loss 3.3388326168060303 Train F1 0.54, acc 0.66, recall 0.41
Val F1 0.62, acc 0.55, recall 0.73
2/50: loss 3.1456546783447266 Train F1 0.72, acc 0.78, recall 0.59
Val F1 0.69, acc 0.59, recall 0.91
3/50: loss 2.957608222961426 Train F1 0.80, acc 0.81, recall 0.77
Val F1 0.74, acc 0.68, recall 0.91


KeyboardInterrupt: 