In [None]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import imageio
import math

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
vocab = [x for x in "abcdefghijklmnopqrstuvwxyz'?!123456789 "]

def get_word2idx_idx2word(vocab):
    word2idx = {word: idx+1 for idx, word in enumerate(vocab)}
    word2idx['<PAD>'] = 0 #len(word2idx)
    # word2idx['<START>'] = len(word2idx)
    # word2idx['<END>'] = len(word2idx)
    # word2idx['<UNK>'] = len(word2idx)
    idx2word = {idx+1: word for idx, word in enumerate(vocab)}
    idx2word[0] = '<PAD>'
    # idx2word[len(idx2word)] = '<UNK>'
    return word2idx, idx2word

def char_to_num(texts, word2idx):
    return [word2idx[char] for char in texts if char in word2idx]

def num_to_char(nums, idx2word):
    return [idx2word.get(num, '') for num in nums]

In [None]:
class LipDataset(Dataset):
    def __init__(self, data_dir: str, label_dir: str, vocab: list, word2idx: dict, idx2word: dict, transform=transforms.ToTensor()) -> None:
        self.data_dir = data_dir
        self.label_dir = label_dir
        self.transform = transform
        self.data = os.listdir(data_dir)
        self.data.remove('sgib8n.mpg')
        self.label = os.listdir(label_dir)
        self.vocab = vocab
        self.word2idx = word2idx
        self.idx2word = idx2word

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            data_path = os.path.join(self.data_dir, self.data[idx])
            label_file = self.data[idx].split(".")[0] + ".align"
            label_path = os.path.join(self.label_dir, label_file)

            assert os.path.exists(data_path), f"Data path {data_path} does not exist"
            assert os.path.exists(label_path), f"Label path {label_path} does not exist"

            assert (
                data_path.split("/")[-1].split(".")[0]
                == label_path.split("/")[-1].split(".")[0]
            ), "Data and label file names do not match"

            frames = self.load_video(data_path)
            if frames is None:
                print(idx)

            label = self.load_alignment(label_path)
            
#             print(idx, label_file)

            return frames, label
        except Exception as e:
            print(idx, e)

    def load_video(self, path: str) -> torch.Tensor:
        cap = cv2.VideoCapture(path)
        num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        frames = []
        for i in range(num_frames):
            ret, frame = cap.read()
            
            # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            
            frame = frame[
                190:236, 80:220, :
            ]  # TODO: Make it dynamic using dlib  # Take only the lip part of the frame
            
            if self.transform:
                frame = self.transform(frame)

            frames.append(frame)

        cap.release()
        
        frames = torch.stack(frames)
        
        std = torch.std(frames)
        mean = torch.mean(frames)
#         print(std, mean)
        frames = (frames - mean) / std # Normalize the frames (z-score normalization

        return frames # (T, H, W, C)
    
    
    def load_alignment(self, path: str) -> torch.Tensor:
        with open(path, "r") as f:
            lines = f.readlines() 
        tokens = []
        for line in lines:
            line = line.split()
            if line[2] != "sil":
                # tokens = [*tokens, ' ',line[2]]
                tokens.append(' ')
                tokens.extend(list(line[2]))  

        token_nums = char_to_num(tokens, self.word2idx)

        
        return torch.tensor(token_nums[1:], dtype=torch.long)

In [None]:
word2idx, idx2word = get_word2idx_idx2word(vocab)

data_transform = transforms.Compose(
    [
        transforms.ToPILImage(),                     
        transforms.ToTensor(),       
        transforms.Grayscale(num_output_channels=1),
#         transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]),
#         transforms.Normalize(mean=[0.7136, 0.4906, 0.3283],
#                 std=[0.113855171, 0.107828568, 0.0917060521]),
         
    ]
)

data_dir = "/kaggle/input/lipnet-videos/s1"
label_dir = "/kaggle/input/lipnet-videos/alignments/s1"

dataset = LipDataset(data_dir, label_dir, transform=data_transform, vocab=vocab, word2idx=word2idx, idx2word=idx2word)

In [None]:
def collate_fn(batch, pad_value=0):
    frames, labels = zip(*batch)

    # Pad the frames to the same length
    max_len = max([f.shape[0] for f in frames])
    frames = [torch.nn.functional.pad(input=f, pad=(0, 0, 0, 0, 0, 0, 0, max_len - f.shape[0]), mode='constant', value=0) for f in frames] 
    
    # Pad the labels to the same length
    max_len = max([l.shape[0] for l in labels])  # noqa: E741
    labels = [torch.nn.functional.pad(input=l, pad=(0, max_len - l.shape[0]), mode='constant', value=pad_value) for l in labels]  # noqa: E741
    
    return torch.stack(frames), torch.stack(labels)

