In [None]:
import json
from pathlib import Path
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.io import read_video


class VideoNormalize(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = torch.tensor(mean).view(3, 1, 1, 1)
        self.std = torch.tensor(std).view(3, 1, 1, 1)

    def forward(self, video):
        return (video - self.mean) / self.std


class VideoDataset(Dataset):
    def __init__(self, root_dir, split, transform=None, clip_duration=5.0, target_fps=30):
        self.root_dir = Path(root_dir) / split
        self.transform = transform
        self.clip_duration = clip_duration
        self.target_fps = target_fps
        self.target_frames = int(clip_duration * target_fps)
        self.video_files = []
        self.labels = {}

        # Load labels from labels.json
        labels_path = self.root_dir / 'labels.json'
        with open(labels_path, 'r') as f:
            self.labels = json.load(f)

        # Collect video file paths
        self.video_files = list(self.root_dir.glob('*.avi'))

    def __len__(self):
        return len(self.video_files)

    def __getitem__(self, idx):
        video_path = str(self.video_files[idx])
        video_name = self.video_files[idx].name
        label = self.labels[video_name]['graininess']

        # Read video using torchvision
        video, audio, meta = read_video(video_path, pts_unit='sec')

        # Extract frame rate from metadata
        fps = meta['video_fps']

        # Calculate the number of frames to sample based on the clip duration and video's fps
        num_frames_to_sample = min(int(self.clip_duration * fps), video.shape[0])

        # Sample frames
        if num_frames_to_sample < video.shape[0]:
            start_idx = torch.randint(0, video.shape[0] - num_frames_to_sample + 1, (1,)).item()
            video = video[start_idx:start_idx + num_frames_to_sample]

        # Resample to target FPS
        if fps != self.target_fps:
            indices = torch.linspace(0, video.shape[0] - 1, self.target_frames).long()
            video = video[indices]

        # Ensure we have exactly target_frames
        if video.shape[0] < self.target_frames:
            video = torch.cat([video, video[-1].unsqueeze(0).repeat(self.target_frames - video.shape[0], 1, 1, 1)])
        elif video.shape[0] > self.target_frames:
            video = video[:self.target_frames]

        # Change from (T, H, W, C) to (C, T, H, W)
        video = video.permute(3, 0, 1, 2)

        if self.transform:
            video = self.transform(video)

        return video, torch.tensor(label, dtype=torch.long)


# Example usage
transform = transforms.Compose([
    transforms.Lambda(lambda x: x.float() / 255.0),  # Normalize to [0, 1]
    VideoNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [7]:
# Path to the dataset
data_root = Path('/Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split')

train_dataset = VideoDataset(data_root,
                             split='train',
                             transform=transform)

test_dataset = VideoDataset(data_root,
                            split='test',
                            transform=transform)

val_dataset = VideoDataset(data_root,
                           split='val',
                           transform=transform)

In [9]:
# DataLoader example
from torch.utils.data import DataLoader
import os

batch_size = 4
num_workers = os.cpu_count()

train_loader = DataLoader(train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=num_workers)

test_loader = DataLoader(test_dataset,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=num_workers)

val_loader = DataLoader(val_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=num_workers)

In [26]:
import json
from pathlib import Path

train_data_path = Path('/Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split/train')
labels_path = train_data_path / 'labels.json'

# /Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split/train/labels.json
video_files = list(train_data_path.glob('*.avi'))
with open(labels_path) as f:
    labels = json.load(f)

video_path = str(video_files[5])
video, audio, meta = read_video(video_path, pts_unit='sec')

In [32]:
clip_duration = 5.0

# Extract frame rate from metadata
fps = meta['video_fps']

# Calculate the number of frames to sample based on the clip duration and video's fps
num_frames_to_sample = min(int(clip_duration * fps), video.shape[0])

In [37]:
num_frames_to_sample

300

In [41]:
# Cell 1: Import necessary libraries
import os
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_video
from transformers import VivitImageProcessor, VivitForVideoClassification, TrainingArguments, Trainer


# Cell 2: Set random seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


set_seed(42)


# Cell 3: Define custom dataset class
# Cell 3: Define custom dataset class
class VideoDataset(Dataset):
    def __init__(self, data_dir, split, processor, max_frames=32):
        self.data_dir = os.path.join(data_dir, split)
        self.processor = processor
        self.max_frames = max_frames
        
        with open(os.path.join(self.data_dir, 'labels.json'), 'r') as f:
            self.labels = json.load(f)
        
        self.video_files = list(self.labels.keys())
    
    def __len__(self):
        return len(self.video_files)
    
    def __getitem__(self, idx):
        video_file = self.video_files[idx]
        video_path = os.path.join(self.data_dir, video_file)
        
        # Read video
        video, _, _ = read_video(video_path, pts_unit='sec')
        
        # Sample frames
        num_frames = video.shape[0]
        if num_frames > self.max_frames:
            start = random.randint(0, num_frames - self.max_frames)
            video = video[start:start+self.max_frames]
        else:
            video = video[:self.max_frames]
        
        # Ensure we have 3 channels (RGB)
        if video.shape[-1] != 3:
            video = video.expand(-1, -1, -1, 3)
        
        # Convert to numpy array and ensure correct shape
        video = video.numpy()
        
        # Ensure the video has the correct shape (num_frames, height, width, channels)
        if video.shape[1] == 3:  # If channels are in the second dimension
            video = np.transpose(video, (0, 2, 3, 1))
        
        # Process frames
        pixel_values = self.processor(
            list(video),
            return_tensors="pt",
            do_resize=True,
            size={"shortest_edge": 224},  # Adjust this size as needed
            do_center_crop=True,
            crop_size={"height": 224, "width": 224},  # Adjust this size as needed
        ).pixel_values
        
        # Get label
        label = self.labels[video_file]['graininess']
        
        return {'pixel_values': pixel_values.squeeze(), 'label': torch.tensor(label)}


# Cell 4: Initialize ViViT model and processor
model_name = "google/vivit-b-16x2-kinetics400"
processor = VivitImageProcessor.from_pretrained(model_name,
                                                ignore_mismatched_sizes=True)
model = VivitForVideoClassification.from_pretrained(model_name, num_labels=2,
                                                    ignore_mismatched_sizes=True)

# Cell 5: Prepare datasets and dataloaders
data_dir = "/Users/sv7/Projects/mtl-video-classification/data/graininess_100_balanced_subset_split"
batch_size = 4

train_dataset = VideoDataset(data_dir, 'train', processor)
val_dataset = VideoDataset(data_dir, 'val', processor)
test_dataset = VideoDataset(data_dir, 'test', processor)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

# Cell 6: Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

# Cell 7: Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# Cell 8: Train the model
trainer.train()

# Cell 9: Evaluate on test set
test_results = trainer.evaluate(test_dataset)
print(test_results)

# Cell 10: Save the model
model.save_pretrained("./vivit_graininess_classifier")
processor.save_pretrained("./vivit_graininess_classifier")

Some weights of VivitForVideoClassification were not initialized from the model checkpoint at google/vivit-b-16x2-kinetics400 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return torch.tensor(value)


RuntimeError: MPS backend out of memory (MPS allocated: 17.77 GB, other allocations: 40.66 MB, max allowed: 18.13 GB). Tried to allocate 1.76 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:

# Cell 11: Inference example
def predict_video(video_path):
    video, _, _ = read_video(video_path, pts_unit='sec')
    inputs = processor(list(video.permute(0, 2, 3, 1).numpy()), return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = logits.argmax(-1).item()

    return "Grainy" if predicted_class == 1 else "Not Grainy"


# Example usage
example_video_path = "path/to/example/video.avi"
prediction = predict_video(example_video_path)
print(f"The video is predicted to be: {prediction}")