In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Sample logits from teacher and student models
# Assume batch size = 3 and number of classes = 5
teacher_logits = torch.tensor([[2.0, 1.0, 0.1, 0.5, 0.3],
                               [1.2, 0.7, 0.8, 2.1, 0.4],
                               [0.9, 1.4, 0.5, 1.2, 0.3]])

student_logits = torch.tensor([[1.5, 1.1, 0.2, 0.4, 0.6],
                               [1.0, 0.8, 0.9, 1.9, 0.5],
                               [1.0, 1.3, 0.4, 1.1, 0.2]])

# Temperature parameter for scaling
temperature = 2.0

# Softening the logits using temperature scaling
def softmax_with_temperature(logits, temperature):
    return F.softmax(logits / temperature, dim=-1)

# Applying temperature scaling
teacher_probs = softmax_with_temperature(teacher_logits, temperature)
student_probs = softmax_with_temperature(student_logits, temperature)

# KL Divergence Loss function
def kl_divergence_loss(teacher_probs, student_probs):
    # Use KLDivLoss with reduction='batchmean' to average over the batch
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    
    # The input to KLDivLoss should be in log space for the student model
    loss = kl_loss(student_probs.log(), teacher_probs)
    return loss

# Calculating KL Divergence Loss
loss = kl_divergence_loss(teacher_probs, student_probs)
print(f"KL Divergence Loss: {loss.item()}")


KL Divergence Loss: 0.00500534987077117


In [1]:
import numpy as np 
import pandas as pd 


In [2]:
sample_data = np.load("../dataset/SLEEP/all/test__static_BuiVanCanh.npy")


In [5]:
from audiomentations import ClippingDistortion 




In [9]:
from audiomentations import Compose, TimeStretch, \
                            PitchShift, Shift, ClippingDistortion, \
                            Gain, GainTransition, Reverse, AddGaussianNoise
import numpy as np
from tqdm import tqdm

# clipping1 = ClippingDistortion(min_percentile_threshold=2, max_percentile_threshold=4, p=1.0,)
clipping = ClippingDistortion(min_percentile_threshold=1, max_percentile_threshold=2, p=1.0)
gain = Gain(min_gain_in_db=-2.0, max_gain_in_db=-1.1, p=1.0)
# gain2 = Gain(min_gain_in_db=-3.0, max_gain_in_db=-2.1, p=1.0)
gaintransition = GainTransition(min_gain_in_db=1.1, max_gain_in_db=2.0, p=1.0)
gaussnoise = AddGaussianNoise(min_amplitude=0.1, max_amplitude=1.2, p=0.5)
timestretch = TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5)
pitchshift = PitchShift(min_semitones=-4, max_semitones=4, p=0.5)
reverse = Reverse(p=1.0)
compose = Compose([
    AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
    PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
])

In [10]:
augments = [
    clipping,
    gain,
    gaintransition,
    # gaussnoise,
    # timestretch,
    # pitchshift,
    reverse,
    # shift,
]

In [15]:
from tslearn.neighbors import KNeighborsTimeSeries
from tslearn.barycenters import softdtw_barycenter
from tslearn.metrics import gamma_soft_dtw

In [13]:
data = torch.randn(16, 3, 1, 100 )

In [16]:
from tslearn.neighbors import KNeighborsTimeSeries
from tslearn.barycenters import softdtw_barycenter
from tslearn.metrics import gamma_soft_dtw

In [32]:
def softdtw_augment_train_set(x_train, y_train, classes, num_synthetic_ts, max_neighbors=5): 
    from tslearn.neighbors import KNeighborsTimeSeries
    from tslearn.barycenters import softdtw_barycenter
    from tslearn.metrics import gamma_soft_dtw
    
    # synthetic train set and labels 
    synthetic_x_train = []
    synthetic_y_train = []
    # loop through each class
    for c in classes:
        # get the MTS for this class 
        c_x_train = x_train[np.where(y_train==c)]
        if len(c_x_train) == 1 :
            # skip if there is only one time series per set
            continue
        # compute appropriate gamma for softdtw for the entire class
        
        class_gamma = gamma_soft_dtw(c_x_train)
        # loop through the number of synthtectic examples needed
        generated_samples = 0
        while generated_samples < num_synthetic_ts:
            # Choose a random representative for the class
            representative_indices = np.arange(len(c_x_train))
            random_representative_index = np.random.choice(representative_indices, size=1, replace=False)
            random_representative = c_x_train[random_representative_index]
            # Choose a random number of neighbors (between 1 and one minus the total number of class representatives)
            random_number_of_neighbors = int(np.random.uniform(1, max_neighbors, size=1))
            knn = KNeighborsTimeSeries(n_neighbors=random_number_of_neighbors+1, metric='softdtw', metric_params={'gamma': class_gamma}).fit(c_x_train)
            random_neighbor_distances, random_neighbor_indices = knn.kneighbors(X=random_representative, return_distance=True)
            random_neighbor_indices = random_neighbor_indices[0]
            random_neighbor_distances = random_neighbor_distances[0]
            nearest_neighbor_distance = np.sort(random_neighbor_distances)[1]
            random_neighbors = np.zeros((random_number_of_neighbors+1, c_x_train.shape[1], c_x_train.shape[2]), dtype=float)
            for j, neighbor_index in enumerate(random_neighbor_indices):
                random_neighbors[j,:] = c_x_train[neighbor_index]
            # Choose a random weight vector (and then normalize it)
            weights = np.exp(np.log(0.5)*random_neighbor_distances/nearest_neighbor_distance) + 0.0000001
            weights /= np.sum(weights)
            # Compute tslearn.barycenters.softdtw_barycenter with weights=random weights and gamma value specific to neighbors
            random_neighbors_gamma = gamma_soft_dtw(random_neighbors)
            generated_sample = softdtw_barycenter(random_neighbors, weights=weights, gamma=random_neighbors_gamma)
            synthetic_x_train.append(generated_sample)
            synthetic_y_train.append(c)         
            # Repeat until you have the desired number of synthetic samples for each class
            generated_samples += 1
    # return the synthetic set 
    return np.array(synthetic_x_train), np.array(synthetic_y_train)

