#### A Quick Demo for the Joint Distribution Weighted Alignment (JDWA) Approach with a Simple Neural Network

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchmin import minimize

In [2]:
# define the one-hidden-layer neural network model
class NeuralNet(nn.Module):
    def __init__(self, input_size=1000, hidden_size=100, output_size=10):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=output_size, bias=True)
    def forward(self, X):
        FX = F.relu(self.fc1(X)) # hidden layer activation features
        prob = F.softmax(self.fc2(FX), dim=1) # probability output
        return FX, prob

In [3]:
# wrap the Joint Distribution Weighted Alignment (JDWA) approach as a class, following the sklearn style
class JDWA:
    """
    In the training procedure, the total batch size per iteration is: batch_size * num_class * num_domain
    For instance, if there are 5 domains with 68 classes in each domain, 
    then batch_size=4 means drawing 4 samples from every class in every domain,
    resulting in 4*68*5 samples in the total batch size.
    """
    def __init__(self, input_size=1024, hidden_size=512, output_size=68, seed=1000, device=torch.device('cpu'),
                 epoch_pretrain=200, epoch=200, lamdaRE=1, lamdaE=0.01, epsilon=1e-3, batch_size=4, lr=1e-3, log=False):
        args_values = locals()
        args_values.pop("self")
        for arg,value in args_values.items():
            setattr(self, arg, value)
 
    def fit(self, X_list, y_list, Xt, yt):
        class_labels = np.unique(y_list[0])
        n, c = len(X_list), len(class_labels) # number of source domains, number of classes
        # generate random target labels
        yt_rand = np.random.choice(a=class_labels, size=len(Xt), replace=True, p=1.0 * np.ones(c) / c) 
        X_list.append(Xt), y_list.append(yt_rand)
        
        # define the neural network instance and the optimizer
        torch.manual_seed(self.seed)
        net = NeuralNet(input_size=self.input_size, hidden_size=self.hidden_size, output_size=self.output_size).to(self.device)
        optimizer = optim.SGD(params=net.parameters(), lr=self.lr, momentum=0.9)

        #=====pretrain the network to estimate the pseudo target labels=======
        print('Pretraining...')
        for epoch in range(self.epoch_pretrain):
            dataset_loaders, l = [], 0
            for X, y in zip(X_list, y_list):
                for i, counts in zip(*np.unique(y, return_counts=True)):
                    dataset = np.hstack((X[y==i], y[y==i][:,None], l * np.ones((counts,1))))
                    dataset_loaders.append(torch.utils.data.DataLoader(dataset=torch.tensor(dataset),
                                                                       batch_size=self.batch_size, shuffle=True, drop_last=False))
                l = l + 1 # source domain labels {0, ..., n-1}, target domain label n
            
            log_loss, m_log_loss, ent_loss, m_ent_loss = 0.0, 0.0, 0.0, 0.0
            for batches in zip(*dataset_loaders):
                Xyl = torch.cat(batches, dim=0)
                X, y, l = Xyl[:,:-2].to(self.device,torch.float32), Xyl[:,-2].to(self.device,torch.int64), Xyl[:,-1].to(self.device,torch.int64)

                FX, prob = net(X)
                negative_log = 0.0
                # weights for the source domains are identical in the pretraining procedure
                for i in range(n):
                    negative_log += -1.0 / n * torch.mean(torch.sum(torch.log(prob[l==i]) * F.one_hot(y[l==i], c), dim=1))
                ent = -torch.mean(torch.sum(prob[l==n] * torch.log(prob[l==n] + 1e-7), dim=1))   # conditional entropy loss
                loss = negative_log + self.lamdaE * ent
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                log_loss += negative_log.item() * len(X[l!=n])
                m_log_loss += len(X[l!=n])
                ent_loss += ent.item() * len(X[l==n])
                m_ent_loss += len(X[l==n])

            with torch.no_grad():
                Xt, yt = torch.as_tensor(Xt, dtype=torch.float32, device=self.device), torch.as_tensor(yt, dtype=torch.int64, device=self.device)    
                yt_hat = torch.argmax(net(Xt)[1],dim=1)
                correct = torch.sum((yt_hat == yt)).item()
                m_test = len(yt)
            
            if True == self.log:
                print('epoch ',epoch, ', log loss ',  "{:.5f}".format(log_loss / m_log_loss), 
                      ', entropy loss ', "{:.5f}".format(ent_loss / m_ent_loss), 
                      ', total loss ', "{:.5f}".format(log_loss / m_log_loss + self.lamdaE * ent_loss / m_ent_loss),
                      ', test acc. ', "{:.5f}".format((correct / m_test) * 100)) 
        #========================================================
        
        #=====train the MSDA network====================================
        print('Training...')
        for epoch in range(self.epoch):
            # update the pseudo target labels every epoch
            y_list.pop()
            y_list.append(yt_hat.cpu().numpy())
            dataset_loaders, l = [], 0
            for X, y in zip(X_list, y_list):
                for i, counts in zip(*np.unique(y, return_counts=True)):
                    dataset = np.hstack((X[y==i], y[y==i][:,None], l * np.ones((counts,1))))
                    dataset_loaders.append(torch.utils.data.DataLoader(dataset=torch.tensor(dataset),
                                                                       batch_size=self.batch_size, shuffle=True, drop_last=False))
                l = l + 1 # source domain labels {0, ..., n-1}, target domain label n
            
            log_loss, m_log_loss, ent_loss, m_ent_loss = 0.0, 0.0, 0.0, 0.0
            # n + 1 batches of identical size are drawn from the n + 1 source and target datasets
            # each batch contains the same number of samples from each class
            for batches in zip(*dataset_loaders):
                Xyl = torch.cat(batches, dim=0)
                X, y, l = Xyl[:,:-2].to(self.device, torch.float32), Xyl[:,-2].to(self.device, torch.int64), Xyl[:,-1].to(self.device, torch.int64)
                
                # compute the Gaussian kernel width
                pairwise_dist = torch.cdist(X, X, p=2)**2 
                sigma = torch.median(pairwise_dist[pairwise_dist!=0]) 
                
                # compute the product kernel matrix
                FX, prob = net(X)
                FX_norm = torch.sum(FX ** 2, axis = -1)
                K = torch.exp(-(FX_norm[:,None] + FX_norm[None,:] - 2 * torch.matmul(FX, FX.t())) / sigma) # feature kernel matrix     
                Deltay = torch.as_tensor(y[:,None]==y, dtype=torch.float64, device=FX.device) # label kernel matrix  
                P = torch.as_tensor(K * Deltay, dtype=torch.double) # product kernel matrix
                
                # optimize the relevance weights
                def ObjAlpha(alpha):
                    alpha = torch.softmax(alpha, dim=0)
                    def ObjTheta(theta):
                        re = -torch.mean(torch.exp(torch.matmul(P[l==n], theta) - 1.0))
                        for i in range(n):
                            re += alpha[i] * torch.mean(torch.matmul(P[l==i], theta))
                        reg = self.epsilon * torch.matmul(theta, theta) 
                        return -re + reg
                    theta0 = torch.zeros(len(X), dtype=torch.double, device=self.device)
                    result = minimize(ObjTheta, theta0, method='l-bfgs', max_iter=5)
                    theta = result.x
                    negative_log = 0.0
                    # the estimated relative entropy as a loss of the relevance weights
                    re = -torch.mean(torch.exp(torch.matmul(P[l==n], theta) - 1.0))
                    for i in range(n):
                        re += alpha[i] * torch.mean(torch.matmul(P[l==i], theta))
                        negative_log += -alpha[i] * torch.mean(torch.sum(torch.log(prob[l==i]) * F.one_hot(y[l==i], c), dim=1))           
                    return negative_log + self.lamdaRE * re
                alpha0 = 1.0 * torch.ones(n, device=self.device) / n 
                result = minimize(ObjAlpha, alpha0, method='l-bfgs', max_iter=5)
                alpha = torch.softmax(result.x, dim=0)
                
                negative_log = 0.0
                for i in range(n):
                    negative_log += -alpha[i] * torch.mean(torch.sum(torch.log(prob[l==i]) * F.one_hot(y[l==i], c), dim=1))
                def ObjTheta(theta):
                    re = -torch.mean(torch.exp(torch.matmul(P[l==n], theta) - 1.0))
                    for i in range(n):
                        re += alpha[i] * torch.mean(torch.matmul(P[l==i], theta))
                    reg = self.epsilon * torch.matmul(theta, theta) 
                    return -re + reg
                theta0 = torch.zeros(len(X), dtype=torch.double, device=self.device)
                result = minimize(ObjTheta, theta0, method='l-bfgs')
                # the estimated relative entropy as a loss of the feature extractor
                re = -torch.mean(torch.exp(torch.matmul(P[l==n], result.x) - 1.0))
                for i in range(n):
                    re += alpha[i] * torch.mean(torch.matmul(P[l==i], result.x))                
                ent = -torch.mean(torch.sum(prob[l==n] * torch.log(prob[l==n] + 1e-7), dim=1))  # conditional entropy loss
                loss = negative_log + self.lamdaRE * re + self.lamdaE * ent
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                log_loss += negative_log.item() * len(X[l!=n])
                m_log_loss += len(X[l!=n])
                ent_loss += ent.item() * len(X[l==n])
                m_ent_loss += len(X[l==n])

            with torch.no_grad():
                Xt, yt = torch.as_tensor(Xt, dtype=torch.float32, device=self.device), torch.as_tensor(yt, dtype=torch.int64, device=self.device)    
                yt_hat = torch.argmax(net(Xt)[1],dim=1)
                correct = torch.sum((yt_hat == yt)).item()
                m_test = len(yt)
                
            if True == self.log:
                print('epoch ',epoch, ', log loss ',  "{:.5f}".format(log_loss / m_log_loss), 
                      ', relative entropy loss ', "{:.5f}".format(re.item()), 
                      ', entropy loss ', "{:.5f}".format(ent_loss / m_ent_loss), 
                      ', total loss ', "{:.5f}".format(log_loss / m_log_loss + self.lamdaRE * re.item() + self.lamdaE * ent_loss / m_ent_loss),
                      ', test acc. ', "{:.5f}".format((correct / m_test) * 100)) 
            #========================================================     
        self.net = net # save the network

    def score(self, Xt, yt):
        with torch.no_grad():
            Xt, yt = torch.as_tensor(Xt, dtype=torch.float32,device=self.device), torch.as_tensor(yt, dtype=torch.int64,device=self.device)    
            pred = torch.argmax(self.net(Xt)[1],dim=1)
            correct = torch.sum((pred == yt)).item()
            m_test = len(yt)
        return (correct / m_test) * 100

