In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

from video_dataset.video import VideoFromVideoFile

from tas_helpers.visualization import SegmentationVisualizer

from bouldering_video_segmentation.models import VideoSegmentMlp
from bouldering_video_segmentation.extractors import ResNet3DFeatureExtractor

  from .autonotebook import tqdm as notebook_tqdm


<div class="alert alert-warning">

**Information:** The constants below (`VIDEO_PATH` and `VIDEO_SEGMENT_MLP_MODEL_WEIGHTS_PATH`) are to be defined. You can find an example video and the models weights in the github repository.

</div>

In [None]:
VIDEO_PATH = "..."
SEGMENT_SIZE = 32
NUMBER_OF_CLASSES = 5
VIDEO_SEGMENT_MLP_MODEL_WEIGHTS_PATH = "..."

In [None]:
video_dir_path = "/".join(VIDEO_PATH.split("/")[:-1])
video_name, video_extension = VIDEO_PATH.split("/")[-1].split(".")

In [None]:
extractor = ResNet3DFeatureExtractor()

model = VideoSegmentMlp(
    input_size=extractor.get_features_shape(),
    # NOTE: the model has been trained on 5 classes, thus the output size is 5 and can't be changed when used with the provided weights
    output_size=NUMBER_OF_CLASSES
)

model = model.load_state_dict(torch.load(VIDEO_SEGMENT_MLP_MODEL_WEIGHTS_PATH))

video = VideoFromVideoFile(
    videos_dir_path=video_dir_path,
    id=video_name,
    video_extension=video_extension
)

In [None]:
predictions = []

for segment in video.get_segments(segment_size=SEGMENT_SIZE):
    features = extractor.transform_and_extract(segment)
    
    prediction = model(features)
    
    predictions.append(prediction)

In [None]:
SegmentationVisualizer(segment, prediction).show()