In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Check if mount was successful
import os
print("Drive mounted successfully!")
print("Contents of drive:", os.listdir('/content/drive/MyDrive'))

Mounted at /content/drive
Drive mounted successfully!
Contents of drive: ['processed']


In [None]:
!pip install torch torchvision torchaudio
!pip install opencv-python
!pip install tqdm

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
import cv2
import json
from tqdm import tqdm

In [None]:
VIDEO_BASE_DIR = "/content/drive/MyDrive/processed"  # Update this path
LABEL_MAP_PATH = "/content/WLASL_v0.3.json"

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_FRAMES = 16  # Use a fixed number of frames per video

print(f"Using device: {DEVICE}")
print(f"Video base directory: {VIDEO_BASE_DIR}")
print(f"Label map path: {LABEL_MAP_PATH}")

Using device: cuda
Video base directory: /content/drive/MyDrive/processed
Label map path: /content/WLASL_v0.3.json


In [None]:
if not os.path.exists(VIDEO_BASE_DIR):
    print(f"❌ ERROR: Video base directory not found: {VIDEO_BASE_DIR}")
    print("Please update the VIDEO_BASE_DIR path in Section 3")
else:
    print(f"✅ Video base directory found: {VIDEO_BASE_DIR}")

# Check if label map file exists
if not os.path.exists(LABEL_MAP_PATH):
    print(f"❌ ERROR: Label map not found: {LABEL_MAP_PATH}")
    print("Please update the LABEL_MAP_PATH in Section 3")
else:
    print(f"✅ Label map found: {LABEL_MAP_PATH}")

# Load and process label map
try:
    with open(LABEL_MAP_PATH, "r") as f:
        label_data = json.load(f)  # This is a list of dicts
        glosses = sorted(set(item["gloss"] for item in label_data))  # Unique gloss names
        label_to_idx = {gloss: idx for idx, gloss in enumerate(glosses)}
        idx_to_label = {idx: gloss for gloss, idx in label_to_idx.items()}
    print(f"✅ Label map loaded successfully. Number of classes: {len(label_to_idx)}")
except Exception as e:
    print(f"❌ Error loading label map: {e}")

✅ Video base directory found: /content/drive/MyDrive/processed
✅ Label map found: /content/WLASL_v0.3.json
✅ Label map loaded successfully. Number of classes: 2000


In [None]:
class SignDataset(Dataset):
    def __init__(self, split_dir, transform=None):
        self.video_paths = []
        self.labels = []
        for root, _, files in os.walk(split_dir):
            for f in files:
                if f.lower().endswith(".mp4"):
                    full_path = os.path.join(root, f)
                    self.video_paths.append(full_path)
                    class_id = os.path.basename(os.path.dirname(full_path))
                    self.labels.append(int(class_id))  # class_id like "0", "1", etc.
        if not self.video_paths:
            raise RuntimeError(f"No videos found in {split_dir}")
        self.transform = transform
        print(f"Dataset initialized with {len(self.video_paths)} videos")

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        frames = self.load_video(video_path)
        return frames, label

    def load_video(self, path):
        cap = cv2.VideoCapture(path)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = cv2.resize(frame, (224, 224))
            # Convert BGR to RGB for proper color handling
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            if self.transform:
                frame = self.transform(frame)
            frames.append(frame)
        cap.release()

        if len(frames) == 0:
            raise RuntimeError(f"No frames extracted from {path}")

        if len(frames) > MAX_FRAMES:
            frames = frames[:MAX_FRAMES]
        elif len(frames) < MAX_FRAMES:
            frames += [frames[-1]] * (MAX_FRAMES - len(frames))

        return torch.stack(frames).permute(1, 0, 2, 3)  # (C, T, H, W)

In [None]:
class SignLanguageModel(nn.Module):
    def __init__(self, hidden_dim=256, num_classes=300):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.cnn = nn.Sequential(*list(resnet.children())[:-1])  # remove final FC
        self.rnn = nn.GRU(input_size=512, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        B, C, T, H, W = x.size()
        x = x.permute(0, 2, 1, 3, 4)  # (B, T, C, H, W)
        cnn_feats = []
        for t in range(T):
            out = self.cnn(x[:, t]).squeeze()  # (B, 512)
            cnn_feats.append(out)
        cnn_feats = torch.stack(cnn_feats, dim=1)  # (B, T, 512)
        _, h_n = self.rnn(cnn_feats)
        out = self.fc(h_n.squeeze(0))
        return out

In [None]:
try:
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_split_path = os.path.join(VIDEO_BASE_DIR, "train")
    print(f"Looking for training data in: {train_split_path}")

    if os.path.exists(train_split_path):
        dataset = SignDataset(train_split_path, transform=transform)
        print(f"✅ Dataset loaded successfully with {len(dataset)} videos")

        # Test loading one sample
        sample_video, sample_label = dataset[0]
        print(f"Sample video shape: {sample_video.shape}")
        print(f"Sample label: {sample_label}")
    else:
        print(f"❌ Training directory not found: {train_split_path}")

except Exception as e:
    print(f"❌ Error loading dataset: {e}")

Looking for training data in: /content/drive/MyDrive/processed/train
Dataset initialized with 1897 videos
✅ Dataset loaded successfully with 1897 videos
Sample video shape: torch.Size([3, 16, 224, 224])
Sample label: 96


In [None]:
def train():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_split_path = os.path.join(VIDEO_BASE_DIR, "train")
    dataset = SignDataset(train_split_path, transform=transform)

    # Use drop_last=True to avoid batch size issues
    loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, drop_last=True)

    # Determine number of classes from dataset
    num_classes = len(set(dataset.labels))
    print(f"Number of classes detected: {num_classes}")

    model = SignLanguageModel(num_classes=num_classes).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print("Starting training...")
    print(f"Total batches per epoch: {len(loader)}")

    for epoch in range(20):
        model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch_idx, (videos, labels) in enumerate(tqdm(loader, desc=f"Epoch {epoch+1}")):
            try:
                videos, labels = videos.to(DEVICE), labels.to(DEVICE)
                print(f"Batch {batch_idx}: Video shape: {videos.shape}, Labels shape: {labels.shape}")

                preds = model(videos)
                loss = loss_fn(preds, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                _, predicted = torch.max(preds.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                print(f"Video shape: {videos.shape if 'videos' in locals() else 'Not loaded'}")
                break

        avg_loss = total_loss / len(loader)
        accuracy = 100 * correct / total if total > 0 else 0
        print(f"✅ Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Save model
    model_save_path = "/content/drive/MyDrive/sign_language_model.pth"
    torch.save(model.state_dict(), model_save_path)
    print(f"✅ Model saved to: {model_save_path}")

# ============================================================================
# SECTION 9: RUN TRAINING
# ============================================================================

# Run training
print("🚀 Starting Sign Language Model Training...")
print("=" * 50)

try:
    train()
    print("🎉 Training completed successfully!")
except Exception as e:
    print(f"❌ Training failed with error: {e}")
    print("Please check your data paths and try again.")