In [1]:
!pip install av
!pip install rarfile
!pip install wget



In [4]:
import os
import ssl
import wget
import urllib.request
import rarfile
import zipfile

context = ssl._create_unverified_context()
def download_url(url, path):
    print(f"downloading {url}...")
    with urllib.request.urlopen(url, context=context) as response:
        with open(path, 'wb') as f:
            f.write(response.read())

def main():
    data_dir = "data"
    os.makedirs(data_dir, exist_ok=True)

    video_url = "https://www.crcv.ucf.edu/data/UCF101/UCF101.rar"
    annotation_url = "https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip"

    download_url(video_url, os.path.join(data_dir, "UCF101.rar"))
    download_url(annotation_url, os.path.join(data_dir, "UCF101_annotations.zip"))

    print("extracting video data...")
    rf = rarfile.RarFile(os.path.join(data_dir, "UCF101.rar"))
    rf.extractall(data_dir)

    print("extracting annotations...")
    with zipfile.ZipFile(os.path.join(data_dir, "UCF101_annotations.zip")) as zf:
        zf.extractall(data_dir)

    os.remove(os.path.join(data_dir, "UCF101.rar"))
    os.remove(os.path.join(data_dir, "UCF101_annotations.zip"))

    print("downloading ast...")
    os.makedirs('pretrained_weights', exist_ok=True)
    wget.download('https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1', os.path.join('pretrained_weights', 'audioset_16_16_0.4422.pth'))
    print("done")

if __name__ == "__main__":
    main()


downloading https://www.crcv.ucf.edu/data/UCF101/UCF101.rar...
downloading https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip...
extracting video data...
extracting annotations...
downloading ast...
done


In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import torchaudio
from pathlib import Path
import torchvision.io as io