In [4]:
import numpy as np
import pandas as pd
import scipy.io as sio
import numpy.linalg as la
from sklearn.preprocessing import scale,LabelEncoder

In [5]:
def readData(tg, domains):
    data = sio.loadmat('PIE/' + tg + '.mat')
    Xt, yt = data['fea'].astype(np.float64), data['gnd'].ravel()
    yt = LabelEncoder().fit(yt).transform(yt).astype(np.float64)
    Xt = scale(Xt / Xt.sum(axis=1,keepdims=True))
    
    Xs_list, ys_list = [], []
    for sc in domains:
        if sc != tg:
            data = sio.loadmat('PIE/' + sc + '.mat')
            Xs, ys = data['fea'].astype(np.float64), data['gnd'].ravel()
            ys = LabelEncoder().fit(ys).transform(ys).astype(np.float64)
            Xs = scale(Xs / Xs.sum(axis=1,keepdims=True))
            Xs_list.append(Xs), ys_list.append(ys)        
    
    return Xs_list, ys_list, Xt, yt

domains = ['C05', 'C07', 'C09', 'C27', 'C29']

In [6]:
DEVICE = torch.device('cpu') # 'cuda:0'
Xs_list, ys_list, Xt, yt = readData('C05', domains)
instance = JDWA(input_size=1024, hidden_size=512, output_size=68, seed=0, device=DEVICE,
                         epoch_pretrain=50, epoch=100, lamdaRE=10, lamdaE=1e-3, epsilon=1e-3, batch_size=1, lr=1e-2, log=True)
