In [None]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import glob
import torch
from PIL import Image
import re

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (20,8)
font = {'family' : 'DejaVu Sans',  'weight' : 'normal',  'size'  : 18}
plt.rc('font', **font)

In [None]:
%matplotlib inline

In [None]:
transform = transforms.Compose([
    transforms.Resize((576, 704)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
def parse_img_num(img):
    m = re.search(r"frame(\d+).png", img)
    if m:
        return f"{int(m.groups()[0]):06d}"
    print(img)
    raise Exception("Couldn't parse number")

In [None]:
class MySampler(torch.utils.data.Sampler):
    def __init__(self, end_idx, seq_length):
        indices = []
        for i in range(len(end_idx) - 1):
            start = end_idx[i]
            end = end_idx[i + 1] - seq_length
            if start > end:
                indices.append(torch.arange(end, start))
            else:
                indices.append(torch.arange(start, end))
        indices = torch.cat(indices)
        self.indices = indices
        
    def __iter__(self):
        indices = self.indices[torch.randperm(len(self.indices))]
        return iter(indices.tolist())
    
    def __len__(self):
        return len(self.indices)

In [None]:
class MyDataset(Dataset):
    def __init__(self, image_paths, seq_length, transform, length): #csv_file, 
        self.image_paths = image_paths
        self.seq_length = seq_length
        self.transform = transform
        self.length = length
        
    def __getitem__(self, index):
        start = index
        end = index + self.seq_length
        print('Getting images from {} to {}'.format(start, end))
        indices = list(range(start, end))
        images = []
        for i in indices:
            image_path = self.image_paths[i][0]
            image = Image.open(image_path)
            if self.transform:
                image = self.transform(image)
            images.append(image)
        x = torch.stack(images)
        y = torch.tensor([self.image_paths[start][1]], dtype=torch.long)
        
        return x, y
    
    def __len__(self):
        return self.length

In [None]:
root_dir = '/media/scratch/astamoulakatos/nsea_video_jpegs/'
class_paths = [d.path for d in os.scandir(root_dir) if d.is_dir]

In [None]:
class_paths

In [None]:
one_hot_classes = [[1,0,1,0,0],[1,0,0,0,1],[1,0,0,0,0],[1,0,0,1,0],[0,1,0,0,0]]

In [None]:
df = pd.read_csv('../train-valid-splits-video-scratch/valid.csv')
df.head()

In [None]:
class_names = ['exp_and','exp_fs','exp','exp_fj','bur']
one_hot_classes = [[1,0,1,0,0],[1,0,0,0,1],[1,0,0,0,0],[1,0,0,1,0],[0,1,0,0,0]]
class_image_paths = []
end_idx = []
for c, class_path in enumerate(class_paths):
     for d in os.scandir(class_path):
        if d.is_dir:
            if d.path in df.videos.values:
                paths = sorted(glob.glob(os.path.join(d.path, '*.png')))
                # Add class idx to paths
                paths = [(p, one_hot_classes[c]) for p in paths]
                class_image_paths.extend(paths)
                end_idx.extend([len(paths)])

In [None]:
end_idx[:10]

In [None]:
len(paths)

In [None]:
len(class_image_paths)

In [None]:
len(end_idx)

In [None]:
sum(end_idx)

In [None]:
end_idx = [0, *end_idx]
end_idx = torch.cumsum(torch.tensor(end_idx), 0)

In [None]:
end_idx[:10]

In [None]:
# class_image_paths = []
# end_idx = []
# for c, class_path in enumerate(class_paths):
#     for d in os.scandir(class_path):
#         if d.is_dir:
#             paths = sorted(glob.glob(os.path.join(d.path, '*.png')))
#             # Add class idx to paths
#             paths = [(p, one_hot_classes[c]) for p in paths]
#             class_image_paths.extend(paths)
#             end_idx.extend([len(paths)])
            
# end_idx = [0, *end_idx]
# end_idx = torch.cumsum(torch.tensor(end_idx), 0)

In [None]:
class_image_paths = []
end_idx = []
for c, class_path in enumerate(class_paths):
    for d in os.scandir(class_path):
        if d.is_dir:
            if d.path in df.videos.values:
                paths = sorted(glob.glob(os.path.join(d.path, '*.png')))
                # Add class idx to paths
                paths = [(p, one_hot_classes[c]) for p in paths]
                if len(paths)>=16:
                    class_image_paths.extend(paths)
                    end_idx.extend([len(paths)])
                
end_idx = [0, *end_idx]
end_idx = torch.cumsum(torch.tensor(end_idx), 0)

In [None]:
end_idx

In [None]:
seq_length = 16

In [None]:
indices = []
for i in range(len(end_idx) - 1):
    start = end_idx[i]
    print(start)
    end = end_idx[i + 1] - seq_length
    print(end)
    #if start > end:
    indices.append(torch.arange(start, end))
    print(indices)
#             else:
#                 indices.append(torch.arange(start, end))
indices = torch.cat(indices)

In [None]:
indices[10:20]

In [None]:
indices = indices[torch.randperm(len(indices))]
l = iter(indices.tolist())
    

In [None]:
indices

In [None]:
next(iter(l))

In [None]:
class MySampler_test(torch.utils.data.Sampler):
    def __init__(self, end_idx, seq_length):
        indices = []
        for i in range(len(end_idx) - 1):
            start = end_idx[i]
            print(start)
            end = end_idx[i + 1] - seq_length
            print(end)
#             if start > end:
#                 indices.append(torch.arange(end, start))
#             else:
            indices.append(torch.arange(start, end))
        indices = torch.cat(indices)
        self.indices = indices
        
    def __iter__(self):
        indices = self.indices[torch.randperm(len(self.indices))]
        return iter(indices.tolist())
    
    def __len__(self):
        return len(self.indices) 

In [None]:
sampler.indices

In [None]:
sampler = MySampler_test(end_idx, seq_length)

In [None]:
sampler = MySampler(end_idx, seq_length)

In [None]:
len(sampler)

In [None]:
dataset = MyDataset(
    image_paths=class_image_paths,
    seq_length=seq_length,
    transform=transform,
    length=len(sampler))


In [None]:
loader = DataLoader(
    dataset,
    batch_size=1,
    sampler=sampler,
    drop_last=True,
    num_workers=0
)

In [None]:
# for data, target in loader:
#     print(data.shape)
#     print(target.shape)

In [None]:
t = next(iter(loader))

In [None]:
import torchvision

In [None]:
class_names = ['exp_and','exp_fs','exp','exp_fj','bur']

In [None]:
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.figure(figsize=(30,30))
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


In [None]:
# Get a batch of training data
inputs, classes = next(iter(loader))
inputs = inputs.squeeze(dim = 0)

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)


for i, f in enumerate(one_hot_classes):
    if np.array_equal(classes[0][0].numpy(), np.asarray(f)):
        title = class_names[i]
        
        
imshow(out, title=title)

In [None]:
classes

In [None]:
inputs