class UCF101Dataset(Dataset):
    def __init__(self, data_path, split_path, split="train", num_frames=8, t=4):
        self.data_path = Path(data_path)
        self.num_frames = num_frames
        self.t = t
        # read split file
        split_file = "trainlist01.txt" if split == "train" else "testlist01.txt"
        with open(os.path.join(split_path, split_file), "r") as f:
            self.video_list = [line.strip().split(" ")[0] for line in f.readlines()]

        # TODO add more randomizations
        self.video_transform = transforms.Compose(
            [
                transforms.Resize((256, 256)), # slightly larger for random crop
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        # create class to index mapping
        self.class_to_idx = {}
        classes = sorted(
            list(set(video_name.split("/")[0] for video_name in self.video_list))
        )
        for idx, classname in enumerate(classes):
            self.class_to_idx[classname] = idx

    def _load_video(self, video_path):
        try:
            vframes, _, _ = io.read_video(str(video_path), pts_unit="sec")
            total_frames = len(vframes)

            # ensure we don't sample beyond video length
            if total_frames < self.num_frames:
                indices = torch.linspace(0, total_frames - 1, total_frames).long()
                indices = torch.cat(
                    [
                        indices,
                        torch.tensor(
                            [total_frames - 1] * (self.num_frames - total_frames)
                        ),
                    ]
                )
            else:
                indices = torch.linspace(0, total_frames - 1, self.num_frames).long()

            frames = []
            for idx in indices:
                frame = vframes[idx]
                frame = Image.fromarray(frame.numpy())
                frame = self.video_transform(frame)
                frames.append(frame)

        except Exception as e:
            print(e)
            frames = [torch.zeros(3, 224, 224) for _ in range(self.num_frames)]

        return torch.stack(frames)  # [num_frames, c, h, w]

    # TODO specaugment
    def _load_audio(self, video_path):
        try:
            audio_array, sample_rate = torchaudio.load(str(video_path))
        except (RuntimeError, TypeError):
            # create a small amount of noise instead of pure zeros
            audio_array = torch.randn(1, 16000 * self.t) * 1e-4
            sample_rate = 16000

        # convert to mono
        if audio_array.shape[0] > 1:
            audio_array = torch.mean(audio_array, dim=0, keepdim=True)

        # resample if necessary
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            audio_array = resampler(audio_array)

        target_length = self.t * 16000
        if audio_array.shape[1] < target_length:
            # pad with zeros if audio is too short
            audio_array = torch.nn.functional.pad(
                audio_array, (0, target_length - audio_array.shape[1])
            )
        else:
            # trim if audio is too long
            audio_array = audio_array[:, :target_length]

        # create mel spectrogram
        spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,
            n_mels=128,
            n_fft=1024,
            win_length=1024,
            hop_length=160,
        )(audio_array)

        spectrogram = torchaudio.transforms.AmplitudeToDB()(spectrogram)
        spectrogram = spectrogram.squeeze(0)  # remove channel dimension

        if spectrogram.shape[1] > 400:
            spectrogram = spectrogram[:, :400]
        elif spectrogram.shape[1] < 400:
            spectrogram = torch.nn.functional.pad(
                spectrogram, (0, 400 - spectrogram.shape[1])
            )

        # mean=0 std=0.5 according to ast
        spectrogram = (spectrogram - spectrogram.mean()) / (spectrogram.std() + 1e-6) * 0.5

        return spectrogram.unsqueeze(0)  # add channel dimension back [1, 128, 100*t]

    def __getitem__(self, idx):
        video_name = self.video_list[idx]
        video_path = self.data_path / "UCF-101" / video_name

        label = video_name.split("/")[0]
        video_tensor = self._load_video(video_path)
        audio_tensor = self._load_audio(video_path)
        class_idx = self.class_to_idx[label]

        return video_tensor, audio_tensor, class_idx

    def __len__(self):
        return len(self.video_list)


def get_dataloaders(data_path, split_path, batch_size=2, num_workers=0):

    train_dataset = UCF101Dataset(
        data_path=data_path,
        split_path=os.path.join(split_path, "ucfTrainTestlist"),
        split="train",
    )

    val_dataset = UCF101Dataset(
        data_path=data_path,
        split_path=os.path.join(split_path, "ucfTrainTestlist"),
        split="test",
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader


In [3]:
import torch
import torch.nn as nn
import timm
from timm.layers import trunc_normal_


# TODO change hardcoded values
class PatchEmbed(nn.Module):
    def __init__(
        self, img_size=(128, 400), patch_size=(16, 16), in_chans=1, embed_dim=768
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        B, C, H, W = x.shape
        assert (
            H == self.img_size[0] and W == self.img_size[1]
        ), f"input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})"
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class Model(nn.Module):
    def __init__(self, num_classes=101, lf=10):
        super().__init__()

        # set num_classes=0 to remove classification head
        self.vv = timm.create_model(
            "vit_base_patch16_224.augreg_in21k", pretrained=True, num_classes=0
        )

        self.va = timm.create_model(
            "vit_base_patch16_224.augreg_in21k", pretrained=True, num_classes=0
        )

        # apply ast weights to va
        ast_weights = torch.load("pretrained_weights/audioset_16_16_0.4422.pth", weights_only=True)
        temp = self.va.state_dict()
        pretrained_dict = {}
        for k, v in ast_weights.items():
            if k.startswith("module."):
                k = k[7:]
            if k in temp and temp[k].shape == v.shape:
                pretrained_dict[k] = v
        temp.update(pretrained_dict)
        self.va.load_state_dict(temp)

        self.va.patch_embed = PatchEmbed(
            img_size=(128, 400),
            patch_size=(16, 16),
            in_chans=1,
            embed_dim=self.va.embed_dim,
        )

        # interpolate position embeddings for video
        num_patches_video = self.vv.patch_embed.num_patches
        self.vv.pos_embed = self.interpolate_pos_encoding(
            self.vv.pos_embed, num_patches_video
        )

        # interpolate position embeddings for audio
        num_patches_audio = self.va.patch_embed.num_patches
        self.va.pos_embed = self.interpolate_pos_encoding(
            self.va.pos_embed, num_patches_audio
        )

        # create new positional embeddings for fused sequence
        total_patches = num_patches_video + num_patches_audio + 2
        self.fused_pos_embed = nn.Parameter(
            torch.zeros(1, total_patches, self.vv.embed_dim)
        )
        trunc_normal_(self.fused_pos_embed, std=0.02)

        self.lf = lf
        self.num_features = self.vv.embed_dim

        # create new classification head for fused features
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.num_features),
            nn.Linear(self.num_features, num_classes),
        )

    def interpolate_pos_encoding(self, pos_embed, num_patches):
        pos_embed = pos_embed.float()
        N = pos_embed.shape[1] - 1  # original number of patches (excluding CLS token)

        # handle CLS token separately
        cls_pos_embed = pos_embed[:, 0:1, :]
        pos_embed = pos_embed[:, 1:, :]

        # interpolate patch position embeddings
        pos_embed = pos_embed.permute(0, 2, 1)
        pos_embed = nn.functional.interpolate(
            pos_embed, size=num_patches, mode="linear", align_corners=False
        )
        pos_embed = pos_embed.permute(0, 2, 1)

        # recombine with CLS token
        pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1)
        return nn.Parameter(pos_embed)

    def forward_features(self, x, v, lf):
        B = x.shape[0]
        x = v.patch_embed(x)
        if len(x.shape) > 3:  # just in case
            x = x.flatten(2).transpose(1, 2)
        cls_token = v.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = v.pos_drop(x + v.pos_embed)

        for i, block in enumerate(v.blocks):
            if i < lf:
                x = block(x)
        return x

    def forward(self, video, audio):
        B, F, c, h, w = video.shape
        video = video.view(B * F, c, h, w)

        # process separately until fusion layer
        v_features = self.forward_features(video, self.vv, self.lf)
        v_features = v_features.view(B, F, -1, self.num_features)
        v_features = torch.mean(v_features, dim=1)
        a_features = self.forward_features(audio, self.va, self.lf)

        fused = torch.cat((v_features, a_features), dim=1)

        # add fused positional embeddings after concatenation
        fused = fused + self.fused_pos_embed

        # pass through remaining layers
        for i in range(self.lf, len(self.vv.blocks)):
            fused = self.vv.blocks[i](fused)

        v_cls = fused[:, 0]
        a_cls = fused[:, self.vv.patch_embed.num_patches + 1]

        v_logits = self.classifier(v_cls)
        a_logits = self.classifier(a_cls)

        output = (v_logits + a_logits) / 2

        return output