instance.fit(Xs_list, ys_list, Xt, yt)
instance.score(Xt, yt)

Pretraining...
epoch  0 , log loss  4.08850 , entropy loss  4.19485 , total loss  4.09270 , test acc.  16.29652
epoch  1 , log loss  3.53632 , entropy loss  4.11505 , total loss  3.54043 , test acc.  31.18247
epoch  2 , log loss  2.89301 , entropy loss  3.93695 , total loss  2.89695 , test acc.  47.50900
epoch  3 , log loss  2.24238 , entropy loss  3.67903 , total loss  2.24606 , test acc.  59.27371
epoch  4 , log loss  1.65814 , entropy loss  3.29787 , total loss  1.66144 , test acc.  65.21609
epoch  5 , log loss  1.21132 , entropy loss  2.90704 , total loss  1.21423 , test acc.  69.17767
epoch  6 , log loss  0.89307 , entropy loss  2.53652 , total loss  0.89560 , test acc.  70.13806
epoch  7 , log loss  0.69678 , entropy loss  2.27251 , total loss  0.69905 , test acc.  72.11885
epoch  8 , log loss  0.55812 , entropy loss  2.00926 , total loss  0.56013 , test acc.  73.73950
epoch  9 , log loss  0.45823 , entropy loss  1.88834 , total loss  0.46012 , test acc.  74.42977
epoch  10 , log

