In [8]:
import numpy as np
from memory_profiler import memory_usage


def make_spiral_moons(n_samples=10000, k=4, noise=0.1, imbalance=True, padding=0.3, random_state=None):
    np.random.seed(random_state)
    
    x_stack = np.array([])
    y_stack = np.array([])
    labels = np.array([])    
    
    if imbalance:
        sizes = _random_split(n_samples, k, random_state=random_state)
    else:
        eps = np.append(np.ones(n_samples % k), np.zeros(k - n_samples % k))
        eps = np.random.permutation(eps).astype(np.int)
        sizes = [n_samples // k + eps[i] for i in range(k-1)]
        sizes.append(n_samples - np.sum(sizes))
    for i in range(k):
        size = sizes[i]
        x = np.random.normal(loc=np.pi/2+padding, scale=0.5, size=size)
        sin_gauss = np.sin(np.linspace(0, np.pi, size)) * (np.random.normal(loc=0, scale=noise, size=size))
        y = np.sin(x - padding) - .2 + sin_gauss
        theta = 2*np.pi * i / k 
        x_ = np.cos(theta)*x - np.sin(theta)*y
        y_ = np.sin(theta)*x + np.cos(theta)*y
        label = (np.ones(len(x_)) * i).astype(np.int)
        x_stack = np.append(x_stack, x_)
        y_stack = np.append(y_stack, y_)
        labels = np.append(labels, label)
    x_stack = np.ravel(x_stack)
    y_stack = np.ravel(y_stack)
    labels = np.ravel(labels)
    return x_stack, y_stack, labels
  
def _random_split(n_samples, k, random_state=None):
    np.random.seed(random_state)
    div = np.random.choice(list(range(1, n_samples-1)), k-1, replace=False).astype(np.int)
    div = np.sort(div)
    div = np.append(div, n_samples)
    ret = []
    x = 0
    for i in range(k):
        x_ = div[i] - x
        ret.append(x_)
        x = div[i]
    return ret

In [9]:
import matplotlib.pyplot as plt

x, y, labels = make_spiral_moons(k=8, noise=0.2, imbalance=False)
data = np.stack((x, y), -1).astype(np.float32)
print(data.shape)

(10000, 2)


In [14]:
import faiss
import time

dim=2
# make index using brute-force L2 distance searching algorithm
index = faiss.IndexFlatL2(dim)
index.add(data)    



def trial():
    k = 100    # we want to see 4 nearest neighbors
    n_epoch = 10
    total = 0
    for i in range(n_epoch):
        start = time.time()
        D, I = index.search(data, k) # D: distance between each data points and top k neibors, I: index of top k neibors including original data
        ellapse = time.time() - start
        total += ellapse
    print(total / n_epoch)

memory_out = memory_usage(trial)
print(np.mean(memory_out))





0.40184309482574465
159.65625


In [15]:


nlist = 10
k = 10
quantizer = faiss.IndexFlatL2(dim)  # 
index = faiss.IndexIVFFlat(quantizer, dim, nlist) # Voronoi
assert not index.is_trained
index.train(data)
assert index.is_trained
index.add(data)


def trial():
    n_epoch = 100
    total = 0
    for i in range(n_epoch):
        start = time.time()
        D, I = index.search(data, k) # D: distance between each data points and top k neibors, I: index of top k neibors including original data
        ellapse = time.time() - start
        total += ellapse
    print(total / n_epoch)

memory_out = memory_usage(trial)
print(np.mean(memory_out))

0.04367564201354981
146.9605129076087


In [17]:
m = 2
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, 8)
assert not index.is_trained
index.train(data)
assert index.is_trained
index.add(data)

def trial():
    n_epoch = 100
    total = 0
    for i in range(n_epoch):
        start = time.time()
        D, I = index.search(data, k) # D: distance between each data points and top k neibors, I: index of top k neibors including original data
        ellapse = time.time() - start
        total += ellapse
    print(total / n_epoch)
    
memory_out = memory_usage(trial)
print(memory_out)



0.03906921863555908
[138.81640625, 138.81640625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625, 134.97265625]