In [7]:
import os
import datetime
import torch
import torch.nn as nn
from tqdm import tqdm
import wandb


class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device):
        wandb.login(key="key")

        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device

        if wandb.run is None:
            current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            run_name = f"wabt_{current_time}"

            wandb.init(project="wabt", name=run_name)
            wandb.watch(self.model)
        self.step = 0

        self.best_val_loss = float("inf")
        self.checkpoint_dir = os.path.join("checkpoints", run_name)
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(self.train_loader, desc="training")
        for video, audio, targets in pbar:
            video = video.to(self.device)
            audio = audio.to(self.device)
            targets = targets.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(video, audio)
            loss = self.criterion(outputs, targets)

            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            self.step += 1
            wandb.log(
                {
                    "batch/train_loss": loss.item(),
                    "batch/train_acc": predicted.eq(targets).sum().item()
                    / targets.size(0),
                    "train/running_loss": running_loss / (pbar.n + 1),
                    "train/running_acc": correct / total,
                    "step": self.step,
                }
            )

            pbar.set_postfix(
                {"loss": running_loss / (pbar.n + 1), "acc": correct / total}
            )

        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc="validation")
            for video, audio, targets in pbar:
                video = video.to(self.device)
                audio = audio.to(self.device)
                targets = targets.to(self.device)

                outputs = self.model(video, audio)
                loss = self.criterion(outputs, targets)

                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                self.step += 1
                wandb.log(
                    {
                        "batch/val_loss": loss.item(),
                        "batch/val_acc": predicted.eq(targets).sum().item()
                        / targets.size(0),
                        "val/running_loss": running_loss / (pbar.n + 1),
                        "step": self.step,
                    }
                )

                pbar.set_postfix(
                    {"loss": running_loss / (pbar.n + 1), "acc": correct / total}
                )

        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = correct / total

        # save checkpoint
        checkpoint = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "val_loss": epoch_loss,
            "val_acc": epoch_acc,
            "step": self.step,
        }

        # save latest checkpoint
        latest_path = os.path.join(self.checkpoint_dir, "latest.pt")
        torch.save(checkpoint, latest_path)

        # save best checkpoint
        if epoch_loss < self.best_val_loss:
            self.best_val_loss = epoch_loss
            best_path = os.path.join(self.checkpoint_dir, "best.pt")
            torch.save(checkpoint, best_path)
            wandb.log({"best_val_loss": self.best_val_loss})

        return epoch_loss, epoch_acc


