In [None]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import f1_score

In [None]:
train = np.load("./dataset/train_split.npy", allow_pickle=True).item()
val = np.load("./dataset/val_split.npy", allow_pickle=True).item()
test = np.load("./dataset/test-release.npy", allow_pickle=True).item()

In [None]:
labeled_x = np.concatenate([seq["keypoints"] for seq in train["sequences"].values()], axis=0)
labeled_x = torch.tensor(labeled_x).reshape(-1, 28)
y = np.concatenate([seq["annotations"] for seq in train["sequences"].values()], axis=0)
y = torch.tensor(y)
labeled_x.shape, y.shape

In [None]:
unlabeled_x = np.concatenate([seq["keypoints"] for seq in test["sequences"].values()], axis=0)
unlabeled_x = torch.tensor(unlabeled_x).reshape(-1, 28)
unlabeled_x.shape

In [None]:
std, mean = torch.std_mean(torch.cat((labeled_x, unlabeled_x), dim=0), dim=0)
labeled_x = (labeled_x - mean) / std
unlabeled_x = (unlabeled_x - mean) / std

In [None]:
val_x = np.concatenate([seq["keypoints"] for seq in val["sequences"].values()], axis=0)
val_x = torch.tensor(val_x).reshape(-1, 28)
val_x = (val_x - mean) / std
val_y = np.concatenate([seq["annotations"] for seq in val["sequences"].values()], axis=0)
val_y = torch.tensor(val_y)
val_x.shape, val_y.shape

In [None]:
def gaussian_parameters(h, dim=-1):
    m, h = torch.split(h, h.size(dim) // 2, dim=dim)
    v = F.softplus(h) + 1e-8
    return m, v

def sample_gaussian(m, v):
    unit_sample = torch.normal(torch.zeros(m.shape), torch.ones(v.shape)).to(device)
    z = m + torch.sqrt(v) * unit_sample
    return z

def kl_cat(q, log_q, log_p):
    element_wise = (q * (log_q - log_p))
    kl = element_wise.sum(-1)
    return kl

def kl_normal(qm, qv, pm, pv):
    element_wise = 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)
    kl = element_wise.sum(-1)
    return kl

def log_normal(x, m, v):
    log_prob = -0.5 * torch.log(2 * np.pi * v) - (x - m) ** 2 / (2 * v)
    log_prob = log_prob.sum(-1)
    return log_prob

def duplicate(x, rep):
    return x.expand(rep, *x.shape).reshape(-1, *x.shape[1:])

In [None]:
class Encoder(nn.Module):
    def __init__(self, z_dim, y_dim):
        super().__init__()
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.net = nn.Sequential(
            nn.Linear(28 + y_dim, 64),
            nn.ELU(),
            nn.Linear(64, 2 * z_dim),
        )

    def encode(self, x, y):
        xy = torch.cat((x, y), dim=1)
        h = self.net(xy)
        m, v = gaussian_parameters(h, dim=1)
        return m, v

class Decoder(nn.Module):
    def __init__(self, z_dim, y_dim):
        super().__init__()
        self.z_dim = z_dim
        self.y_dim = y_dim
        self.net = nn.Sequential(
            nn.Linear(z_dim + y_dim, 64),
            nn.ELU(),
            nn.Linear(64, 28)
        )

    def decode(self, z, y):
        zy = torch.cat((z, y), dim=1)
        return self.net(zy)

class Classifier(nn.Module):
    def __init__(self, y_dim):
        super().__init__()
        self.y_dim = y_dim
        self.net = nn.Sequential(
            nn.Linear(28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, y_dim)
        )

    def classify(self, x):
        return self.net(x)

