In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import glob
import cv2
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from tqdm import tqdm
import time

# Set up DataLoader

### Define text to int function

In [2]:
alphabet = " ABCDEFGHIJKLMNOPQRSTUVWXYZ"
blank_idx = 0
char2idx = {ch: i+1 for i, ch in enumerate(alphabet)} # Leave one blank token for CTC
idx2char = {v: k for k, v in char2idx.items()}

def text_to_int_sequence(text, char2idx):
    text = text.upper()
    sequence = []
    for ch in text:
        if ch in char2idx:
            sequence.append(char2idx[ch])
        # else be blank
    return sequence

### Define dataset class

In [3]:
class BBCNewsVideoDataset(Dataset):
    """
    A dataset that reads (video, transcript) pairs from 'pretrain' or 'main' directories.
    
    Each ID folder contains matching mp4/txt files named e.g. "00001.mp4" and "00001.txt".
    Parse the lines from .txt to extract the transcript text,
    and read frames from the corresponding .mp4 for the video data.
    """
    def __init__(self, 
                 root_dir,        # e.g. "/kaggle/input/my_data"
                 mode='pretrain', # or 'main'
                 transform=None,  # optional transforms on frames
                 max_frames=75):
        """
        :param root_dir: path to the folder containing 'pretrain' and 'main' subdirs
        :param mode: which subdir to read from ('pretrain' or 'main')
        :param transform: optional torchvision transforms for the frames
        :param max_frames: if you want to limit frames per clip (just an example)
        """
        super().__init__()
        self.root_dir = os.path.join(root_dir, mode)
        self.transform = transform
        self.max_frames = max_frames
        
        # Gather all mp4 files recursively
        # For example: root_dir/mode/*/*.mp4
        self.video_paths = sorted(glob.glob(os.path.join(self.root_dir, '*', '*.mp4')))
        
        # We'll derive the matching txt path by replacing .mp4 with .txt
        # or just see if it exists
        self.data = []
        for vp in self.video_paths:
            txt_path = vp.replace('.mp4', '.txt')
            if os.path.exists(txt_path):
                self.data.append((vp, txt_path))
            else:
                # If there's no matching txt, skip
                continue
                
    def __len__(self):
        return len(self.data)
    
    def parse_transcript(self, txt_path):
        """
        Reads the .txt file to extract the transcript text after 'Text:' line.
        For your example, we ignore Conf, WORD lines, etc.
        """
        transcript = ""
        with open(txt_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith("Text:"):
                    # Everything after "Text:" is the transcript
                    transcript = line.replace("Text:", "").strip()
                    # remove trailing "Conf:" if it's on the same line (some are multiline)
                    if "Conf:" in transcript:
                        transcript = transcript.split("Conf:")[0].strip()
                    break
        return transcript
    
    def read_video(self, video_path):
        """
        Reads frames from mp4 using OpenCV (cv2) into a list of frames (H,W,3).
        Optionally limit to self.max_frames frames.
        """
        frames = []
        cap = cv2.VideoCapture(video_path)
        frame_count = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            # Convert BGR -> RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
            frame_count += 1
            if self.max_frames is not None and frame_count >= self.max_frames:
                break
        cap.release()
        return frames
    
    def __getitem__(self, idx):
        video_path, txt_path = self.data[idx]
        
        # Parse transcript
        transcript = self.parse_transcript(txt_path)
        
        # Read frames
        frames = self.read_video(video_path)  # list of np arrays, shape (H, W, 3)
        
        # Apply any transform to each frame
        if self.transform:
            frames = [self.transform(img) for img in frames]
        else:
            # Convert frames to torch Tensors if no transform
            # shape => (C, H, W)
            frames = [torch.from_numpy(img).permute(2,0,1) for img in frames]
        
        # Stack into a single tensor => shape (T, C, H, W)
        # T = number of frames
        video_tensor = torch.stack(frames, dim=0).float()

        # Return
        #   video_tensor: shape (T, 3, H, W)
        #   transcript: a string
        return video_tensor, transcript

# For padding the videos to have the same number of frames and converting the chars into nums for ctc
def collate_fn_ctc(batch):
    """
    batch: list of (frames, transcript_str)
        1) Convert transcript_str -> numeric
        2) Pad frames in time dimension
        3) Flatten targets
        4) Return (frames, targets, input_lengths, target_lengths)
    """
    # Sort by descending frames length
    batch.sort(key=lambda x: x[0].shape[0], reverse=True)

    frames_list, targets_list = [], []
    input_lengths, target_lengths = [], []
    max_len = 0

    for (video_tensor, txt) in batch:
        T = video_tensor.shape[0]
        if T > max_len:
            max_len = T
        
    # Convert text -> numeric, build up final batch
    for (video_tensor, txt) in batch:
        frames_list.append(video_tensor)
        input_lengths.append(video_tensor.shape[0])

        numeric_seq = text_to_int_sequence(txt, char2idx)
        target_lengths.append(len(numeric_seq))
        targets_list.append(torch.tensor(numeric_seq, dtype=torch.long))

    # Pad frames to max_len
    padded_frames = []
    for vid in frames_list:
        T = vid.shape[0]
        if T < max_len:
            pad_amt = max_len - T
            vid = torch.nn.functional.pad(vid, (0,0,0,0,0,0,0,pad_amt))  # pad time dim
        padded_frames.append(vid)
    
    frames_tensor = torch.stack(padded_frames, dim=0)  # => (B, max_len, C, H, W)
    concat_targets = torch.cat(targets_list, dim=0)
    
    input_lengths = torch.tensor(input_lengths, dtype=torch.long)
    target_lengths = torch.tensor(target_lengths, dtype=torch.long)
    
    return frames_tensor, concat_targets, input_lengths, target_lengths

### TEST Dataset Class

In [4]:
root_dir = "mvlrs_v1"

# Optional: define a transform for frames
# e.g. resize to (50,100) for LipNet style
transform = T.Compose([
    T.ToPILImage(),
    T.Resize((50,100)),  # (H, W) 
    T.ToTensor()
])

# Create a dataset for the 'pretrain' directory
pretrain_dataset = BBCNewsVideoDataset(root_dir, mode='pretrain', transform=transform)
print("Pretrain dataset size:", len(pretrain_dataset))

# Example item
sample_video_tensor, sample_transcript = pretrain_dataset[0]
print("sample_video_tensor shape:", sample_video_tensor.shape)
print("sample_transcript:", sample_transcript)

# Create a DataLoader
pretrain_loader = DataLoader(pretrain_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn_ctc)

# Example iteration
for batch in pretrain_loader:
    frames_tensor, concat_targets, input_lengths, target_lengths = batch  # but watch out, "videos" might be list of T x C x H x W
    print("Batch video frames:", len(frames_tensor))
    print("Batch transcript char length:", len(concat_targets))
    print("Batch frame nums:", input_lengths)
    print("Batch transcript lengths:", target_lengths)
    break

Pretrain dataset size: 96318
sample_video_tensor shape: torch.Size([75, 3, 50, 100])
sample_transcript: THESE DAYS WHEN YOU'RE COOKING CHIPS AT HOME THE TRADITIONAL CHIP PAN OFTEN STAYS ON THE SHELF IN FAVOUR OF A BAKING TRAY AND A BAG OF FROZEN OVEN
Batch video frames: 2
Batch transcript char length: 216
Batch frame nums: tensor([75, 75])
Batch transcript lengths: tensor([ 86, 130])


# Define Model Architecture

In [5]:
class LipNetPyTorch(nn.Module):
    def __init__(self,
                 img_c=3,
                 img_w=100,
                 img_h=50,
                 frames_n=75,
                 output_size=28,        # number of characters + 1 for blank
                 absolute_max_string_len=32  # not strictly needed in PyTorch model??
                 ):
        super(LipNetPyTorch, self).__init__()
        
        # 3D Conv block #1
        self.conv1 = nn.Conv3d(in_channels=img_c,
                               out_channels=32,
                               kernel_size=(3, 5, 5),
                               stride=(1, 2, 2),
                               padding=(1, 2, 2))
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2),
                                  stride=(1, 2, 2))
        self.drop1 = nn.Dropout(0.5)
        
        # 3D Conv block #2
        self.conv2 = nn.Conv3d(in_channels=32,
                               out_channels=64,
                               kernel_size=(3, 5, 5),
                               stride=(1, 1, 1),
                               padding=(1, 2, 2))
        self.pool2 = nn.MaxPool3d(kernel_size=(1, 2, 2),
                                  stride=(1, 2, 2))
        self.drop2 = nn.Dropout(0.5)
        
        # 3D Conv block #3
        self.conv3 = nn.Conv3d(in_channels=64,
                               out_channels=96,
                               kernel_size=(3, 3, 3),
                               stride=(1, 1, 1),
                               padding=(1, 1, 1))
        self.pool3 = nn.MaxPool3d(kernel_size=(1, 2, 2),
                                  stride=(1, 2, 2))
        self.drop3 = nn.Dropout(0.5)
        
        # After these 3D convs, the shape in time dimension should remain ~frames_n
        # but height/width get downsampled heavily by stride/pool
        # We'll flatten spatial dims but keep the time dim for the RNN

        # Dimensionality going into the GRUs:
        # original: (W=100, H=50)
        # conv1+pool1 => W -> (100/2/2)=25, H->(50/2/2)=12 (accounting for strides/pools)
        # conv2+pool2 => W -> 25/2=12,   H->12/2=6
        # conv3+pool3 => W -> 12/2=6,    H->6/2=3
        # => final is (N, 96, T, 3, 6)
        # => flattened per frame => 96*3*6 = 1728

        self.gru_hidden_size = 256
        self.num_gru_layers = 2
        
        # Bi-directional GRU
        self.gru = nn.GRU(input_size=1728,
                          hidden_size=self.gru_hidden_size,
                          num_layers=self.num_gru_layers,
                          batch_first=True,
                          bidirectional=True)
        
        # Final linear layer to project onto output_size
        # Because it’s bidirectional, output size is 2 * gru_hidden_size
        self.fc = nn.Linear(self.gru_hidden_size * 2, output_size)

    def forward(self, x):
        """
        x shape expected as (batch, channels=3, frames=75, height=50, width=100)
        If your data is (batch, frames, height, width, channels),
        be sure to permute it before calling forward: x.permute(0,4,1,2,3)
        """
        # (1) 3D conv/pool #1
        x = self.conv1(x)  # => (batch, 32, frames, H/2, W/2)
        x = F.relu(x)
        x = self.pool1(x)  # => (batch, 32, frames, H/4, W/4)
        x = self.drop1(x)
        
        # (2) 3D conv/pool #2
        x = self.conv2(x)  # => (batch, 64, frames, H/4, W/4)
        x = F.relu(x)
        x = self.pool2(x)  # => (batch, 64, frames, H/8, W/8)
        x = self.drop2(x)
        
        # (3) 3D conv/pool #3
        x = self.conv3(x)  # => (batch, 96, frames, H/8, W/8)
        x = F.relu(x)
        x = self.pool3(x)  # => (batch, 96, frames, H/16, W/16)
        x = self.drop3(x)
        
        # Now flatten the spatial dims but keep the time dim
        # x shape is (batch, 96, T, H’, W’)
        b, c, t, h, w = x.size()
        # -> (batch, t, c*h*w)
        x = x.permute(0, 2, 1, 3, 4)
        x = x.reshape(b, t, c*h*w)
        
        # (4) Bi-GRU
        # x => shape (batch, time, features)
        x, _ = self.gru(x)  # => (batch, time, 2*gru_hidden_size)

        # (5) FC => output_size
        logits = self.fc(x)  # => (batch, time, output_size)

        # For CTC, you’ll typically feed log_probs = F.log_softmax(logits, dim=2)
        return logits

