In [2]:
import os
from glob import glob
import logging
from pathlib import Path
from typing import Union, Optional, Tuple, List, Dict

import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from sklearn.metrics import average_precision_score
from tqdm.auto import tqdm

In [3]:
class Config: 
    # logging
    log_dir = Path("exp/train")
    
    # data
    data_dir = Path("../data/")
    num_labels = 256
    crop_size = 81

    batch_size = 256
    eval_batch_size = 256
    num_workers = 4

    # net
    input_dim = 768
    hidden_dim = 512

    device = "cuda:1"
    use_amp = True
    clip_value = None
    lr = 1e-3
    min_lr = 1e-8

    n_epochs = 20

cfg = Config()

In [4]:
def get_exp_name(log_dir: Path):
    log_dir.mkdir(parents=True, exist_ok=True)
    new_exp_name = log_dir.name
    prev_exps = [exp.name for exp in log_dir.parent.iterdir()]
    last_exp_num = ""
    for exp in prev_exps:
        if new_exp_name in exp:
            tmp = str(exp).split("_")
            if len(tmp) > 1: 
                last_exp_num = int(tmp[-1]) + 1
            else:
                last_exp_num = 1
            last_exp_num = f"_{last_exp_num}"
    return log_dir.parent / f"{new_exp_name}{last_exp_num}"

logdir = get_exp_name(cfg.logs_dir)
print(logdir)
tb = SummaryWriter(log_dir=logdir)

exp/train_3


# Data

In [5]:
# df_train = pd.read_csv('../data/train.csv')
# df_test = pd.read_csv('../data/test.csv')

# track_idx2embeds = {}
# for fn in tqdm(glob('../data/track_embeddings/*')):
#     track_idx = int(fn.split('/')[3].split('.')[0])
#     embeds = np.load(fn)
#     track_idx2embeds[track_idx] = embeds

# def collate_fn(b):
#     track_idxs = torch.from_numpy(np.vstack([x[0] for x in b]))
#     targets = torch.from_numpy(np.vstack([x[2] for x in b]))
#     # embeds = [torch.from_numpy(x[1]) for x in b]
#     embeds = torch.stack([x[1] for x in b])
#     return track_idxs, embeds, targets


  0%|          | 0/76714 [00:00<?, ?it/s]

In [6]:
class EmbeddingDataset(Dataset):
    def __init__(
        self,
        data_dir,
        num_labels=256,
        crop_size=60,
        stage="train",
        transform=None,
    ):
        super().__init__()
        self.data_dir = Path(data_dir)
        self.stage = stage
        self.transform = transform
        self.crop_size = crop_size
        self.num_labels = num_labels

        self.meta_info = pd.read_csv(self.data_dir / "metadata.csv", sep="\t")

        assert stage in ("train", "val", "test", "infer")

        if self.stage in ["train", "val", "test"]:
            if "stage" in self.meta_info.columns:
                self.meta_info = self.meta_info.loc[self.meta_info.stage == stage].reset_index(drop=True)
            self.labels = torch.tensor(self.meta_info.tags.apply(self.process_tags)).float()

        self.tracks = self.meta_info.track.values

        print("Uploading embeddings into memory")
        self.embeddings = [self.get_embedding(track) for track in self.tracks]

    def process_tags(self, tags):
        tags = list(map(int, tags.split(",")))
        one_hot_tags = np.zeros(self.num_labels, dtype=np.uint8)
        one_hot_tags[tags] = 1
        return one_hot_tags.tolist()

    def get_embedding(self, track: int) -> torch.Tensor:
        embeddings = np.load(self.data_dir / f"track_embeddings/{track}.npy")
        embeddings = torch.from_numpy(embeddings)
        return embeddings

    def __process_features(self, x: torch.Tensor):
        # normalize
        # x /= x.max()
        # x = (x - x.mean()) / x.std()

        # add padding
        x = x.permute(1, 0)
        x_len = x.shape[-1]
        if x_len > self.crop_size:
            start = np.random.randint(0, x_len - self.crop_size)
            x = x[..., start : start + self.crop_size]
        else:
            if self.stage == "train":
                i = (
                    np.random.randint(0, self.crop_size - x_len)
                    if self.crop_size != x_len
                    else 0
                )
            else:
                i = (self.crop_size - x_len) // 2
            pad_patern = (i, self.crop_size - x_len - i)
            x = torch.nn.functional.pad(x, pad_patern, "constant").detach()
        x = x.permute(1, 0)
        return x

    def __getitem__(self, idx):
        track_features = self.embeddings[idx]
        track_features = self.__process_features(track_features)

        if self.labels is not None:
            label = self.labels[idx]

            out = {
                "features": track_features,
                "label": label,
                "track": self.tracks[idx],
            }
            # return self.tracks[idx], track_features, label
        else:
            out = {
                "features": track_features,
                "track": self.tracks[idx],
            }
            # return self.tracks[idx], track_features

        return out

    def __len__(self):
        return len(self.meta_info)


