In [4]:
!pip install av
!pip install rarfile
!pip install wget



In [4]:
import os
import ssl
import wget
import urllib.request
import rarfile
import zipfile

context = ssl._create_unverified_context()
def download_url(url, path):
    print(f"downloading {url}...")
    with urllib.request.urlopen(url, context=context) as response:
        with open(path, 'wb') as f:
            f.write(response.read())

def main():
    data_dir = "data"
    os.makedirs(data_dir, exist_ok=True)

    video_url = "https://www.crcv.ucf.edu/data/UCF101/UCF101.rar"
    annotation_url = "https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip"

    download_url(video_url, os.path.join(data_dir, "UCF101.rar"))
    download_url(annotation_url, os.path.join(data_dir, "UCF101_annotations.zip"))

    print("extracting video data...")
    rf = rarfile.RarFile(os.path.join(data_dir, "UCF101.rar"))
    rf.extractall(data_dir)

    print("extracting annotations...")
    with zipfile.ZipFile(os.path.join(data_dir, "UCF101_annotations.zip")) as zf:
        zf.extractall(data_dir)

    os.remove(os.path.join(data_dir, "UCF101.rar"))
    os.remove(os.path.join(data_dir, "UCF101_annotations.zip"))

    print("downloading ast...")
    os.makedirs('pretrained_weights', exist_ok=True)
    wget.download('https://www.dropbox.com/s/mdsa4t1xmcimia6/audioset_16_16_0.4422.pth?dl=1', os.path.join('pretrained_weights', 'audioset_16_16_0.4422.pth'))
    print("done")

if __name__ == "__main__":
    main()


downloading https://www.crcv.ucf.edu/data/UCF101/UCF101.rar...
downloading https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip...
extracting video data...
extracting annotations...
downloading ast...
done


In [5]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import torchaudio
from pathlib import Path
import torchvision.io as io