In [None]:


class FaissKNNClassifier:
    """ Scikit-learn wrapper interface for Faiss KNN.
    Parameters
    ----------
    n_neighbors : int (Default = 5)
                Number of neighbors used in the nearest neighbor search.
    n_jobs : int (Default = None)
             The number of jobs to run in parallel for both fit and predict.
              If -1, then the number of jobs is set to the number of cores.
    algorithm : {'brute', 'voronoi'} (Default = 'brute')
        Algorithm used to compute the nearest neighbors:
            - 'brute' will use the :class: `IndexFlatL2` class from faiss.
            - 'voronoi' will use :class:`IndexIVFFlat` class from faiss.
            - 'hierarchical' will use :class:`IndexHNSWFlat` class from faiss.
        Note that selecting 'voronoi' the system takes more time during
        training, however it can significantly improve the search time
        on inference. 'hierarchical' produce very fast and accurate indexes,
        however it has a higher memory requirement. It's recommended when
        you have a lots of RAM or the dataset is small.
        For more information see: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
    n_cells : int (Default = 100)
        Number of voronoi cells. Only used when algorithm=='voronoi'.
    n_probes : int (Default = 1)
        Number of cells that are visited to perform the search. Note that the
        search time roughly increases linearly with the number of probes.
        Only used when algorithm=='voronoi'.
    References
    ----------
    Johnson Jeff, Matthijs Douze, and Hervé Jégou. "Billion-scale similarity
    search with gpus." arXiv preprint arXiv:1702.08734 (2017).
    """

    def __init__(self,
                 n_neighbors=5,
                 n_jobs=None,
                 algorithm='brute',
                 n_cells=100,
                 n_probes=1):

        self.n_neighbors = n_neighbors
        self.n_jobs = n_jobs
        self.algorithm = algorithm
        self.n_cells = n_cells
        self.n_probes = n_probes

        import faiss
        self.faiss = faiss

    def predict(self, X):
        """Predict the class label for each sample in X.
        Parameters
        ----------
        X : array of shape (n_samples, n_features)
            The input data.
        Returns
        -------
        preds : array, shape (n_samples,)
                Class labels for samples in X.
        """
        idx = self.kneighbors(X, self.n_neighbors, return_distance=False)
        class_idx = self.y_[idx]
        counts = np.apply_along_axis(
            lambda x: np.bincount(x, minlength=self.n_classes_), axis=1,
            arr=class_idx.astype(np.int16))
        preds = np.argmax(counts, axis=1)
        return preds

    def kneighbors(self, X, n_components=None, return_distance=True):
        n_components = n_components or self.n_components

        elif n_components <= 0:
            raise ValueError("Expected n_components > 0. Got {}".format(n_components))
        else:
            if not np.issubdtype(type(n_components), np.integer):
                raise TypeError("n_components does not take {} value, enter integer value".format(type(n_components)))

        check_is_fitted(self, 'index_')

        X = np.atleast_2d(X).astype(np.float32)
        dist, idx = self.index_.search(X, n_neighbors)
        if return_distance:
            return dist, idx
        else:
            return idx

    def predict(self, X):
        idx = self.kneighbors(X, self.n_neighbors, return_distance=False)
        class_idx = self.y_[idx]
        counts = np.apply_along_axis(
            lambda x: np.bincount(x, minlength=self.n_classes_), axis=1,
            arr=class_idx.astype(np.int16))

        preds = counts / self.n_neighbors

        return preds_proba

    def fit(self, X, labels):
        X = np.atleast_2d(X).astype(np.float32)
        X = np.ascontiguousarray(X)
        self.index_ = self._get_index(X)
        self.index_.add(X)
        self.labels_ = labels
        self.n_classes_ = len(np.unique(labels))
        return self

    def _get_index(self, X):
        dim = X.shape[1]
        if self.algorithm == 'brute':
            index = self.faiss.IndexFlatL2(d)
        elif self.algorithm == 'voronoi':
            quantizer = self.faiss.IndexFlatL2(dim)
            index = self.faiss.IndexIVFFlat(quantizer, dim, self.ncells)
            index.train(X)
            index.nprobe = self.nprobe
        elif self.algorithm == 'hierarchical':
            index = self.faiss.IndexHNSWFlat(dim, 32)
            index.hnsw.efConstruction = 40
        else:
            raise ValueError("Invalid algorithm option. Expected ['brute', 'voronoi', 'hierarchical'], got {}".format(self.algorithm))
        return index
            
            

