In [None]:
from google.colab import drive
%load_ext tensorboard
drive.mount('/content/drive', force_remount=True)

In [None]:
!python3 '/content/drive/My Drive/human_motion/main.py' --num_lstm 2 --latent_size 10 --lstm_hidden_size 100 --batch_size 4 --unbalancing_rate 0.5 --mode 'lstm' \
--train_continue 'off' \
--lstm_ckpt_dir '/content/drive/My Drive/human_motion/ex_classification/experiment_2/checkpoint' \
--lstm_log_dir '/content/drive/My Drive/human_motion/ex_classification/experiment_2/log' \
--oversampling_ckpt_dir '/content/drive/My Drive/human_motion/ex_oversampling/experiment_2/checkpoint' \
--oversampling_log_dir '/content/drive/My Drive/human_motion/ex_oversampling/experiment_2/log' \
--weight_balancing_ckpt_dir '/content/drive/My Drive/human_motion/ex_weight_balancing/experiment_2/checkpoint' \
--weight_balancing_log_dir '/content/drive/My Drive/human_motion/ex_weight_balancing/experiment_2/log' \
--feature_gan_ckpt_dir '/content/drive/My Drive/human_motion/ex_feature_gan/experiment_2/checkpoint' \
--feature_gan_log_dir '/content/drive/My Drive/human_motion/ex_feature_gan/experiment_2/log' \
--lstm_retrain_ckpt_dir '/content/drive/My Drive/human_motion/ex_lstm_retrain/experiment_2/checkpoint' \
--lstm_retrain_log_dir '/content/drive/My Drive/human_motion/ex_lstm_retrain/experiment_2/log'

In [None]:
import numpy as np
import torch
import torch.nn.functional as F


def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.
    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.
    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()

    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss



if __name__ == '__main__':
    no_of_classes = 5
    logits = torch.rand(10,no_of_classes).float()
    print(logits)
    labels = torch.randint(0,no_of_classes, size = (10,))
    print(labels)
    beta = 0.9999
    gamma = 2.0
    samples_per_cls = [2,3,1,2,2]
    loss_type = "focal"
    cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes,loss_type, beta, gamma)
    print(cb_loss)

In [None]:
print('='*40)
print('The entered "mode" does not exist')
print('='*40)

In [None]:
sample_per_cls = [1]*12
for i in range(12):
    if i % 2 ==0:
        sample_per_cls[i] *= 0.5

print(sample_per_cls)

In [None]:
from torch.autograd import Variable
import torch
import numpy as np
import torch.nn as nn

batch_size = 4
latent_dim = 10
n_classes = 12

label_emb = nn.Embedding(n_classes, latent_dim)

z = Variable(torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim))))
gen_labels = Variable(torch.LongTensor(np.random.randint(0, n_classes, batch_size)))

print(z)
print(gen_labels)

print(label_emb(gen_labels))
z.mul_(label_emb(gen_labels))
print(z)