# Dance recognition training - full

In [None]:
### --- Download dataset
import zipfile
from pathlib import Path

dataset_path = Path("/content") / "body-postures"

if not dataset_path.exists():
  !wget https://www.dropbox.com/s/b9gfafnh6aesrsu/body-postures-dataset.bin
  # alternative mirror: https://lenzgregor.com/nextcloud/s/tbamCq9Eo95qfLc/download/body-postures-dataset.bin
  !mv body-postures-dataset.bin body-postures-dataset.tar.gz
  !tar -xzf body-postures-dataset.tar.gz
  !cd body-postures-dataset; python -m pip install .

!pip install torchmetrics

In [None]:
### --- Imports
from typing import Optional

from tqdm.auto import tqdm
import torch
import torchmetrics
from torch import nn

# Dataset
from body_postures import BodyPostureFrames

### Settings and hyperparameters

In [None]:
n_classes = 7
batch_size = 256
learning_rate = 1e-3
weight_decay = 0

## Data inspection

In [None]:
frame_dataset_training = BodyPostureFrames(
    cache_path='cache/frames_training',
    event_count=3000,
    reset_cache=True,
    train=True,
    hot_pixel_filter_freq=60
)
print(len(frame_dataset_training))
frame_dataset_validation = BodyPostureFrames(
    cache_path='cache/frames_val',
    event_count=3000,
    reset_cache=True,
    train=False,
    hot_pixel_filter_freq=60
)
print(len(frame_dataset_validation))
train_loader = torch.utils.data.DataLoader(frame_dataset_training, batch_size=batch_size)
val_loader = torch.utils.data.DataLoader(frame_dataset_validation, batch_size=batch_size)

classes = {
    0: "background",
    1: "clap",
    2: "mj",
    3: "salive",
    4: "star",
    5: "wave",
    6: "other",
}

### Model definition

In [None]:
class ANN(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        
        # The actual network
        self.net = nn.Sequential(
            nn.Conv2d(2, 16, kernel_size=(3, 3), stride=2, padding=1, bias=False),  # 16, 64, 64
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2, 2)),  # 16, 32, 32
            nn.Dropout2d(0.1),
            
            nn.Conv2d(16, 32, kernel_size=(3, 3), stride=2, padding=1, bias=False),  # 32, 16, 16
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2, 2)),  # 32, 8, 8
            nn.Dropout2d(0.25),
            
            nn.Conv2d(32, 32, kernel_size=(3, 3), stride=1, padding=1, bias=False),  # 64, 4, 4
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(2, 2)),  # 64, 2, 2
            nn.Flatten(),
            nn.Dropout2d(0.5),
            
            nn.Linear(32*4*4, n_classes, bias=False),
            nn.ReLU(),
        )
        
    def forward(self, x):
        # Reshape data
        x = torch.as_tensor(x, dtype=torch.float32).squeeze(1)
        return self.net(x)

In [None]:
ann = ANN(n_classes)

## -- Prepare training
# - Loss
criterion = torch.nn.CrossEntropyLoss()

# - Optimizer
optimizer = torch.optim.Adam(ann.parameters(), lr=learning_rate, weight_decay=weight_decay)

# - Confusion matrix
confusion_matrix = torchmetrics.ConfusionMatrix(num_classes=n_classes)

# - Training loop
def train(num_epochs: int=100, validate_every: Optional[None]=5):
    
    # Track losses
    training_losses = []
    val_losses = []
    val_accuracy = []
    
    pbar_epoch = tqdm(range(num_epochs), desc="Epochs")
    pbar_batches_tr = tqdm(train_loader, desc="Batches training")
    pbar_batches_val = tqdm(val_loader, desc="Batches validation")
    for epoch in pbar_epoch:
        losses = []
        for data, labels in pbar_batches_tr:
            # Forward pass
            logits = ann(data)
            # Loss for current batch
            loss = criterion(logits, labels)
            # Backward pass
            loss.backward()
            # Update parameters
            optimizer.step()
            optimizer.zero_grad()
            
            loss = loss.item()
            losses.append(loss)
            pbar_batches_tr.set_postfix(loss=loss)
        
        epoch_loss = sum(losses) / len(losses)
        training_losses.append(epoch_loss)
        pbar_batches_tr.set_postfix(loss=epoch_loss)
        
        # Validation
        if validate_every is not None and epoch % validate_every == 0:
            with torch.no_grad():
                val_data = []
                losses = []
                for data, labels in pbar_batches_val:
                    logits = ann(data)
                    # Loss
                    val_loss = criterion(logits, labels)
                    losses.append(val_loss)
                    # Collect outputs and labels
                    predictions = torch.max(logits, dim=1)[1]
                    val_data.append((predictions, labels))
                
                val_losses.append(sum(losses) / len(losses))
                
                # Calculate accuracy
                confusion_list = [confusion_matrix(pred, label) for pred, label in val_data]
                confusion = torch.stack(confusion_list).sum(0)
                val_accuracy = confusion.trace() / confusion.sum()
                
                pbar_batches_val.set_postfix(loss=val_loss, accuracy=val_accuracy)
                
    
            

In [None]:
train()