In [21]:
import numpy as np

n_components = np.array(2)
print(n_components)
print(type(n_components))
print(np.issubdtype(np.int, np.integer))
print(np.issubdtype(int, np.integer))
print(np.issubdtype(np.uint8, np.integer))
print(np.issubdtype(int, np.integer))


2
<class 'numpy.ndarray'>
True
True
True
True


In [17]:
import torch
import torch.nn.functional as F

index = torch.randint(low=0, high=5, size=(16,)).to(torch.long)
F.one_hot(index, num_classes=6).shape

torch.Size([16, 6])

In [31]:

F.one_hot(torch.tensor(1), num_classes=6).repeat(3, 1)

tensor([[0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0]])

In [40]:
unlabeled_loss = 0
unlabeled_loss = unlabeled_loss + torch.Tensor([9.])
unlabeled_loss += torch.Tensor([9.])
unlabeled_loss += torch.Tensor([9.])
print(torch.mean(unlabeled_loss))

tensor(27.)


In [41]:
torch.randperm(100)

tensor([32, 12, 15, 97, 16, 57, 69, 98, 73, 33, 79, 58, 50, 89, 53, 95, 27, 43,
        28, 30, 76, 48, 66, 51, 26, 40, 45, 10, 86, 39, 83, 13, 67, 52, 72, 25,
        90, 64, 68, 55,  3, 38, 62, 96, 74, 24, 35,  2, 61, 88, 63, 37, 49, 84,
         5, 70,  6, 54,  0, 85, 20, 71, 44,  8,  9, 91, 41, 81, 99, 36, 60, 82,
        93, 56, 31,  7, 92, 75, 11, 29, 19, 59, 47, 21, 22, 77, 65,  1, 17, 87,
        94,  4, 23, 34, 46, 14, 78, 42, 80, 18])

In [50]:
data = torch.randn((100, 1, 486, 486))
data[torch.randperm(100), :]