class UCF101Dataset(Dataset):
    def __init__(self, data_path, split_path, split="train", num_frames=8, t=4):
        self.data_path = Path(data_path)
        self.num_frames = num_frames
        self.t = t
        # read split file
        split_file = "trainlist01.txt" if split == "train" else "testlist01.txt"
        with open(os.path.join(split_path, split_file), "r") as f:
            self.video_list = [line.strip().split(" ")[0] for line in f.readlines()]

        # TODO add more randomizations
        self.video_transform = transforms.Compose(
            [
                transforms.Resize((256, 256)), # slightly larger for random crop
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        # create class to index mapping
        self.class_to_idx = {}
        classes = sorted(
            list(set(video_name.split("/")[0] for video_name in self.video_list))
        )
        for idx, classname in enumerate(classes):
            self.class_to_idx[classname] = idx

    def _load_video(self, video_path):
        try:
            vframes, _, _ = io.read_video(str(video_path), pts_unit="sec")
            total_frames = len(vframes)

            # ensure we don't sample beyond video length
            if total_frames < self.num_frames:
                indices = torch.linspace(0, total_frames - 1, total_frames).long()
                indices = torch.cat(
                    [
                        indices,
                        torch.tensor(
                            [total_frames - 1] * (self.num_frames - total_frames)
                        ),
                    ]
                )
            else:
                indices = torch.linspace(0, total_frames - 1, self.num_frames).long()

            frames = []
            for idx in indices:
                frame = vframes[idx]
                frame = Image.fromarray(frame.numpy())
                frame = self.video_transform(frame)
                frames.append(frame)

        except Exception as e:
            print(e)
            frames = [torch.zeros(3, 224, 224) for _ in range(self.num_frames)]

        return torch.stack(frames)  # [num_frames, c, h, w]

    # TODO specaugment
    def _load_audio(self, video_path):
        try:
            audio_array, sample_rate = torchaudio.load(str(video_path))
        except (RuntimeError, TypeError):
            # create a small amount of noise instead of pure zeros
            audio_array = torch.randn(1, 16000 * self.t) * 1e-4
            sample_rate = 16000

        # convert to mono
        if audio_array.shape[0] > 1:
            audio_array = torch.mean(audio_array, dim=0, keepdim=True)

        # resample if necessary
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            audio_array = resampler(audio_array)

        target_length = self.t * 16000
        if audio_array.shape[1] < target_length:
            # pad with zeros if audio is too short
            audio_array = torch.nn.functional.pad(
                audio_array, (0, target_length - audio_array.shape[1])
            )
        else:
            # trim if audio is too long
            audio_array = audio_array[:, :target_length]

        # create mel spectrogram
        spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000,
            n_mels=128,
            n_fft=1024,
            win_length=1024,
            hop_length=160,
        )(audio_array)

        spectrogram = torchaudio.transforms.AmplitudeToDB()(spectrogram)
        spectrogram = spectrogram.squeeze(0)  # remove channel dimension

        if spectrogram.shape[1] > 400:
            spectrogram = spectrogram[:, :400]
        elif spectrogram.shape[1] < 400:
            spectrogram = torch.nn.functional.pad(
                spectrogram, (0, 400 - spectrogram.shape[1])
            )

        # mean=0 std=0.5 according to ast
        spectrogram = (spectrogram - spectrogram.mean()) / (spectrogram.std() + 1e-6) * 0.5

        return spectrogram.unsqueeze(0)  # add channel dimension back [1, 128, 100*t]

    def __getitem__(self, idx):
        video_name = self.video_list[idx]
        video_path = self.data_path / "UCF-101" / video_name

        label = video_name.split("/")[0]
        video_tensor = self._load_video(video_path)
        audio_tensor = self._load_audio(video_path)
        class_idx = self.class_to_idx[label]

        return video_tensor, audio_tensor, class_idx

    def __len__(self):
        return len(self.video_list)


def get_dataloaders(data_path, split_path, batch_size=2, num_workers=0):

    train_dataset = UCF101Dataset(
        data_path=data_path,
        split_path=os.path.join(split_path, "ucfTrainTestlist"),
        split="train",
    )

    val_dataset = UCF101Dataset(
        data_path=data_path,
        split_path=os.path.join(split_path, "ucfTrainTestlist"),
        split="test",
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader


In [6]:
import torch
import torch.nn as nn
import timm
from timm.layers import trunc_normal_


# TODO: change hardcoded values
class PatchEmbed(nn.Module):
    def __init__(
        self, img_size=(128, 400), patch_size=(16, 16), in_chans=1, embed_dim=768
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
        )

    def forward(self, x):
        # B, C, H, W = x.shape
        # assert (
        #     H == self.img_size[0] and W == self.img_size[1]
        # ), f"input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})"
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class Model(nn.Module):
    def __init__(self, num_classes=101, lf=10, num_bottlenecks=8):
        super().__init__()

        # set num_classes=0 to remove classification head
        self.vv = timm.create_model(
            "vit_base_patch16_224.augreg_in21k", pretrained=True, num_classes=0
        )

        self.va = timm.create_model(
            "vit_base_patch16_224.augreg_in21k", pretrained=True, num_classes=0
        )

        # apply ast weights to va
        ast_weights = torch.load(
            "pretrained_weights/audioset_16_16_0.4422.pth", weights_only=True
        )
        temp = self.va.state_dict()
        pretrained_dict = {}
        for k, v in ast_weights.items():
            if k.startswith("module."):
                k = k[7:]
            if k in temp and temp[k].shape == v.shape:
                pretrained_dict[k] = v
        temp.update(pretrained_dict)
        self.va.load_state_dict(temp)

        self.va.patch_embed = PatchEmbed(
            img_size=(128, 400),
            patch_size=(16, 16),
            in_chans=1,
            embed_dim=self.va.embed_dim,
        )

        # interpolate position embeddings for video
        num_patches_video = self.vv.patch_embed.num_patches
        self.vv.pos_embed = self.interpolate_pos_encoding(
            self.vv.pos_embed, num_patches_video
        )

        # interpolate position embeddings for audio
        num_patches_audio = self.va.patch_embed.num_patches
        self.va.pos_embed = self.interpolate_pos_encoding(
            self.va.pos_embed, num_patches_audio
        )

        self.lf = lf
        self.num_features = self.vv.embed_dim

        # create new classification head for fused features
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.num_features),
            nn.Linear(self.num_features, num_classes),
        )

        # create bottleneck fusion tokens
        self.num_bottlenecks = num_bottlenecks
        self.zfsn = nn.Parameter(torch.zeros(1, num_bottlenecks, self.vv.embed_dim))
        trunc_normal_(self.zfsn, std=0.02)

    def interpolate_pos_encoding(self, pos_embed, num_patches):
        pos_embed = pos_embed.float()
        N = pos_embed.shape[1] - 1  # original number of patches (excluding CLS token)

        # handle CLS token separately
        cls_pos_embed = pos_embed[:, 0:1, :]
        pos_embed = pos_embed[:, 1:, :]

        # interpolate patch position embeddings
        pos_embed = pos_embed.permute(0, 2, 1)
        pos_embed = nn.functional.interpolate(
            pos_embed, size=num_patches, mode="linear", align_corners=False
        )
        pos_embed = pos_embed.permute(0, 2, 1)

        # recombine with CLS token
        pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1)
        return nn.Parameter(pos_embed)

    def forward_features(self, x, v, lf):
        B = x.shape[0]
        x = v.patch_embed(x)
        if len(x.shape) > 3:  # just in case
            x = x.flatten(2).transpose(1, 2)
        cls_token = v.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = v.pos_drop(x + v.pos_embed)

        for i, block in enumerate(v.blocks):
            if i < lf:
                x = block(x)
        return x

    def forward(self, video, audio):
        B, F, C, H, W = video.shape
        video = video.view(B * F, C, H, W)

        # process separately until fusion layer
        v_features = self.forward_features(video, self.vv, self.lf)
        v_features = v_features.view(B, F, -1, self.num_features)
        v_features = torch.mean(v_features, dim=1)
        a_features = self.forward_features(audio, self.va, self.lf)

        # expand fusion tokens for batch
        zfsn = self.zfsn.expand(B, -1, -1)

        # process remaining layers with bottleneck fusion
        for block in self.vv.blocks[self.lf :]:
            # eqn 8
            v_concat = torch.cat([v_features, zfsn], dim=1)
            a_concat = torch.cat([a_features, zfsn], dim=1)

            v_out = block(v_concat)
            a_out = block(a_concat)

            # split features and fusion tokens
            v_features = v_out[:, : v_features.shape[1]]
            a_features = a_out[:, : a_features.shape[1]]
            v_zfsn = v_out[:, v_features.shape[1] :]
            a_zfsn = a_out[:, a_features.shape[1] :]

            # eqn 9
            zfsn = (v_zfsn + a_zfsn) / 2

        # get CLS tokens
        v_cls = v_features[:, 0]
        a_cls = a_features[:, 0]

        v_logits = self.classifier(v_cls)
        a_logits = self.classifier(a_cls)

        output = (v_logits + a_logits) / 2

        return output