In [None]:
class SSVAE(nn.Module):
    def __init__(self, gen_weight=1, class_weight=100):
        super().__init__()
        self.z_dim = 16
        self.y_dim = 4
        self.gen_weight = gen_weight
        self.class_weight = class_weight
        self.enc = Encoder(self.z_dim, self.y_dim)
        self.dec = Decoder(self.z_dim, self.y_dim)
        self.cls = Classifier(self.y_dim)

        # Set prior as fixed parameter attached to Module
        self.z_prior_m = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
        self.z_prior_v = torch.nn.Parameter(torch.ones(1), requires_grad=False)

    def negative_elbo_bound(self, x):
        y_logits = self.cls.classify(x)
        y_logprob = F.log_softmax(y_logits, dim=1)
        y_prob = torch.softmax(y_logprob, dim=1) # (batch, y_dim)

        # Duplicate y based on x's batch size. Then duplicate x
        # This enumerates all possible combination of x with labels (0, 1, 2, 3)
        y = np.repeat(np.arange(self.y_dim), x.size(0))
        y = x.new(np.eye(self.y_dim)[y])
        x = duplicate(x, self.y_dim)

        z_m, z_v = self.enc.encode(x, y)
        z = sample_gaussian(z_m, z_v)
        x_m = self.dec.decode(z, y)
        x_v = torch.tensor(0.1).repeat(x_m.shape).to(device)

        kl_y = kl_cat(y_prob, y_logprob, np.log(1.0 / self.y_dim))
        kl_z = kl_normal(z_m, z_v, self.z_prior_m, self.z_prior_v)
        kl_z = (kl_z.reshape(self.y_dim, -1) * y_prob.t()).sum(dim=0)
        rec = -log_normal(x, x_m, x_v)
        rec = (rec.reshape(self.y_dim, -1) * y_prob.t()).sum(dim=0)
        nelbo = kl_y + kl_z + rec

        nelbo, kl_z, kl_y, rec = nelbo.mean(), kl_z.mean(), kl_y.mean(), rec.mean()
        return nelbo, kl_z, kl_y, rec

    def classification_cross_entropy(self, x, y):
        y_logits = self.cls.classify(x)
        return F.cross_entropy(y_logits, y.argmax(1))

    def loss(self, x, xl, yl):
        if self.gen_weight > 0:
            nelbo, kl_z, kl_y, rec = self.negative_elbo_bound(x)
        else:
            nelbo, kl_z, kl_y, rec = [0] * 4
        ce = self.classification_cross_entropy(xl, yl)
        loss = self.gen_weight * nelbo + self.class_weight * ce

        summaries = dict((
            ('train/loss', loss),
            ('class/ce', ce),
            ('gen/elbo', -nelbo),
            ('gen/kl_z', kl_z),
            ('gen/kl_y', kl_y),
            ('gen/rec', rec),
        ))

        return loss, summaries

    def compute_sigmoid_given(self, z, y):
        logits = self.dec.decode(z, y)
        return torch.sigmoid(logits)

    def sample_z(self, batch):
        return ut.sample_gaussian(self.z_prior[0].expand(batch, self.z_dim),
                                  self.z_prior[1].expand(batch, self.z_dim))

    def sample_x_given(self, z, y):
        return torch.bernoulli(self.compute_sigmoid_given(z, y))


In [None]:
epochs = 30
device = "cuda"
labeled_batch_size = 64
unlabeled_batch_size = 320
learning_rate = 1e-3
seed = 1

In [None]:
torch.manual_seed(seed);
labeled_dataset = TensorDataset(labeled_x, y)
unlabeled_dataset = TensorDataset(unlabeled_x)
labeled_loader = DataLoader(labeled_dataset, batch_size=labeled_batch_size, shuffle=True)

In [None]:
model = SSVAE(gen_weight=0).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
    total_loss = 0
    total_count = 0
    for i, (xl, yl) in enumerate(labeled_loader):
        optimizer.zero_grad()

        unlabeled_indices = torch.randint(0, len(unlabeled_dataset), (unlabeled_batch_size,))
        xu = unlabeled_dataset[unlabeled_indices][0]

        xu = xu.to(device).float()
        yl = yl.new(np.eye(4)[yl]).to(device).float()
        xl = xl.to(device).float()
        loss, summaries = model.loss(xu, xl, yl)
    
        loss.backward()
        optimizer.step()

        total_loss += loss.detach().item() * len(xl)
        total_count += len(xl)

        if i % 3000 == 0:
            print(f"Train_loss={total_loss / total_count:.4f}")

    train_pred = model.cls.classify(labeled_x.to(device).float()).argmax(1)
    train_f1 = f1_score(y, train_pred.cpu(), average="macro", labels=[0, 1, 2])
    val_pred = model.cls.classify(val_x.to(device).float()).argmax(1)
    val_f1 = f1_score(val_y, val_pred.cpu(), average="macro", labels=[0, 1, 2])

    print(f"Epoch {epoch}, Train-F1={train_f1:.4f}, Val-F1={val_f1:.4f}")

In [None]:
model = SSVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
for epoch in range(epochs):
    total_loss = 0
    total_count = 0
    for i, (xl, yl) in enumerate(labeled_loader):
        optimizer.zero_grad()

        unlabeled_indices = torch.randint(0, len(unlabeled_dataset), (unlabeled_batch_size,))
        xu = unlabeled_dataset[unlabeled_indices][0]

        xu = xu.to(device).float()
        yl = yl.new(np.eye(4)[yl]).to(device).float()
        xl = xl.to(device).float()
        loss, summaries = model.loss(xu, xl, yl)
    
        loss.backward()
        optimizer.step()

        total_loss += loss.detach().item() * len(xl)
        total_count += len(xl)

        if i % 3000 == 0:
            print(f"Train_loss={total_loss / total_count:.4f}")

    train_pred = model.cls.classify(labeled_x.to(device).float()).argmax(1)
    train_f1 = f1_score(y, train_pred.cpu(), average="macro", labels=[0, 1, 2])
    val_pred = model.cls.classify(val_x.to(device).float()).argmax(1)
    val_f1 = f1_score(val_y, val_pred.cpu(), average="macro", labels=[0, 1, 2])

    print(f"Epoch {epoch}, Train-F1={train_f1:.4f}, Val-F1={val_f1:.4f}")