epoch  26 , log loss  0.06221 , relative entropy loss  0.00480 , entropy loss  0.71180 , total loss  0.11096 , test acc.  93.09724
epoch  27 , log loss  0.05998 , relative entropy loss  0.01640 , entropy loss  0.68386 , total loss  0.22470 , test acc.  93.36735
epoch  28 , log loss  0.05893 , relative entropy loss  0.00686 , entropy loss  0.65923 , total loss  0.12821 , test acc.  93.24730
epoch  29 , log loss  0.05796 , relative entropy loss  -0.00188 , entropy loss  0.64731 , total loss  0.03982 , test acc.  93.39736
epoch  30 , log loss  0.05614 , relative entropy loss  0.00214 , entropy loss  0.64028 , total loss  0.07815 , test acc.  93.48739
epoch  31 , log loss  0.05545 , relative entropy loss  0.00569 , entropy loss  0.63339 , total loss  0.11301 , test acc.  93.54742
epoch  32 , log loss  0.05403 , relative entropy loss  0.01427 , entropy loss  0.62097 , total loss  0.19735 , test acc.  93.81753
epoch  33 , log loss  0.05262 , relative entropy loss  0.00354 , entropy loss  0.6

epoch  89 , log loss  0.02542 , relative entropy loss  -0.01168 , entropy loss  0.24415 , total loss  -0.09113 , test acc.  94.65786
epoch  90 , log loss  0.02470 , relative entropy loss  0.00172 , entropy loss  0.23288 , total loss  0.04209 , test acc.  94.59784
epoch  91 , log loss  0.02430 , relative entropy loss  0.00039 , entropy loss  0.24212 , total loss  0.02839 , test acc.  94.62785
epoch  92 , log loss  0.02442 , relative entropy loss  -0.00093 , entropy loss  0.23770 , total loss  0.01540 , test acc.  94.59784
epoch  93 , log loss  0.02402 , relative entropy loss  -0.00751 , entropy loss  0.23811 , total loss  -0.05083 , test acc.  94.65786
epoch  94 , log loss  0.02391 , relative entropy loss  0.00050 , entropy loss  0.24041 , total loss  0.02915 , test acc.  94.59784
epoch  95 , log loss  0.02385 , relative entropy loss  -0.00542 , entropy loss  0.23334 , total loss  -0.03012 , test acc.  94.65786
epoch  96 , log loss  0.02384 , relative entropy loss  -0.01295 , entropy lo

94.68787515006002