In [7]:
import os
import datetime
import torch
import torch.nn as nn
from tqdm import tqdm
import wandb


class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device):
        wandb.login(key="2f3ffd7baf545af396e18e48bfa20b33d2609dcc")

        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device

        if wandb.run is None:
            current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            run_name = f"wabt_{current_time}"

            wandb.init(project="wabt", name=run_name)
            wandb.watch(self.model)
        self.step = 0

        self.best_val_loss = float("inf")
        self.checkpoint_dir = os.path.join("checkpoints", run_name)
        os.makedirs(self.checkpoint_dir, exist_ok=True)

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(self.train_loader, desc="training")
        for video, audio, targets in pbar:
            video = video.to(self.device)
            audio = audio.to(self.device)
            targets = targets.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(video, audio)
            loss = self.criterion(outputs, targets)

            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            self.step += 1
            wandb.log(
                {
                    "batch/train_loss": loss.item(),
                    "batch/train_acc": predicted.eq(targets).sum().item()
                    / targets.size(0),
                    "train/running_loss": running_loss / (pbar.n + 1),
                    "train/running_acc": correct / total,
                    "step": self.step,
                }
            )

            pbar.set_postfix(
                {"loss": running_loss / (pbar.n + 1), "acc": correct / total}
            )

        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = correct / total
        return epoch_loss, epoch_acc

    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc="validation")
            for video, audio, targets in pbar:
                video = video.to(self.device)
                audio = audio.to(self.device)
                targets = targets.to(self.device)

                outputs = self.model(video, audio)
                loss = self.criterion(outputs, targets)

                running_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

                self.step += 1
                wandb.log(
                    {
                        "batch/val_loss": loss.item(),
                        "batch/val_acc": predicted.eq(targets).sum().item()
                        / targets.size(0),
                        "val/running_loss": running_loss / (pbar.n + 1),
                        "step": self.step,
                    }
                )

                pbar.set_postfix(
                    {"loss": running_loss / (pbar.n + 1), "acc": correct / total}
                )

        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = correct / total

        # save checkpoint
        checkpoint = {
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "val_loss": epoch_loss,
            "val_acc": epoch_acc,
            "step": self.step,
        }

        # save latest checkpoint
        latest_path = os.path.join(self.checkpoint_dir, "latest.pt")
        torch.save(checkpoint, latest_path)

        # save best checkpoint
        if epoch_loss < self.best_val_loss:
            self.best_val_loss = epoch_loss
            best_path = os.path.join(self.checkpoint_dir, "best.pt")
            torch.save(checkpoint, best_path)
            wandb.log({"best_val_loss": self.best_val_loss})

        return epoch_loss, epoch_acc