In [None]:
frames, label = dataset[0]
print(frames.shape, label, label.shape)

In [None]:
plt.imshow(frames[23].permute(1,2,0))

In [None]:
print(''.join(num_to_char(label.tolist(), idx2word)))

In [None]:
def split_dataset(dataset, val_split=0.2):
    n_val = int(len(dataset) * val_split)
    n_train = len(dataset) - n_val
    return torch.utils.data.random_split(dataset, [n_train, n_val])

train_dataset, val_dataset = split_dataset(dataset)
print(len(train_dataset), len(val_dataset))

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=collate_fn, pin_memory=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=10, shuffle=False, collate_fn=collate_fn, pin_memory=True, num_workers=4)

In [None]:
class LipNet(nn.Module):
    def __init__(self, vocab_size, input_size, hidden_size=128, dropout=0.5, input_channels=1):
        super().__init__()
                
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels=input_channels, out_channels=128, kernel_size=(3,3,3), stride=(1, 1, 1), padding=(1, 1, 1)),
            nn.BatchNorm3d(128),
            nn.ReLU(True),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            # nn.Dropout3d(dropout),
            
            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=(3,3,3), stride=(1, 1, 1), padding=(1, 1, 1)),
            nn.BatchNorm3d(256),
            nn.ReLU(True),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            # nn.Dropout3d(dropout),
            
            nn.Conv3d(in_channels=256, out_channels=75, kernel_size=(3,3,3), stride=(1, 1, 1), padding=(1, 1, 1)),
            nn.BatchNorm3d(75),
            nn.ReLU(True),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            # nn.Dropout3d(dropout)
        )
        
        
        self.lstm1 = nn.LSTM(input_size=75 * (46 // 8) * (140 // 8), hidden_size=hidden_size,
                             num_layers=1, batch_first=True, bidirectional=True)
        self.dropout1 = nn.Dropout(dropout)
        
        self.lstm2 = nn.LSTM(input_size=256, hidden_size=hidden_size,
                             num_layers=1, batch_first=True, bidirectional=True)
        self.dropout2 = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * 2, vocab_size+1)
        
        self.initialize_weights()
        
    def forward(self, x):
        x = self.conv(x)
#         print(x.shape)
        x = x.permute(2, 0, 1, 3, 4).contiguous()
#         print(x.shape)
        # (B, C, T, H, W)->(T, B, C*H*W)
        x = x.view(x.size(0), x.size(1), -1)
#         print(x.shape)
        
        self.lstm1.flatten_parameters()
        self.lstm2.flatten_parameters()
        
        x, _ = self.lstm1(x)
        x = self.dropout1(x)
        
        x, _ = self.lstm2(x)
        x = self.dropout2(x)
        
        x = self.fc(x)
        x = x.permute(1, 0, 2).contiguous()
        return x
    
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.LSTM):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        nn.init.orthogonal_(param)
                    elif 'bias' in name:
                        nn.init.constant_(param, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
                
        print('Model weights initialized.')

In [None]:
 # deterministic training
seed = 86
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
# torch.backends.cudnn.deterministic = True
model = LipNet(vocab_size=len(word2idx), input_size=75, input_channels=1).to(device)
# model

In [None]:
# torch.cuda.empty_cache()
# output = model(frames.permute(0,2,1,3,4))

In [None]:
def ctc_greedy_decode(y_pred, idx2word, blank_index=40):
    # y_pred: tensor of shape (max_time_steps, batch_size, num_classes)
    
    # Get the predicted class index for each time step
    y_pred_softmax = F.softmax(y_pred, dim=2)
    max_indices = torch.argmax(y_pred_softmax, dim=2)  # Shape: (max_time_steps, batch_size)

    # Decode sequences
    decoded_sequences = []
    for seq in max_indices.permute(1, 0):  # Shape: (batch_size, max_time_steps)
        decoded_seq = []
        prev_index = -1
        for index in seq:
            # Skip duplicates and blank
            if index != blank_index and index != prev_index:
                decoded_seq.append(index.item())
                prev_index = index.item()
#         print(num_to_char(decoded_seq, idx2word))
        decoded_sequences.append(decoded_seq)
    print(decoded_sequences)
    return decoded_sequences

In [None]:
criterion = nn.CTCLoss(reduction='mean', zero_infinity=True, blank = 40)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

lambda_lr = lambda epoch: 1.0 if epoch < 40 else math.exp(-0.1 * (epoch - 39))
# Use LambdaLR with your custom schedule
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_lr)

