In [1]:
import os
import math

from tqdm.auto import tqdm, trange
import torch
import torchvision
import numpy as np
from PIL import Image
from torchsummary import summary
import matplotlib.pyplot as plt
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from transformers import VideoMAEImageProcessor, AutoModel, AutoConfig

In [2]:
configs = {
    "ROOT_DIR": "data/IPAD/R01/training/frames",
    "IMAGE_TENSOR_SHAPE": (3, 224, 224),
    "MAE_BACKBONE": "OpenGVLab/VideoMAEv2-Large",
    "SEQ_LEN": 16,
    "BATCH_SIZE": 12,
    "DATASET_SHUFFLE": True,
    "EPOCHS": 10
}

TRANSFORM = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor()
])

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
class CNNFrameReconstructor(torch.nn.Module):
    def __init__(self, embed_dim=1024, feature_dim=512, out_channels=3, img_size=224):
        super(CNNFrameReconstructor, self).__init__()
        self.img_size = img_size
        self.feature_dim = feature_dim
        self.out_channels = out_channels

        self.fc = torch.nn.Linear(embed_dim, feature_dim * (img_size // 16) * (img_size // 16))

        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(feature_dim, feature_dim // 2, kernel_size=4, stride=2, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.ConvTranspose2d(feature_dim // 2, feature_dim // 4, kernel_size=4, stride=2, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.ConvTranspose2d(feature_dim // 4, feature_dim // 8, kernel_size=4, stride=2, padding=1),
            torch.nn.ReLU(inplace=True),
            torch.nn.ConvTranspose2d(feature_dim // 8, out_channels, kernel_size=4, stride=2, padding=1),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        """
        x: (B, embed_dim) - Single embedding vector per batch item
        Output: (B, C, H, W) - Single frame per batch item
        """
        B, D = x.shape
        x = self.fc(x)
        x = x.view(B, self.feature_dim, self.img_size // 16, self.img_size // 16)
        x = self.decoder(x)
        return x


class FramePredictor(torch.nn.Module):
    def __init__(self):
        super(FramePredictor, self).__init__()

        config = AutoConfig.from_pretrained(configs["MAE_BACKBONE"], trust_remote_code=True)
        self.processor = VideoMAEImageProcessor.from_pretrained(configs["MAE_BACKBONE"])    
        self.video_mae = AutoModel.from_pretrained(configs["MAE_BACKBONE"], config=config, trust_remote_code=True)

        self.reconstructor = CNNFrameReconstructor()
    
    def forward(self, x):
        videos = [list(sequence) for sequence in x]
        x = self.processor(videos, return_tensors="pt")
        x['pixel_values'] = x['pixel_values'].permute(0, 2, 1, 3, 4).to(device)
        x = self.video_mae(**x)
        x = self.reconstructor(x)
        return x

In [4]:

def load_one_sequence(sequence_dir: str):
    frame_files = [file for file in os.listdir(sequence_dir) if file.endswith("jpg")]
    sequence = torch.zeros(len(frame_files), *configs["IMAGE_TENSOR_SHAPE"])
    for i, frame_file in enumerate(frame_files):
        frame_path = os.path.join(sequence_dir, frame_file)
        image = Image.open(frame_path).convert("RGB")
        sequence[i] = TRANSFORM(image)
    return sequence


def load_sequences(root_dir: str):
    sequences = []
    sequence_dirs = os.listdir(root_dir)
    for sequence_dir in sequence_dirs:
        sequence_dir_path = os.path.join(root_dir, sequence_dir)
        sequence = load_one_sequence(sequence_dir_path)
        sequences.append(sequence)
    return sequences


def visualize_sequence(sequence: torch.Tensor):
    """
    Expected input shape: (T, C, H, W)
    """
    images = sequence.permute(0, 2, 3, 1)
    grid_shape = math.ceil(math.sqrt(images.shape[0]))
    for i, image in enumerate(images):
        plt.subplot(grid_shape, grid_shape, i+1)
        plt.imshow(image)
        plt.axis('off')
    plt.show()

In [5]:
sequences = load_sequences(configs["ROOT_DIR"])
print(f"Number of sequences loaded: {len(sequences)}")

Number of sequences loaded: 34


In [6]:
class SequenceDataset(Dataset):
    def __init__(self, tensors, seq_len=10):
        self.tensors = tensors
        self.seq_len = seq_len
        
        self.cumulative_lengths = [0]
        for tensor in tensors:
            valid_indices = max(0, tensor.shape[0] - seq_len)
            self.cumulative_lengths.append(self.cumulative_lengths[-1] + valid_indices)
    
    def __len__(self):
        return self.cumulative_lengths[-1]
    
    def __getitem__(self, idx):
        tensor_idx = 0
        while idx >= self.cumulative_lengths[tensor_idx + 1]:
            tensor_idx += 1
        
        start_frame = idx - self.cumulative_lengths[tensor_idx]
        tensor = self.tensors[tensor_idx]
        input_sequence = tensor[start_frame:start_frame + self.seq_len]
        target_frame = tensor[start_frame + self.seq_len]
        
        return input_sequence, target_frame

In [7]:
generator = torch.Generator().manual_seed(0)

dataset = SequenceDataset(
    tensors=sequences,
    seq_len=configs["SEQ_LEN"]
)

train_ds, valid_ds = random_split(dataset, [0.9, 0.1], generator=generator)

train_dl = DataLoader(train_ds, batch_size=configs["BATCH_SIZE"], shuffle=configs["DATASET_SHUFFLE"])
valid_dl = DataLoader(valid_ds, batch_size=configs["BATCH_SIZE"], shuffle=False)

print(f"Number of training batches: {len(train_dl)}")
print(f"Number of validation batches: {len(valid_dl)}")

Number of training batches: 545
Number of validation batches: 61


In [11]:
def train_one_epoch(model):
    model.to(device)
    model.train()

    loss_func = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    total_loss = 0.0
    with tqdm(train_dl, desc="Training") as pb:
        for i, batch in enumerate(pb):
            sequence, next_frame = batch
            sequence, next_frame = sequence.to(device), next_frame.to(device)
    
            optimizer.zero_grad()
            predicted_next_frame = model(sequence)
            loss = loss_func(predicted_next_frame, next_frame)
            loss.backward()
            optimizer.step()
    
            total_loss += loss.item()
            pb.set_postfix(loss=total_loss/(i+1))
    return total_loss / len(train_dl)

def validate_model(model):
    model.eval()
    model.to(device)

    loss_func = torch.nn.MSELoss()
    total_loss = 0.0
    with tqdm(valid_dl, desc="Validation") as pb:
        for i, batch in enumerate(pb):
            sequence, next_frame = batch
            sequence, next_frame = sequence.to(device), next_frame.to(device)

            with torch.no_grad():
                predicted_next_frame = model(sequence)
                
            loss = loss_func(predicted_next_frame, next_frame)
            total_loss += loss.item()
            pb.set_postfix(loss=total_loss/(i+1))
    return total_loss / len(valid_dl)

In [None]:
model = FramePredictor()
for epoch in trange(configs["EPOCHS"]):
    print(f"Using device {device}")
    print(f"Epoch {epoch}")
    train_loss = train_one_epoch(model)
    valid_loss = validate_model(model)
    print(f"Epoch Train loss: {train_loss}")
    print(f"Epoch Validation loss: {valid_loss}")

In [12]:
valid_loss = validate_model(model)

Validation:   0%|          | 0/61 [00:00<?, ?it/s]