In [88]:
X = np.random.random((16, 3, 1 , 100))

In [76]:
classes = np.arange(0, 12)
classes

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

In [77]:
y = np.random.randint(0, 12, size=16)

In [78]:
y

array([ 8,  9,  2,  3,  4,  2, 10,  2, 10, 11, 11,  5,  1,  6, 10,  4])

In [79]:
data = sample_data[:, 1:4]
label = sample_data[:, -1]
print(data.shape)

(36000, 3)


In [80]:
X = np.random.rand()

In [81]:
import random

In [94]:
data = torch.rand((16, 3,1,100))

In [108]:
def augment_data(data, aug_methods = None):
    
    b, d , _ , s = data.shape
    data = data.reshape(b, s, d)
    data_aug = np.array([]) 
    for X in data: 
        x = X[:, 0]
        y = X[:, 1]
        z = X[:, 2]
        method = random.choice(aug_methods)
        X_aug = method(samples=x, sample_rate=8000)
        Y_aug = method(samples=y, sample_rate=8000)
        Z_aug = method(samples=z, sample_rate=8000)
        aug_data = np.transpose(np.array([X_aug, Y_aug, Z_aug]))
        
        if data_aug.shape[0] == 0:
            data_aug = np.expand_dims( np.transpose(np.array([X_aug, Y_aug, Z_aug])),  axis=0)
            
        else:
            data_aug = np.concatenate([data_aug,np.expand_dims( np.transpose(np.array([X_aug, Y_aug, Z_aug])),  axis=0)], axis = 0)
    return torch.tensor(data_aug.reshape(b, d, 1, s), dtype = torch.float64)
        

In [109]:
X_ = data.numpy()

In [110]:
import random 
random.choice(augments)

<audiomentations.augmentations.gain_transition.GainTransition at 0x1642b684130>

In [111]:
X_aug = augment_data(data=X_, aug_methods=augments)

In [112]:
X_aug

tensor([[[[0.2536, 0.5685, 0.2473,  ..., 0.2194, 0.7682, 0.8477]],

         [[0.0309, 0.1892, 0.0762,  ..., 0.2703, 0.8331, 0.7106]],

         [[0.1296, 0.2756, 0.3507,  ..., 0.0644, 0.1638, 0.5196]]],


        [[[0.0645, 0.9293, 0.2191,  ..., 0.5499, 1.1299, 0.7302]],

         [[0.9811, 1.1445, 0.2802,  ..., 0.2461, 0.3684, 1.0350]],

         [[1.0778, 0.3801, 0.0821,  ..., 0.1415, 1.1435, 0.7015]]],


        [[[0.4483, 0.9199, 1.0368,  ..., 0.3495, 1.1591, 0.6289]],

         [[0.7175, 0.3225, 1.1039,  ..., 0.6642, 1.1281, 0.7562]],

         [[0.0991, 0.7617, 1.0161,  ..., 0.5629, 1.1050, 0.0405]]],


        ...,


        [[[0.3789, 0.3624, 0.4582,  ..., 0.0122, 0.2401, 0.9509]],

         [[0.9689, 0.9098, 0.1220,  ..., 0.8096, 0.3053, 0.7822]],

         [[0.4266, 0.3893, 0.7411,  ..., 0.9236, 0.8019, 0.8451]]],


        [[[0.2012, 0.2984, 0.3266,  ..., 0.5943, 0.2121, 0.5822]],

         [[0.3819, 0.2409, 0.8805,  ..., 0.8280, 0.0135, 0.2906]],

         [[0.4073, 0.0603

In [None]:

# for augment_method in augments:
#     X_aug = augment_method(samples=X_vsl, sample_rate=8000)
#     Y_aug = augment_method(samples=Y_vsl, sample_rate=8000)
#     Z_aug = augment_method(samples=Z_vsl, sample_rate=8000)
#     aug_data = np.transpose(np.array([X_aug, Y_aug, Z_aug, label_pos]))
    

In [None]:
def augemt_signal(data, labels, augements):
  data_aug, labels_aug = np.array([]), np.array([])
  errors_data = []
  for i in tqdm(range(len(data)), total=len(data)):
        # X
    X_vsl = data[i, :, 0]

    # Y
    Y_vsl = data[i, :, 1]

    # Z
    Z_vsl = data[i, :, 2]
    label_pos = data[i, :, 3]
    aug_errors = []
    for augment_method in augments:
        X_aug = augment_method(samples=X_vsl, sample_rate=8000)
        Y_aug = augment_method(samples=Y_vsl, sample_rate=8000)
        Z_aug = augment_method(samples=Z_vsl, sample_rate=8000)
        aug_data = np.transpose(np.array([X_aug, Y_aug, Z_aug, label_pos]))
        if data_aug.shape[0] == 0:
           data_aug = np.expand_dims(np.transpose(np.array([X_aug, Y_aug, Z_aug, label_pos])), axis=0)
           labels_aug = np.expand_dims(labels[i], axis=0)
        else:
          data_aug = np.concatenate([data_aug,np.expand_dims(np.transpose(np.array([X_aug, Y_aug, Z_aug, label_pos])), axis=0)], axis = 0)
          labels_aug = np.concatenate([labels_aug, np.expand_dims(labels[i],axis=0)], axis=0)
    errors_data.append(aug_errors)
  return data_aug, labels_aug, errors_data