In [None]:

def ctc_loss_fn(y_true, y_pred, ctc_loss):
    batch_len = y_true.size(0)  # Number of sequences in the batch
    input_length = y_pred.size(1)  # Time steps per batch sequence
    
    # Correctly create input_lengths with shape (batch_len,)
    input_lengths = torch.full((batch_len,), input_length, dtype=torch.int32)

    # Calculate target lengths based on actual lengths of sequences in y_true
    target_lengths = torch.tensor([len(seq[seq != 0]) for seq in y_true], dtype=torch.int32)

    # print(input_lengths, target_lengths, y_true.size(), y_pred.shape)
    
    y_true_flattened = y_true[y_true != 0].view(-1)  # Flattening while ignoring padding
    
    y_preds_logits = y_pred.permute(1,0,2).log_softmax(dim=-1)


    loss = ctc_loss(y_preds_logits, y_true_flattened, input_lengths, target_lengths)
    
    return loss



def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path, is_best=False):
    
    if not os.path.exists(os.path.dirname(checkpoint_path)):
        os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
    
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'loss': loss
    }
    torch.save(checkpoint, checkpoint_path)
    print(f'Checkpoint saved at {checkpoint_path}')
    
    if is_best:
        best_path = checkpoint_path.replace('.pt', '_best.pt')
        torch.save(checkpoint, best_path)
        print(f'Best model saved at {best_path}')

In [None]:
def train(model, dataloader, criterion, optimizer, device, lr_scheduler,print_every=40):
    model.train()
    
    total_loss = 0.0
    
    for i, (frames, labels) in enumerate(dataloader):
        frames, labels = frames.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        output = model(frames.permute(0,2,1,3,4))
        

        loss = ctc_loss_fn(labels, output, criterion)
         
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        frames, labels = frames.cpu(), labels.cpu()
        
        if (i+1) % print_every == 0:
            ctc_greedy_decode(output.clone(), idx2word)
            print(f'Batch {i+1}/{len(dataloader)} - Loss: {loss.item()}')
    
    lr_scheduler.step()
    print(f'Learning rate: {lr_scheduler.get_last_lr()}')
            
    return total_loss / len(dataloader)


def evaluate(model, dataloader, criterion, device, print_every=10):
    model.eval()
    
    total_loss = 0.0
    
    with torch.inference_mode():
        for i, (frames, labels) in enumerate(dataloader):
            frames, labels = frames.to(device), labels.to(device)
            
            output = model(frames.permute(0,2,1,3,4))
             
            loss = ctc_loss_fn(labels, output, criterion)
            
            total_loss += loss.item()
            
            frames, labels = frames.cpu(), labels.cpu()
            
            if (i+1) % print_every == 0:
                ctc_greedy_decode(output.clone(), idx2word)
                print(f'Batch {i+1}/{len(dataloader)} - Loss: {loss.item()}')
                
            
    return total_loss / len(dataloader)

In [None]:
def train_model(model, train_dataloader, val_dataloader, criterion, optimizer, lr_scheduler, num_epochs, device, checkpoint_path='/kaggle/working/check.pt'):

    loss_history = {'train': [], 'val': []}    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs}')
        
        train_loss = train(model, train_dataloader, criterion, optimizer, device, lr_scheduler)
        loss_history['train'].append(train_loss)
        
        val_loss = evaluate(model, val_dataloader, criterion, device)
        loss_history['val'].append(val_loss)
        
        print(f'Train Loss: {train_loss} - Val Loss: {val_loss}')
        
        save_checkpoint(model, optimizer, epoch, val_loss, checkpoint_path, False)
        
        
    return loss_history

num_epochs = 100
loss_history = train_model(model, train_dataloader, val_dataloader, criterion, optimizer, lr_scheduler , num_epochs, device)

In [None]:
model.eval()
frames, labels = train_dataset[10]

In [None]:
output = model(frames.unsqueeze(0).permute(0,2,1,3,4).to(device))

In [None]:
# output = output.permute(1, 0, 2)
print('Output shape:', output.shape)
ctcts = ctc_greedy_decode(output, idx2word)


In [None]:
for i in ctcts:
    print(num_to_char(i, idx2word))