In [108]:
import torchvision
from torchvision.io import read_video
import lightning
import torchmetrics
import timm
import os
import glob
import numpy as np
import av
import cv2
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import ast
import lightning as L
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler

In [89]:
def set_seed(seed = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    L.seed_everything(seed)

In [90]:
set_seed(42)

Seed set to 42


In [54]:

train_dir = "data_train_short\data_train_short"

video_train_paths = glob.glob(os.path.join(train_dir, "*", "*.mp4"))
print(video_train_paths)

['data_train_short\\data_train_short\\-220020068_456239859\\-220020068_456239859.mp4', 'data_train_short\\data_train_short\\-220020068_456241671\\-220020068_456241671.mp4', 'data_train_short\\data_train_short\\-220020068_456241672\\-220020068_456241672.mp4', 'data_train_short\\data_train_short\\-220020068_456241673\\-220020068_456241673.mp4', 'data_train_short\\data_train_short\\-220020068_456241682\\-220020068_456241682.mp4', 'data_train_short\\data_train_short\\-220020068_456241755\\-220020068_456241755.mp4', 'data_train_short\\data_train_short\\-220020068_456241756\\-220020068_456241756.mp4', 'data_train_short\\data_train_short\\-220020068_456241758\\-220020068_456241758.mp4', 'data_train_short\\data_train_short\\-220020068_456241844\\-220020068_456241844.mp4', 'data_train_short\\data_train_short\\-220020068_456241845\\-220020068_456241845.mp4', 'data_train_short\\data_train_short\\-220020068_456241846\\-220020068_456241846.mp4', 'data_train_short\\data_train_short\\-220020068_45624

In [55]:
def read_video_frame_by_frame(path):
    container = av.open(path)
    for frame in container.decode(video=0):
        yield frame.to_ndarray(format="rgb24")

In [56]:
def read_video_safe(path, target_size=(112, 112)):
    cap = cv2.VideoCapture(path)
    frames = []
    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break 
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame).resize(target_size)
            frames.append(np.array(frame))

    except Exception as e:
        print(f"error {path}")
        if len(frames) > 0:
            return frames
        else:
            return None
    finally:
        cap.release()
    if len(frames)>0:
        return frames
    else:
        return None



In [57]:
valid_videos = []
for path in video_train_paths[0:2]:
    video = read_video_safe(path)
    if video is not None:
        valid_videos.append(video)
video_train = valid_videos

In [58]:
def time_to_frames(str):
    h, m, s = map(int, str.split(':'))
    return (h*3600 + m*60 + s)*24

In [70]:
with open(r"C:\UCHYOBA\VK\labels_json\labels_json\train_labels.json", "r", encoding='utf-8') as f:
    content = f.read()
    data = ast.literal_eval(content)
    result = {key: (min(time_to_frames(val["start"]), time_to_frames(val["end"])),
                    max(time_to_frames(val["start"]), time_to_frames(val["end"])))
    for key, val in data.items()
}

sorted_data = {k: result[k] for k in sorted(result)}
print(sorted_data)

        

{'-220020068_456239859': (360, 672), '-220020068_456241671': (4560, 5448), '-220020068_456241672': (4680, 5544), '-220020068_456241673': (3288, 4128), '-220020068_456241682': (3144, 3720), '-220020068_456241755': (4512, 4608), '-220020068_456241756': (1584, 1680), '-220020068_456241758': (1488, 2832), '-220020068_456241844': (3936, 4032), '-220020068_456241845': (1608, 1680), '-220020068_456241846': (1800, 1896), '-220020068_456241847': (9744, 9840), '-220020068_456241849': (4392, 5736), '-220020068_456241850': (1464, 1536), '-220020068_456241851': (360, 1560), '-220020068_456248657': (3576, 3600), '-220020068_456249667': (144, 240), '-220020068_456249692': (144, 240), '-220020068_456249693': (144, 240), '-220020068_456249716': (144, 216), '-220020068_456249719': (144, 240), '-220020068_456249720': (144, 216), '-220020068_456249732': (144, 240), '-220020068_456249733': (144, 216), '-220020068_456249739': (144, 240), '-220020068_456252055': (264, 11040), '-220020068_456253855': (11808, 

In [96]:
class IntroDataset(Dataset):
    def __init__(self, videos, labels_dict, filenames, clip_len=16, fps=24, transform=None):
        self.clip_len = clip_len
        self.fps = fps
        self.transform = transform

        self.samples = []

        for idx, video_frames in enumerate(videos):
            name = filenames[idx]
            if name not in labels_dict:
                continue
            label_range = labels_dict[name]
            frames = np.array(video_frames) 
            for start_idx in range(0, len(frames) - clip_len + 1):
                clip = frames[start_idx:start_idx + clip_len]
                overlap = sum([
                    1 for i in range(start_idx, start_idx + clip_len)
                    if label_range[0] <= i <= label_range[1]
                ])
                label = 1.0 if overlap / clip_len >= 0.5 else 0.0

                self.samples.append((clip, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        clip, label = self.samples[idx]
        clip = clip.transpose(3, 0, 1, 2)
        clip = torch.from_numpy(clip).float()
        if self.transform:
            clip = self.transform(clip)
        return clip, torch.tensor(label, dtype=torch.float32)


In [None]:
filenames = [os.path.splitext(os.path.basename(path))[0] for path in video_train_paths[0:2]]


dataset = IntroDataset(
    videos=video_train,
    labels_dict=sorted_data,
    filenames=filenames,
    clip_len=16,
    fps=24
)

<__main__.IntroDataset object at 0x000002BD4AB6AC10>


In [109]:

intro_frames = sum((sorted_data[name][1]-sorted_data[name][0]) for name in filenames)
print(intro_frames)
all_frames = sum(len(video) for video in video_train)
print(all_frames)
weights = intro_frames/all_frames, 1-intro_frames/all_frames 
print(weights)
sampler = WeightedRandomSampler(weights=weights, num_samples=len(dataset))

1200
64409
(0.018630936670341103, 0.9813690633296589)


In [114]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=4, sampler=sampler)

In [115]:
model = torchvision.models.video.r3d_18
model.fc = nn.Linear(512, 1)