def train(model, train_loader, val_loader, num_epochs=10, device="cuda"):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
    trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device)

    for epoch in range(num_epochs):
        print(f"\nepoch {epoch+1}/{num_epochs}")
        train_loss, train_acc = trainer.train_epoch()
        val_loss, val_acc = trainer.validate()

        wandb.log(
            {
                "epoch": epoch + 1,
                "epoch/train_loss": train_loss,
                "epoch/train_acc": train_acc,
                "epoch/val_loss": val_loss,
                "epoch/val_acc": val_acc,
            }
        )

        print(f"train loss: {train_loss:.4f} | train acc: {train_acc:.2f}")
        print(f"val loss: {val_loss:.4f} | val acc: {val_acc:.2f}")


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model().to(device)
train_loader, val_loader = get_dataloaders(data_path="data/", split_path="data/")

train(model, train_loader, val_loader, num_epochs=3, device=device)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msouhhmm[0m ([33msouhhmm-bits-pilani[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc



epoch 1/3


training: 100%|██████████| 4769/4769 [51:28<00:00,  1.54it/s, loss=1.92, acc=0.536] 
validation: 100%|██████████| 1892/1892 [13:14<00:00,  2.38it/s, loss=1.02, acc=0.725] 


train loss: 1.9194 | train acc: 0.54
val loss: 1.0198 | val acc: 0.72

epoch 2/3


training: 100%|██████████| 4769/4769 [51:38<00:00,  1.54it/s, loss=0.281, acc=0.923] 
validation: 100%|██████████| 1892/1892 [13:13<00:00,  2.39it/s, loss=0.924, acc=0.742]


train loss: 0.2811 | train acc: 0.92
val loss: 0.9236 | val acc: 0.74

epoch 3/3


training: 100%|██████████| 4769/4769 [51:41<00:00,  1.54it/s, loss=0.1, acc=0.973]    
validation: 100%|██████████| 1892/1892 [13:13<00:00,  2.38it/s, loss=0.622, acc=0.831]


train loss: 0.1003 | train acc: 0.97
val loss: 0.6221 | val acc: 0.83
