## Analiza Atencji dla STSA-Net do rozpoznawania ćwiczeń


### Wczytanie zbioru danych

TODO: dodanie wczytania zbioru danych.

### Zdefiniowanie sieci

In [None]:
import torch
import torch.nn as nn

DEBUG = False
DROPOUT = 0.1
NUM_HEADS = 4
FFN_MULT = 4
NUM_JOINTS = 41
JOINT_DIM = 36
NUM_CLASSES = 6
T = 7
H = 16

class SpatialAttentionBlock(nn.Module):

    def __init__(self, H, num_heads=8, dropout=DROPOUT):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=H, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(H)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, J, H = x.shape
        xt = x.reshape(B * T, J, H)
        out, _ = self.mha(xt, xt, xt, need_weights=False)
        out = self.dropout(out)
        out = out + xt
        out = self.norm(out)
        out = out.reshape(B, T, J, H)
        return out


class TemporalAttentionBlock(nn.Module):

    def __init__(self, H, num_heads=8, dropout=DROPOUT):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim=H, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.norm = nn.LayerNorm(H)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, J, H = x.shape
        xt = x.permute(0, 2, 1, 3).reshape(B * J, T, H)
        out, _ = self.mha(xt, xt, xt, need_weights=False)
        out = self.dropout(out)
        out = out + xt
        out = self.norm(out)
        out = out.reshape(B, J, T, H).permute(0, 2, 1, 3)
        return out


class STSABlock(nn.Module):

    def __init__(self, H, num_heads=NUM_HEADS, dropout=DROPOUT, ffn_mult=FFN_MULT):
        super().__init__()
        self.s_attn = SpatialAttentionBlock(H, num_heads=num_heads, dropout=dropout)
        self.t_attn = TemporalAttentionBlock(H, num_heads=num_heads, dropout=dropout)

        self.ffn_norm = nn.LayerNorm(H)
        self.ffn = nn.Sequential(
            nn.Linear(H, H * ffn_mult),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(H * ffn_mult, H),
            nn.Dropout(dropout),
        )
        self.final_norm = nn.LayerNorm(H)

    def forward(self, x):
        s = self.s_attn(x)
        x = x + s

        t = self.t_attn(x)
        x = x + t

        B, T, J, H = x.shape
        y = x.reshape(B * T * J, H)
        y = self.ffn(y)
        y = y.reshape(B, T, J, H)
        x = x + y

        x = self.final_norm(x)
        return x


class STSANet(nn.Module):
    def __init__(self, T=T, J=NUM_JOINTS, D=JOINT_DIM, H=H, num_classes=NUM_CLASSES,
                 num_heads=NUM_HEADS, dropout=DROPOUT, ffn_mult=FFN_MULT, num_blocks=2):
        super().__init__()
        self.T = T
        self.J = J
        self.D = D
        self.H = H

        self.embed = nn.Linear(D, H)

        self.pos_time = nn.Parameter(torch.randn(1, T, 1, H) * 0.02)
        self.pos_joint = nn.Parameter(torch.randn(1, 1, J, H) * 0.02)

        blocks = []
        for _ in range(num_blocks):
            blocks.append(STSABlock(H, num_heads=num_heads, dropout=dropout, ffn_mult=ffn_mult))
        self.blocks = nn.ModuleList(blocks)

        self.pool_norm = nn.LayerNorm(H)
        self.cls = nn.Linear(H, num_classes)

    def forward(self, x):
        B, T, J, D = x.shape

        x = self.embed(x)
        x = x + self.pos_time + self.pos_joint

        for idx, blk in enumerate(self.blocks):
            x = blk(x)

        x = self.pool_norm(x)
        pooled = x.mean(dim=(1, 2))
        logits = self.cls(pooled)
        return logits

### Trening

TODO: Trzeba dodać coś co zamieni dataset na pytorchowe dataloadery.

In [None]:
import tqdm
import numpy as np

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score

DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 6
NUM_EPOCHS = 20
LR = 0.5e-3

model = STSANet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=3
)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0

    train_preds_all = []
    train_labels_all = []

    for clips, labels in tqdm.tqdm(train_loader, desc="Training"):
        clips = clips.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True).long()

        optimizer.zero_grad()

        outputs = model(clips)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * clips.size(0)

        _, preds = outputs.max(1)
        
        train_preds_all.append(preds.cpu().numpy())
        train_labels_all.append(labels.cpu().numpy())

    train_preds_all = np.concatenate(train_preds_all)
    train_labels_all = np.concatenate(train_labels_all)

    avg_train_loss = train_loss / len(train_labels_all)
    train_f1 = f1_score(train_labels_all, train_preds_all, average='macro')

    # validation
    model.eval()
    val_loss = 0.0
    val_preds_all = []
    val_labels_all = []

    with torch.no_grad():
        for clips, labels in tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1} [Valid]"):
            clips = clips.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True).long()

            outputs = model(clips)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * clips.size(0)

            _, preds = outputs.max(1)
            
            val_preds_all.append(preds.cpu().numpy())
            val_labels_all.append(labels.cpu().numpy())

    val_preds_all = np.concatenate(val_preds_all)
    val_labels_all = np.concatenate(val_labels_all)

    avg_val_loss = val_loss / len(val_labels_all)
    val_f1 = f1_score(val_labels_all, val_preds_all, average='macro')

    scheduler.step(val_f1)

    print(
        f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
        f"Train Loss: {avg_train_loss:.4f} | Train F1: {train_f1:.4f} "
        f"| Val Loss: {avg_val_loss:.4f} | Val F1: {val_f1:.4f}"
    )

### Analiza atencji modelu

TODO: Będzie trzeba dodać hooki do layerów modelu i zbierać gdzieś informacje o atencji i je wyplotować.