tensor([[[[ 1.5367e-01, -9.8013e-01, -3.9470e-01,  ..., -4.8490e-01,
            1.1265e+00,  2.8259e-01],
          [ 1.1218e+00,  6.5208e-01, -5.2254e-01,  ..., -7.9252e-01,
            1.6929e+00,  8.5924e-01],
          [ 1.0971e+00, -9.5824e-01,  4.1752e-01,  ..., -1.0960e-01,
            1.0058e+00,  3.3421e-02],
          ...,
          [-6.4956e-01,  2.8781e+00, -6.5023e-01,  ..., -5.5994e-01,
           -1.1134e+00, -4.5962e-01],
          [-7.9998e-01, -6.8363e-02,  8.0596e-01,  ..., -1.5374e-01,
            7.8103e-01, -3.7725e-01],
          [ 3.1900e-01,  5.2063e-01, -1.3744e-01,  ...,  8.4876e-01,
           -4.5599e-01, -1.4758e+00]]],


        [[[-4.2002e-01,  6.3339e-01, -2.2178e+00,  ..., -1.0256e+00,
            7.0947e-01, -9.1578e-02],
          [ 4.7112e-01,  1.4571e-01,  3.2345e-01,  ..., -2.1362e-01,
            4.2283e-01,  9.6770e-02],
          [-6.7567e-01,  4.2544e-01, -6.5585e-01,  ..., -1.2827e+00,
            8.4240e-01,  5.5083e-01],
          ...,
   

In [57]:
targets = torch.randint(low=0, high=10, size=(100,))
torch.nonzero(targets==0)[:,0]



tensor([ 5, 16, 22, 31, 54, 62, 84, 92, 94])

In [62]:
torch.range(0, 100-1)
list(range(100))

  """Entry point for launching an IPython kernel.


[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99]

In [73]:
L = 100
shuffle=False
idx = torch.randperm(L) if shuffle else torch.arange(L)
idx.shape

torch.Size([100])

In [99]:
import  numpy as np

    
L = 1000
idx = torch.arange(L)

data = torch.randn(L, 1, 213, 213)
targets = torch.randint(low=0, high=10, size=(L,))

targets = targets[idx]
num_per_class = 10
y_dim = 10
uni_idx = np.empty(0).astype(np.integer)
for i in range(y_dim):
    uni_idx = np.append(uni_idx, torch.nonzero(targets==i)[:,0].numpy()[:num_per_class])

rem_idx = np.array(list(set(idx.numpy()) - set(uni_idx)))

rem_idx

array([ 70,  71,  72,  77,  82,  84,  86,  87,  88,  89,  92,  93,  95,
        96, 101, 103, 104, 105, 106, 107, 110, 111, 112, 114, 115, 117,
       120, 121, 122, 124, 125, 126, 128, 130, 132, 133, 134, 135, 136,
       137, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 151,
       152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164,
       165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177,
       178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190,
       191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203,
       204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216,
       217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229,
       230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242,
       243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255,
       256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268,
       269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 28

In [100]:
for i in torch.Tensor([5, 10, 16, 19]).to(torch.long):
    print(targets==i)


tensor([False,  True, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False,  True, False, False, False,  True,
        False, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False,  True, False, False,
         True, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False,  True, False, False,
         True, False,  True, False, False, False, False, False, 

In [123]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader,TensorDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def get_dataloader(num_training=50000, num_labeled=3000, batch_size=200):

    train = MNIST(root='./mnist',download=True)
    train_data=(train.train_data.view(-1,784).float()/255.0)
    train_label=train.train_labels

    dataset={}
    dataset['data']=[]
    dataset['label']=[]

    dataset['test_data']=[]
    dataset['test_label']=[]

    num_per_class=num_training//10
    num_labeled_per_class=num_labeled//10
    for i in range(10):
        ind_i=torch.nonzero(train_label==i)[:,0].numpy()
        np.random.shuffle(ind_i)
        dataset['data'].append(train_data[ind_i[:num_per_class],:])
        dataset['label'].append(train_label[ind_i[:num_per_class]])

        dataset['test_data'].append(train_data[ind_i[num_per_class:],:])
        dataset['test_label'].append(train_label[ind_i[num_per_class:]])


    datas=torch.cat(dataset['data'],0)
    labels=torch.cat(dataset['label'],0)
    labels=torch.zeros(labels.size(0),10).scatter_(1, labels.view(-1,1), 1)

    # dataset={}

    dataset['labeled_data']=datas[torch.Tensor(np.concatenate([np.arange(i*num_per_class,i*num_per_class+num_labeled_per_class) for i in range(10)],0)).long(),:]
    dataset['labeled_label']=labels[torch.Tensor(np.concatenate([np.arange(i*num_per_class,i*num_per_class+num_labeled_per_class) for i in range(10)],0)).long(),:]

    dataset['unlabeled_data']=datas[torch.Tensor(np.concatenate([np.arange(i*num_per_class+num_labeled_per_class,(i+1)*num_per_class) for i in range(10)],0)).long(),:]
    dataset['unlabeled_label']=labels[torch.Tensor(np.concatenate([np.arange(i*num_per_class+num_labeled_per_class,(i+1)*num_per_class) for i in range(10)],0)).long(),:]

    dataset['test_data']=torch.cat(dataset['test_data'],0)
    dataset['test_label']=torch.cat(dataset['test_label'],0)

    dataloader={}
    dataloader['labeled'] = DataLoader(TensorDataset(dataset['labeled_data'], dataset['labeled_label']),
                                       batch_size=num_labeled // (num_training // batch_size), shuffle=True,
                                       num_workers=4)

    dataloader['unlabeled'] = DataLoader(TensorDataset(dataset['unlabeled_data'], dataset['unlabeled_label']),
                                       batch_size=batch_size-num_labeled // (num_training // batch_size), shuffle=True,
                                       num_workers=4)

    dataloader['test'] = DataLoader(TensorDataset(dataset['test_data'],dataset['test_label']),
                                    batch_size=500,shuffle=False,num_workers=4)

    return dataloader

def bce_loss(inputs, targets):
    loss = F.binary_cross_entropy(inputs, targets, reduction='none').view(inputs.shape[0], -1)
    return loss

def _log_norm(x, mean=None, var=None):
    if mean is None:
        mean = torch.zeros_like(x)
    if var is None:
        var = torch.ones_like(x)
    return -0.5 * (torch.log(2.0 * np.pi * var) + torch.pow(x - mean, 2) / var )

def log_norm_kl(x, mean, var, mean_=None, var_=None):
    log_p = _log_norm(x, mean, var)
    log_q = _log_norm(x, mean_, var_)
    loss = log_p - log_q
    return loss

def entropy(logits):
    p = logits.softmax(-1)
    log_p = logits.log_softmax(-1)
    entropy = -(p * log_p)
    return entropy

def softmax_cross_entropy(input, target):
    loss = F.cross_entropy(input, target, reduction='none')
    return loss

class Gaussian(nn.Module):
    def __init__(self, in_dim, out_dim, eps=1e-8):
        super().__init__()
        self.features = nn.Linear(in_dim, out_dim * 2)
        self.eps = eps

    def forward(self, x, reparameterize=True):
        x = self.features(x)
        mean, logit = torch.split(x, x.shape[1] // 2, -1)
        var = F.softplus(logit) + self.eps
        if reparameterize:
            x = self._reparameterize(mean, var)
        else:
            x = mean
        return x, mean, var
    
    def _reparameterize(self, mean, var):
        if torch.is_tensor(var):
            std = torch.pow(var, 0.5)
        else:
            std = np.sqrt(var)
        eps = torch.randn_like(mean)
        x = mean + eps * std
        return x
    
class VAE_M2(nn.Module):
    def __init__(self, hidden_dim=64, z_dim=32, y_dim=10, device='cpu'):
        super().__init__()

        self.device = device
        self.z_dim = z_dim
        self.y_dim = y_dim
        
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.BatchNorm1d(256, momentum=0.01),  
            nn.ReLU(inplace=True),
            nn.Linear(256, hidden_dim),
            nn.BatchNorm1d(hidden_dim, momentum=0.01),  
            nn.ReLU(inplace=True)
        )
        
        self.y_inference = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.BatchNorm1d(256, momentum=0.01),
            nn.Linear(256, y_dim)
        )
        
        self.z_inference = nn.Sequential(
            nn.Linear(hidden_dim + y_dim, 256),
            nn.BatchNorm1d(256, momentum=0.01),
            Gaussian(256, z_dim)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(z_dim + y_dim, 256),
            nn.BatchNorm1d(256, momentum=0.01),  
            nn.ReLU(inplace=True),
            nn.Linear(256, 784),
            nn.BatchNorm1d(784, momentum=0.01),  
            nn.Sigmoid()
        )
        
        self = self.to(device)
    
    def forward(self, ux, lx, ly, return_loss=False, alpha=1.):
        ux = ux.to(self.device) # (batch_size, 1, x_size, x_size)
        lx = lx.to(self.device) # (batch_size, 1, x_size, x_size)
        ly = ly.to(self.device) # (batch_size, )
        
        labeled_loss = self.labeled_loss(lx, ly, alpha)
        unlabeled_loss = self.unlabeled_loss(ux)

        if return_loss:
            return labeled_loss + unlabeled_loss
        return x_reconst
    
    def predict(self, x):
        x = x.to(self.device) # (batch_size, 1, x_size, x_size)
        
        x_hidden = self.encoder(x) # (batch_size, 512 * block.expansion)
        _, _, p_y = self.y_inference(x_hidden)
        _, y_pred = torch.max(p_y, -1)
        
        return y_pred
    
    # labeled loss
    def labeled_loss(self, x, y, alpha=1.):

        x_hidden = self.encoder(x) # (batch_size, 512 * block.expansion

        z, z_mean, z_var = self.z_inference(torch.cat((x_hidden, y), -1)) # (batch_size, z_dim)
        x_reconst = self.decoder(torch.cat((z, y), -1)) # (batch_size, 1, ??, ??)

        log_p_x = bce_loss(x_reconst, x).sum(-1)
        log_p_y = -np.log(1 / self.y_dim)
        log_p_z = log_norm_kl(z, z_mean, z_var, torch.zeros_like(z), torch.ones_like(z)).sum(-1)

        y_logits = self.y_inference(x_hidden)
    
        sup_loss = alpha * softmax_cross_entropy(y_logits, torch.argmax(y,1)).sum(-1)

        loss = (log_p_x + log_p_y + log_p_z + sup_loss).mean() # batch mean

        return loss

    # unlabeled loss
    def unlabeled_loss(self, x):
        unlabeled_loss = 0
        x_hidden = self.encoder(x) # (batch_size, 512 * block.expansion)
        y_logits = self.y_inference(x_hidden)
        qy = F.softmax(y_logits, -1)

        for i in range(self.y_dim):
            qyi = qy[:, i]
            y = F.one_hot(torch.tensor(i), num_classes=self.y_dim).repeat(x.shape[0], 1).to(self.device, dtype=torch.float32)
            z, z_mean, z_var = self.z_inference(torch.cat((x_hidden, y), -1))
            x_reconst = self.decoder(torch.cat((z, y), -1))

            log_p_x = bce_loss(x_reconst, x).sum(-1)
            log_p_y = -np.log(1 / self.y_dim)
            log_p_z = log_norm_kl(z, z_mean, z_var, torch.zeros_like(z), torch.ones_like(z)).sum(-1)
            log_q_y = torch.log(qyi + 1e-10)

            unlabeled_loss += (log_p_x + log_p_y + log_p_z + log_q_y) * qyi

        loss = unlabeled_loss.mean() # batch mean
        return loss

In [126]:
model = VAE_M2()
dataloader = get_dataloader()
n_epoch = 100
for epoch in range(1, n_epoch+1):
    loss = 0
    for step, (labeled_batch, unlabeled_batch) in enumerate(zip(dataloader['labeled'], dataloader['unlabeled'])):
        lx, ly = labeled_batch
        ux, _ = unlabeled_batch
        loss = model(ux, lx, ly, return_loss=True, alpha=0.1 * 200)
        loss += loss.item()
    print(f'loss = {loss:.3f} at epoch {epoch}')

loss = 3631.426 at epoch 1
loss = 3754.611 at epoch 2
loss = 3742.406 at epoch 3
loss = 3748.849 at epoch 4
loss = 3637.217 at epoch 5
loss = 3634.219 at epoch 6
loss = 3662.515 at epoch 7
loss = 3700.556 at epoch 8
loss = 3652.645 at epoch 9
loss = 3853.604 at epoch 10
loss = 3703.984 at epoch 11
loss = 3725.702 at epoch 12
loss = 3781.822 at epoch 13
loss = 3600.037 at epoch 14
loss = 3739.469 at epoch 15
loss = 3763.411 at epoch 16
loss = 3697.403 at epoch 17
loss = 3659.354 at epoch 18
loss = 3695.303 at epoch 19
loss = 3883.104 at epoch 20
loss = 3667.581 at epoch 21
loss = 3709.542 at epoch 22
loss = 3590.102 at epoch 23
loss = 3695.595 at epoch 24
loss = 3760.744 at epoch 25
loss = 3824.709 at epoch 26
loss = 3753.752 at epoch 27
loss = 3705.874 at epoch 28
loss = 3762.392 at epoch 29
loss = 3846.200 at epoch 30
loss = 3730.970 at epoch 31
loss = 3699.435 at epoch 32
loss = 3704.175 at epoch 33
loss = 3689.582 at epoch 34
loss = 3704.817 at epoch 35
loss = 3663.893 at epoch 36
l