def train(model, train_loader, val_loader, num_epochs=10, device="cuda"):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
    trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device)

    for epoch in range(num_epochs):
        print(f"\nepoch {epoch+1}/{num_epochs}")
        train_loss, train_acc = trainer.train_epoch()
        val_loss, val_acc = trainer.validate()

        wandb.log(
            {
                "epoch": epoch + 1,
                "epoch/train_loss": train_loss,
                "epoch/train_acc": train_acc,
                "epoch/val_loss": val_loss,
                "epoch/val_acc": val_acc,
            }
        )

        print(f"train loss: {train_loss:.4f} | train acc: {train_acc:.2f}")
        print(f"val loss: {val_loss:.4f} | val acc: {val_acc:.2f}")


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model().to(device)
train_loader, val_loader = get_dataloaders(data_path="data/", split_path="data/")

train(model, train_loader, val_loader, num_epochs=5, device=device)

model.safetensors:   0%|          | 0.00/410M [00:00<?, ?B/s]

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msouhhmm[0m ([33msouhhmm-bits-pilani[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc



epoch 1/5


training: 100%|██████████| 4769/4769 [51:30<00:00,  1.54it/s, loss=1.82, acc=0.543]  
validation: 100%|██████████| 1892/1892 [13:04<00:00,  2.41it/s, loss=1.18, acc=0.681]


train loss: 1.8238 | train acc: 0.54
val loss: 1.1784 | val acc: 0.68

epoch 2/5


training: 100%|██████████| 4769/4769 [51:52<00:00,  1.53it/s, loss=0.351, acc=0.892] 
validation: 100%|██████████| 1892/1892 [13:12<00:00,  2.39it/s, loss=0.909, acc=0.76] 


train loss: 0.3510 | train acc: 0.89
val loss: 0.9093 | val acc: 0.76

epoch 3/5


training: 100%|██████████| 4769/4769 [51:50<00:00,  1.53it/s, loss=0.165, acc=0.95] 
validation: 100%|██████████| 1892/1892 [13:28<00:00,  2.34it/s, loss=1.11, acc=0.721]


train loss: 0.1650 | train acc: 0.95
val loss: 1.1142 | val acc: 0.72

epoch 4/5


training: 100%|██████████| 4769/4769 [51:53<00:00,  1.53it/s, loss=0.0549, acc=0.985]
validation: 100%|██████████| 1892/1892 [13:08<00:00,  2.40it/s, loss=0.85, acc=0.792] 


train loss: 0.0549 | train acc: 0.98
val loss: 0.8501 | val acc: 0.79

epoch 5/5


training: 100%|██████████| 4769/4769 [52:35<00:00,  1.51it/s, loss=0.0332, acc=0.991]
validation: 100%|██████████| 1892/1892 [13:14<00:00,  2.38it/s, loss=0.876, acc=0.799]


train loss: 0.0332 | train acc: 0.99
val loss: 0.8756 | val acc: 0.80


In [8]:
def load_model(checkpoint_path, device="cuda"):
    model = Model().to(device)
    checkpoint = torch.load(checkpoint_path, weights_only=True)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    return model


def get_class_mapping():
    # create a dataset instance just to get the class mapping
    dataset = UCF101Dataset(data_path="data/", split_path="data/ucfTrainTestlist", split="test")
    # invert the class_to_idx dictionary
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    return idx_to_class


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # load model
    checkpoint_path = "checkpoints/wabt_20250206_182221/best.pt"
    model = load_model(checkpoint_path, device)

    # get class mapping
    idx_to_class = get_class_mapping()

    # create dataset and get a sample
    dataset = UCF101Dataset(
        data_path="data/",
        split_path="data/ucfTrainTestlist",
        split="test",
    )

    random_idxs = torch.randperm(len(dataset))[:10]
    for i in random_idxs:
        video, audio, true_label = dataset[i]

        # add batch dimension
        video = video.unsqueeze(0).to(device)
        audio = audio.unsqueeze(0).to(device)

        # get prediction
        with torch.no_grad():
            outputs = model(video, audio)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)

            # get top 3 predictions
            top_prob, top_idx = torch.topk(probabilities, 3)

        print(f"\nsample {i + 1}")
        print(f"true class: {idx_to_class[true_label]}")
        print("top 3 predictions:")
        for prob, idx in zip(top_prob[0], top_idx[0]):
            print(f"{idx_to_class[idx.item()]}: {prob.item()*100:.2f}%")
        print("-" * 50)


if __name__ == "__main__":
    main()



sample 1746
true class: JugglingBalls
top 3 predictions:
JugglingBalls: 75.22%
Nunchucks: 24.11%
PlayingGuitar: 0.36%
--------------------------------------------------

sample 3686
true class: WallPushups
top 3 predictions:
WritingOnBoard: 90.42%
YoYo: 1.41%
Hammering: 1.38%
--------------------------------------------------

sample 1988
true class: MilitaryParade
top 3 predictions:
MilitaryParade: 99.43%
HorseRace: 0.50%
VolleyballSpiking: 0.04%
--------------------------------------------------

sample 382
true class: BenchPress
top 3 predictions:
BenchPress: 99.06%
Punch: 0.44%
CleanAndJerk: 0.22%
--------------------------------------------------

sample 1193
true class: FrontCrawl
top 3 predictions:
FrontCrawl: 98.00%
BreastStroke: 1.92%
SkyDiving: 0.02%
--------------------------------------------------

sample 3380
true class: TaiChi
top 3 predictions:
BaseballPitch: 88.55%
SalsaSpin: 4.56%
TaiChi: 4.49%
--------------------------------------------------

sample 656
true class

In [9]:
import os
import torchaudio
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

def analyze_dataset(data_path, split_path):
    train_file = os.path.join(split_path, "ucfTrainTestlist", "trainlist01.txt")
    test_file = os.path.join(split_path, "ucfTrainTestlist", "testlist01.txt")
    
    with open(train_file, "r") as f:
        train_videos = [line.strip().split(" ")[0] for line in f.readlines()]
    
    with open(test_file, "r") as f:
        test_videos = [line.strip().split(" ")[0] for line in f.readlines()]
    
    data_path = Path(data_path)
    
    def check_audio(video_list):
        has_audio = 0
        no_audio = 0
        
        for video in tqdm(video_list):
            video_path = data_path / "UCF-101" / video
            try:
                audio_array, _ = torchaudio.load(str(video_path))
                if audio_array.shape[1] > 0:  # check if audio has content
                    has_audio += 1
                else:
                    no_audio += 1
            except:
                no_audio += 1
        
        return has_audio, no_audio
    
    print("analyzing training split...")
    train_has_audio, train_no_audio = check_audio(train_videos)
    
    print("analyzing test split...")
    test_has_audio, test_no_audio = check_audio(test_videos)
    
    # create visualization
    labels = ['Training Set', 'Test Set']
    has_audio = [train_has_audio, test_has_audio]
    no_audio = [train_no_audio, test_no_audio]
    
    x = range(len(labels))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.bar([i - width/2 for i in x], has_audio, width, label='Has Audio')
    ax.bar([i + width/2 for i in x], no_audio, width, label='No Audio')
    
    ax.set_ylabel('Number of Videos')
    ax.set_title('Audio Availability in UCF101 Dataset')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()

    for i in x:
        ax.text(i - width/2, has_audio[i], str(has_audio[i]), ha='center', va='bottom')
        ax.text(i + width/2, no_audio[i], str(no_audio[i]), ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('audio_availability.png')
    plt.close()
    
    print("\nSummary:")
    print(f"Training Set - Total: {len(train_videos)}")
    print(f"  - With Audio: {train_has_audio} ({train_has_audio/len(train_videos)*100:.1f}%)")
    print(f"  - Without Audio: {train_no_audio} ({train_no_audio/len(train_videos)*100:.1f}%)")
    
    print(f"\nTest Set - Total: {len(test_videos)}")
    print(f"  - With Audio: {test_has_audio} ({test_has_audio/len(test_videos)*100:.1f}%)")
    print(f"  - Without Audio: {test_no_audio} ({test_no_audio/len(test_videos)*100:.1f}%)")

if __name__ == "__main__":
    analyze_dataset("data/", "data/")


analyzing training split...


100%|██████████| 9537/9537 [00:57<00:00, 167.26it/s]


analyzing test split...


100%|██████████| 3783/3783 [00:22<00:00, 167.86it/s]



Summary:
Training Set - Total: 9537
  - With Audio: 4893 (51.3%)
  - Without Audio: 4644 (48.7%)

Test Set - Total: 3783
  - With Audio: 1944 (51.4%)
  - Without Audio: 1839 (48.6%)
