In [None]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from sklearn.metrics import f1_score
from numpy.lib.stride_tricks import sliding_window_view

In [None]:
train = np.load("train_split.npy", allow_pickle=True).item()
val = np.load("val_split.npy", allow_pickle=True).item()
test = np.load("test-release.npy", allow_pickle=True).item()

In [None]:
def attach_neighbor_frames(x, num_frames):
    if num_frames == 1:
        return x
    results = []
    for seq in x:
        D = seq.shape[1]
        pad = np.zeros(((num_frames - 1) // 2, D)) + 1e-8
        padded = np.vstack([pad, seq, pad])
        windows = sliding_window_view(padded, window_shape=(num_frames, D)).squeeze(axis=1)
        results.append(windows.transpose(0, 2, 1))
    return np.concatenate(results, axis=0)

In [None]:
num_frames_per_cls = 201

In [None]:
labeled_x = [seq["keypoints"].reshape(-1, 28) for seq in train["sequences"].values()]
unlabeled_x = [seq["keypoints"].reshape(-1, 28) for seq in test["sequences"].values()]

In [None]:
labeled_x = attach_neighbor_frames(labeled_x, num_frames_per_cls)
labeled_x = torch.tensor(labeled_x)
y = np.concatenate([seq["annotations"] for seq in train["sequences"].values()], axis=0)
y = torch.tensor(y)
labeled_x.shape, y.shape

In [None]:
val_x = [seq["keypoints"].reshape(-1, 28) for seq in val["sequences"].values()]
val_x = attach_neighbor_frames(val_x, num_frames_per_cls)
val_x = torch.tensor(val_x)
val_y = np.concatenate([seq["annotations"] for seq in val["sequences"].values()], axis=0)
val_y = torch.tensor(val_y)
val_x.shape, val_y.shape

In [None]:
class UnlabeledDataset(Dataset):
    def __init__(self, data, num_frames):
        super(UnlabeledDataset).__init__()
        self.data = data
        self.num_frames = num_frames
        self.lengths = np.cumsum(list(map(len, data)))

    def __len__(self):
        return self.lengths[-1]

    def __getitem__(self, index):
        seq_index = self._find_seq_index(index)
        video = self.data[seq_index]
        frame_index = index if seq_index == 0 else index - self.lengths[seq_index - 1]
        frame = video[[frame_index]]
        window_size = (self.num_frames - 1) // 2
        if frame_index >= window_size:
            window_before = video[frame_index - window_size: frame_index]
        else:
            pad = np.zeros((window_size - frame_index, frame.shape[1])) + 1e-8
            window_before = np.concatenate((pad, video[0: frame_index]), axis=0)
        if frame_index + window_size < len(video):
            window_after = video[frame_index + 1: frame_index + window_size + 1]
        else:
            pad = np.zeros((window_size - (len(video) - 1 - frame_index), frame.shape[1])) + 1e-8
            window_after = np.concatenate((video[frame_index + 1:], pad), axis=0)
        data = np.concatenate((window_before, frame, window_after), axis=0)
        return torch.from_numpy(data).permute(1, 0)

    def _find_seq_index(self, index):
        start, end = 0, len(self.lengths) - 1
        while start + 1 < end:
            mid = (start + end) // 2
            if self.lengths[mid] <= index:
                start = mid
            else:
                end = mid
        if self.lengths[start] > index:
            return start
        return end

In [None]:
labeled_dataset = TensorDataset(labeled_x, y)
val_dataset = TensorDataset(val_x, val_y)
unlabeled_dataset = UnlabeledDataset(unlabeled_x, num_frames_per_cls)

In [None]:
def gaussian_parameters(h, dim=-1):
    m, h = torch.split(h, h.size(dim) // 2, dim=dim)
    v = F.softplus(h) + 1e-8
    return m, v

def sample_gaussian(m, v):
    unit_sample = torch.normal(torch.zeros(m.shape), torch.ones(v.shape)).to(device)
    z = m + torch.sqrt(v) * unit_sample
    return z

def kl_cat(q, log_q, log_p):
    element_wise = (q * (log_q - log_p))
    kl = element_wise.sum(-1)
    return kl

def kl_normal(qm, qv, pm, pv):
    element_wise = 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)
    kl = element_wise.sum(-1)
    return kl

def log_normal(x, m, v):
    log_prob = -0.5 * (torch.log(v) + (x - m).pow(2) / v + np.log(2 * np.pi))
    log_prob = log_prob.sum(-1)
    return log_prob

def duplicate(x, rep):
    return x.expand(rep, *x.shape).reshape(-1, *x.shape[1:])

In [None]:
class Encoder(nn.Module):
    def __init__(self, z_dim, y_dim):
        super().__init__()
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.net = nn.Sequential(
            nn.Linear(28 * num_frames_per_cls + y_dim, 256),
            nn.ELU(),
            nn.Linear(256, 2 * z_dim),
        )

    def encode(self, x, y):
        xy = torch.cat((x, y), dim=1)
        h = self.net(xy)
        m, v = gaussian_parameters(h, dim=1)
        return m, v

class Decoder(nn.Module):
    def __init__(self, z_dim, y_dim):
        super().__init__()
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.net = nn.Sequential(
            nn.Linear(z_dim + y_dim, 256),
            nn.ELU(),
            nn.Linear(256, 28 * num_frames_per_cls),
        )

    def decode(self, z, y):
        zy = torch.cat((z, y), dim=1)
        return self.net(zy)

class Classifier(nn.Module):
    def __init__(self, y_dim):
        super().__init__()
        self.y_dim = y_dim
        self.net = nn.Sequential(
            nn.BatchNorm1d(28),
            
            nn.Conv1d(28, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.MaxPool1d(kernel_size=2, stride=2),
            
            nn.Conv1d(512, 256, kernel_size=5, padding=2),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.MaxPool1d(kernel_size=2, stride=2),

            nn.AvgPool1d(kernel_size=50),
            nn.Flatten(1, 2),
            nn.Linear(256, self.y_dim),
        )

    def classify(self, x):
        return self.net(x)

In [None]:
class SSVAE(nn.Module):
    def __init__(self, gen_weight=.01, class_weight=1):
        super().__init__()
        self.z_dim = 128
        self.y_dim = 4
        self.gen_weight = gen_weight
        self.class_weight = class_weight
        self.enc = Encoder(self.z_dim, self.y_dim)
        self.dec = Decoder(self.z_dim, self.y_dim)
        self.cls = Classifier(self.y_dim)

        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)

    def negative_elbo_bound(self, x):
        y_logits = self.cls.classify(x)
        y_logprob = F.log_softmax(y_logits, dim=1)
        y_prob = torch.softmax(y_logprob, dim=1) # (batch, y_dim)

        # Duplicate y based on x's batch size. Then duplicate x
        # This enumerates all possible combination of x with labels (0, 1, 2, 3)
        y = np.repeat(np.arange(self.y_dim), x.size(0))
        y = x.new(np.eye(self.y_dim)[y])
        x = x.flatten(1, 2)
        x = duplicate(x, self.y_dim)

        z_m, z_v = self.enc.encode(x, y)
        z = sample_gaussian(z_m, z_v)
        x_m = self.dec.decode(z, y)
        x_v = torch.tensor(0.1).repeat(x_m.shape).to(device)

        kl_y = kl_cat(y_prob, y_logprob, np.log(1.0 / self.y_dim))
        kl_z = kl_normal(z_m, z_v, self.z_prior_m, self.z_prior_v)
        kl_z = (kl_z.reshape(self.y_dim, -1) * y_prob.t()).sum(dim=0)
        rec = -log_normal(x, x_m, x_v)
        rec = (rec.reshape(self.y_dim, -1) * y_prob.t()).sum(dim=0)
        nelbo = kl_y + kl_z + rec

        nelbo, kl_z, kl_y, rec = nelbo.mean(), kl_z.mean(), kl_y.mean(), rec.mean()
        return nelbo, kl_z, kl_y, rec

    def classification_cross_entropy(self, x, y):
        y_logits = self.cls.classify(x)
        return F.cross_entropy(y_logits, y.argmax(1))

    def loss(self, x, xl, yl):
        if self.gen_weight > 0:
            nelbo, kl_z, kl_y, rec = self.negative_elbo_bound(x)
        else:
            nelbo, kl_z, kl_y, rec = [0] * 4
        ce = self.classification_cross_entropy(xl, yl)
        loss = self.gen_weight * nelbo + self.class_weight * ce

        summaries = dict((
            ('train/loss', loss),
            ('class/ce', ce),
            ('gen/elbo', -nelbo),
            ('gen/kl_z', kl_z),
            ('gen/kl_y', kl_y),
            ('gen/rec', rec),
        ))

        return loss, summaries

In [None]:
def evaluate_f1(model, dataloader):
    y_true = []
    y_pred = []
    for i, (xl, yl) in enumerate(dataloader):
        y_true.extend(yl)
        with torch.no_grad():
            y_pred.extend(model.cls.classify(xl.to(device).float()).argmax(1).cpu().numpy())
    return f1_score(y_true, y_pred, average="macro", labels=[0, 1, 2])

In [None]:
epochs = 5
device = "cuda"
labeled_batch_size = 64
unlabeled_batch_size = 320
learning_rate = 1e-3
seed = 1

In [None]:
torch.manual_seed(seed);
labeled_loader = DataLoader(labeled_dataset, batch_size=labeled_batch_size, shuffle=True)

In [None]:
model = SSVAE(gen_weight=0).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
    total_loss = 0
    total_count = 0
    for i, (xl, yl) in enumerate(labeled_loader):
        optimizer.zero_grad()

        unlabeled_indices = torch.randint(0, len(unlabeled_dataset), (unlabeled_batch_size,))
        xu = torch.stack([unlabeled_dataset[index] for index in unlabeled_indices], dim=0)

        xu = xu.to(device).float()
        yl = yl.new(np.eye(4)[yl]).to(device).float()
        xl = xl.to(device).float()
        loss, summaries = model.loss(xu, xl, yl)

        loss.backward()
        optimizer.step()

        total_loss += loss.detach().item() * len(xl)
        total_count += len(xl)

        if i % 1000 == 0:
            train_f1 = evaluate_f1(model, DataLoader(labeled_dataset, batch_size=10000))
            val_f1 = evaluate_f1(model, DataLoader(val_dataset, batch_size=10000))
            print(f"Epoch {epoch}, Train_loss={total_loss / total_count:.4f}, Train-F1={train_f1:.4f}, Val-F1={val_f1:.4f}")

In [None]:
# model = SSVAE().to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# for epoch in range(epochs):
#     total_loss = 0
#     total_count = 0
#     for i, (xl, yl) in enumerate(labeled_loader):
#         optimizer.zero_grad()

#         unlabeled_indices = torch.randint(0, len(unlabeled_dataset), (unlabeled_batch_size,))
#         xu = torch.stack([unlabeled_dataset[index] for index in unlabeled_indices], dim=0)

#         xu = xu.to(device).float()
#         yl = yl.new(np.eye(4)[yl]).to(device).float()
#         xl = xl.to(device).float()
#         loss, summaries = model.loss(xu, xl, yl)
    
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.detach().item() * len(xl)
#         total_count += len(xl)

#         if i % 1000 == 0:
#             print(f"Train_loss={total_loss / total_count:.4f}")
#             train_f1 = evaluate_f1(model, DataLoader(labeled_dataset, batch_size=1000))
#             val_f1 = evaluate_f1(model, DataLoader(val_dataset, batch_size=1000))
#             print(f"Epoch {epoch}, Train-F1={train_f1:.4f}, Val-F1={val_f1:.4f}")