# Network

In [7]:
class Network(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_labels):
        super().__init__()
        self.num_labels = num_labels

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                input_dim, 8, dim_feedforward=2048, dropout=0.2, batch_first=True
            ),
            num_layers=3
        )
        self.pooling = nn.AdaptiveAvgPool2d((1, input_dim))
        self.projector =  nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, num_labels),
        )

    def forward(self, x):
        x = self.transformer(x)
        x = self.pooling(x).squeeze()
        x = self.projector(x)
        return x

# Training

In [8]:
def train(net, train_dataloader, val_dataloader, n_epochs, optimizer, criterion, device, scheduler=None, use_amp=True):
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    alpha = 0.8

    for epoch in range(n_epochs): 
        # Training
        net.train()     
        train_loss = None
        train_targets = []
        train_preds = []
        for data in (pbar := tqdm(train_dataloader)):
            optimizer.zero_grad()
            track_ids, batch, targets = data["track"], data["features"], data["label"]
            # track_ids, batch, targets = data
            batch = batch.to(device)
            targets = targets.to(device)
            with torch.cuda.amp.autocast(enabled = use_amp), torch.backends.cuda.sdp_kernel(enable_flash=cfg.device == "cuda:0"):
                logits = net(batch)
                loss = criterion(logits, targets)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            if cfg.clip_value is not None:
                torch.nn.utils.clip_grad_norm_(net.parameters(), cfg.clip_value)
            scaler.step(optimizer)
            scaler.update()
            
            train_loss = loss.item() if not train_loss else alpha * train_loss + (1 - alpha) * loss.item()
            train_targets.extend(targets.cpu().numpy())
            train_preds.extend(torch.sigmoid(logits.detach()).cpu().numpy())

            pbar.set_description(f"Epoch: {epoch} Loss: {train_loss:.6f}")

        if scheduler:
            scheduler.step()
            
        train_loss = np.mean(train_loss)
        train_score = average_precision_score(train_targets, train_preds)
        print('Train Loss:', train_loss)
        print('Train AP:', train_score)

        # Evaluation
        net.eval()   
        val_loss = None
        val_targets = []
        val_preds = []
        for data in (pbar := tqdm(val_dataloader)):
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled = use_amp), torch.backends.cuda.sdp_kernel(enable_flash=cfg.device == "cuda:0"):
                    track_ids, batch, targets = data["track"], data["features"], data["label"]
                    # track_ids, batch, targets = data
                    batch = batch.to(device)
                    targets = targets.to(device)

                    logits = net(batch)
                    loss = criterion(logits, targets.float())

                val_loss = loss.item() if not val_loss else alpha * val_loss + (1 - alpha) * loss.item()
                val_targets.extend(targets.cpu().numpy())
                val_preds.extend(torch.sigmoid(logits).cpu().numpy())

                pbar.set_description(f"Epoch: {epoch} Loss: {val_loss:.6f}")

        val_loss = np.mean(val_loss)
        val_score = average_precision_score(val_targets, val_preds)
        print('Val Loss:', val_loss)
        print('Val AP:', val_score)

        tb.add_scalar('Loss/train', train_loss, epoch)
        tb.add_scalar('Loss/val', val_loss, epoch)
        tb.add_scalar('AP/train', train_score, epoch)
        tb.add_scalar('AP/val', val_score, epoch)

In [9]:
train_dataset = EmbeddingDataset(cfg.data_dir, cfg.num_labels, cfg.crop_size, stage="train")
val_dataset = EmbeddingDataset(cfg.data_dir, cfg.num_labels, cfg.crop_size, stage="val")
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_loader = DataLoader(val_dataset, batch_size=cfg.eval_batch_size, shuffle=False, num_workers=cfg.num_workers)

In [10]:
net = Network(
    input_dim=768,
    hidden_dim=512,
    num_labels=cfg.num_labels,
).to(cfg.device)
optimizer = torch.optim.Adam(net.parameters(), lr=cfg.lr)
criterion = torch.nn.BCEWithLogitsLoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.n_epochs, eta_min=cfg.min_lr)

In [11]:
train(
    net, train_loader, val_loader, cfg.n_epochs, optimizer, criterion, cfg.device, scheduler, cfg.use_amp
)

  0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 