In [6]:
# quick test with a dummy input:
model = LipNetPyTorch()
# dummy input: batch=2, channels=3, frames=75, H=50, W=100
dummy_input = torch.randn(2, 3, 75, 50, 100)
out = model(dummy_input)
print("Output shape:", out.shape)
# Should be (2, 75, 28)

Output shape: torch.Size([2, 75, 28])


# Define Train Loop

In [7]:
#   frames   => (batch, 3, T, 50, 100)   # videos
#   targets  => 1D Tensor of all targets concatenated
#   input_lengths => lengths of each sequence in frames
#   target_lengths => lengths of each transcription

def train_step(frames, targets, input_lengths, target_lengths):
    frames = frames.cuda()
    targets = targets.cuda()

    optimizer.zero_grad()
    
    # (batch, time, output_size)
    logits = model(frames)
    # PyTorch’s CTC wants => (time, batch, class)
    logits_for_ctc = logits.permute(1, 0, 2)  # => (T, N, C)
    
    # Compute log probs
    log_probs = F.log_softmax(logits_for_ctc, dim=2)

    loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
    loss.backward()
    optimizer.step()
    return loss.item()

In [8]:
def train_model(model, train_loader, ctc_loss, optimizer, num_epochs=10, device='cuda'):
    """
    Args:
        model: your LipNetPyTorch model
        train_loader: DataLoader yielding (frames, targets, input_lengths, target_lengths)
        ctc_loss: nn.CTCLoss (or similar)
        optimizer: e.g. torch.optim.Adam(model.parameters())
        num_epochs: total epochs to train
        device: 'cuda' or 'cpu'
    """
    
    model.to(device)
    model.train()
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        total_loss = 0.0
        
        for batch_idx, (frames, targets, input_lengths, target_lengths) in enumerate(
            tqdm(train_loader, desc=f"Epoch {epoch+1}", unit="batch")
        ):
            # frames shape is (B, T, C, H, W). 
            # need (B, C, T, H, W) for the model
            frames = frames.permute(0, 2, 1, 3, 4)  # => (B, C, T, H, W)

            # Perform a single train step
            loss_value = train_step(frames, targets, input_lengths, target_lengths)

            total_loss += loss_value
        
        avg_loss = total_loss / len(train_loader)
        epoch_time = time.time() - epoch_start
        print(f"Epoch [{epoch+1}/{num_epochs}] took {epoch_time:.2f} seconds, Loss: {avg_loss:.4f}")


# Train

In [9]:
model = LipNetPyTorch()
if torch.cuda.device_count() > 1:
    print("Using DataParallel on", torch.cuda.device_count(), "GPUs!")
    model = torch.nn.DataParallel(model)
model = model.cuda()
ctc_loss = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Set up dataloader
root_dir = "mvlrs_v1"

transform = T.Compose([
    T.ToPILImage(),
    T.Resize((50,100)),  # (H, W) 
    T.ToTensor()
])

main_dataset = BBCNewsVideoDataset(root_dir, mode='main', transform=transform)
print("Main dataset size:", len(main_dataset))
main_loader = DataLoader(main_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn_ctc)

train_model(model, main_loader, ctc_loss, optimizer, num_epochs=1, device='cuda')

Main dataset size: 48165


Epoch 1:   4%|▍         | 65/1506 [01:53<42:06,  1.75s/batch]


KeyboardInterrupt: 