In [1]:
import numpy as np 
import torch
from torch import nn
import torch.nn.functional as F
import sys
import time

sys.path.append('/home/lugeon/eeg_project/scripts')
from training.representation.cl_transforms import identity, reverse, add_gaussian_noise, permute_frames, jitter_channels
from training.representation.losses import ContrastiveLoss

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
def permute_frames2(array: np.array):
    """ Permute the frames independently for each video.

    Args:
        array (np.array): Input array of shape (batch_size x n_frames x n_channels x img_dim x img_dim)
    """
    print(array.shape)
    batch_size, n_frames, _, _, _ = array.shape

    # one permutation per video
    ix = np.arange(batch_size * n_frames).reshape(batch_size, n_frames)
    np.apply_along_axis(np.random.shuffle, 1, ix)
    print(ix.shape)
    
    # reshape and permute frames
    per = array.reshape(batch_size * n_frames, -1)[ix.flatten()]
    print(per.shape)
    
    return per.reshape(array.shape)

In [16]:
loss = ContrastiveLoss()

batch_size = 512
n_frames = 10
n_channels = 5
image_dim = 32
encoding_size = 1024

video = np.random.rand(batch_size, n_frames, n_channels, image_dim, image_dim)

t = time.time()
identity(video)
print(f'Identity transform time: {time.time() - t:.3f} seconds')

t = time.time()
reverse(video)
print(f'Reverse transform time: {time.time() - t:.3f} seconds')

t = time.time()
add_gaussian_noise(video)
print(f'Add noise transform time: {time.time() - t:.3f} seconds')

t = time.time()
permute_frames2(video)
print(f'Permute transform time: {time.time() - t:.3f} seconds')

t = time.time()
jitter_channels(video)
print(f'Jitter channels transform time: {time.time() - t:.3f} seconds')

first = torch.rand(batch_size, encoding_size)
second = torch.rand(batch_size, encoding_size)

t = time.time()
loss(first, second)
print(f'Compute loss time: {time.time() - t:.3f} seconds')

Identity transform time: 0.000 seconds
Reverse transform time: 0.077 seconds
Add noise transform time: 0.660 seconds
(512, 10, 5, 32, 32)
(512, 10)
(5120, 5120)
Permute transform time: 0.077 seconds
Jitter channels transform time: 0.099 seconds
Compute loss time: 0.068 seconds


In [3]:
permute_frames(video).shape

(256, 20, 5, 32, 32)
(256, 20)
(5120, 5120)


(256, 20, 5, 32, 32)