#### A Quick Demo for the Direct Weight Estimation (DWE) Algorithm

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchmin import minimize

In [2]:
# define the one-hidden-layer neural network model
class NeuralNet(nn.Module):
    def __init__(self, input_size=1000, hidden_size=100, output_size=10):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=output_size, bias=True)
    def forward(self, X):
        FX = F.relu(self.fc1(X)) # hidden layer activation features
        prob = F.softmax(self.fc2(FX), dim=1) # probability output
        return FX, prob

In [3]:
# wrap the Direct Weight Estimation (DWE) algorithm as a class, following the sklearn style
class DWE:
    def __init__(self, input_size=1024, hidden_size=512, output_size=68, seed=1000, device=torch.device('cpu'),
                 epoch=200, start=50, lamdaE=0.1, div='L2', epsilon=0.01, batch_size=4, lr=1e-3, log=False):
        # in the training procedure, the total batch size per iteration is: batch_size * 2
        args_values = locals()
        args_values.pop("self")
        for arg,value in args_values.items():
            setattr(self, arg, value)
            
    def fit(self, Xs, ys, Xt, yt):
        class_labels = torch.unique(ys)
        c = len(class_labels) # number of source classes
        ws = torch.ones(len(ys)) # initialize all the source data weights to 1
        
        # define the neural network instance and the optimizer
        torch.manual_seed(self.seed)
        net = NeuralNet(input_size=self.input_size, hidden_size=self.hidden_size, output_size=self.output_size).to(self.device)
        optimizer = optim.SGD(params=net.parameters(), lr=self.lr, momentum=0.9)
        
        #=============train the PDA network==========================
        print('training the PDA network...')
        m_batch = self.batch_size
        for epoch in range(self.epoch):
            sc_loader = torch.utils.data.DataLoader(dataset=torch.cat((Xs, ys[:,None], ws[:,None]), dim=1),
                                                                       batch_size=self.batch_size, shuffle=True, drop_last=False)
            tg_loader = torch.utils.data.DataLoader(dataset=Xt,
                                                                        batch_size=self.batch_size, shuffle=True, drop_last=False) 
            
            log_loss, m_log_loss, sce_loss, m_sce_loss = 0.0, 0.0, 0.0, 0.0
            # 2 batches of identical size are drawn from the source and target datasets
            for sc_batch, tg_batch in zip(sc_loader, tg_loader):
                Xs_batch, ys_batch = sc_batch[:, :-2].to(self.device,torch.float32), sc_batch[:, -2].to(self.device,torch.int64)
                ws_batch = sc_batch[:, -1].to(self.device,torch.int64)
                Xt_batch = tg_batch.to(self.device, torch.float32)
                X_batch = torch.cat((Xs_batch, Xt_batch), dim=0)
                
                prob = net(X_batch)[1]
                negative_log = -torch.mean(ws_batch * torch.sum(torch.log(prob[:m_batch]) * F.one_hot(ys_batch, c), dim=1))      
                sce = -torch.mean(torch.sum((prob[m_batch:] - 1.0)**2, dim=1))   # squared-loss conditional entropy
                loss = negative_log + self.lamdaE * sce
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
                log_loss += negative_log.item() * m_batch
                m_log_loss += m_batch
                sce_loss += sce.item() * m_batch
                m_sce_loss += m_batch

            with torch.no_grad():
                Xt, yt = torch.as_tensor(Xt, dtype=torch.float32, device=self.device), torch.as_tensor(yt, dtype=torch.int64, device=self.device)    
                yt_hat = torch.argmax(net(Xt)[1], dim=1)
                correct = torch.sum((yt_hat == yt)).item()
                m_test = len(yt)
            
            if ((epoch+1) >= self.start) and ((epoch+1) % 10==0):
                print('update the source data weights after', (epoch+1), 'epochs')
                ms, mt = len(Xs), len(Xt)
                Xs, ys = torch.as_tensor(Xs, dtype=torch.float32, device=self.device), torch.as_tensor(ys, dtype=torch.int64, device=self.device)  
                FXs, FXt = net(Xs)[0], net(Xt)[0]
                FXs, FXt = FXs / torch.norm(FXs, dim=1, keepdim=True), FXt / torch.norm(FXt, dim=1, keepdim=True)
                FX, y = torch.cat((FXs, FXt), dim=0), torch.cat((ys, yt_hat), dim=0)
                FX_norm = torch.sum(FX ** 2, axis = -1)
                FXs_norm = torch.sum(FXs ** 2, axis = -1)
                
                if self.div == 'L2':
                    # estimate the source data weights under the L2 distance
                    K = torch.exp(-(FX_norm[:,None] + FXs_norm[None,:] - 2 * torch.matmul(FX, FXs.t())) / (2 * 1.0 / torch.pi)) * torch.as_tensor(y[:,None]==ys, dtype=torch.float32, device=self.device) # kernel matrix  
                    Ks, Kt = K[:ms], K[ms:]
                    Kt_mean = torch.mean(Kt, dim=0)
                    H = torch.exp(-(FXs_norm[:,None] + FXs_norm[None,:] - 2 * torch.matmul(FXs, FXs.t())) / (4 * 1.0 / torch.pi)) * torch.as_tensor(ys[:,None]==ys, dtype=torch.float32, device=self.device)
                    invM = torch.inverse(H + self.epsilon * torch.eye(ms, device=self.device))

                    A = (1.0 / ms)**2 * torch.matmul(torch.matmul(Ks, invM), Ks.t()) - 0.5 * (1.0 / ms)**2 * torch.matmul(torch.matmul(torch.matmul(torch.matmul(Ks, invM), H), invM), Ks.t())
                    B1 = 2 * (1.0 / ms) * torch.matmul(torch.matmul(Kt_mean, invM), Ks.t()) 
                    B2 = (1.0 / ms) * torch.matmul(torch.matmul(torch.matmul(torch.matmul(Kt_mean, invM), H), invM), Ks.t())
                    B = B1 - B2
                    C = torch.matmul(torch.matmul(Kt_mean, invM), Kt_mean) - 0.5 * torch.matmul(torch.matmul(torch.matmul(torch.matmul(Kt_mean, invM), H), invM), Kt_mean)

                    def obj(w):
                        w =torch.softmax(w, dim=0) * ms
                        div = torch.matmul(torch.matmul(w, A), w)  - torch.matmul(B, w) + C
                        return div
                    w0 = torch.zeros(ms, dtype=torch.float32, device=self.device)
                    result = minimize(obj, w0, method='l-bfgs', max_iter=100) #
                    ws = torch.softmax(result.x, dim=0) * ms
                    
                else: 
                    # estimate the source data weights under the Chi2 divergence
                    pairwise_dist = torch.cdist(FX, FX, p=2)**2 
                    sigma = torch.median(pairwise_dist[pairwise_dist!=0]) # compute the Gaussian kernel width

                    K = torch.exp(-(FX_norm[:,None] + FXs_norm[None,:] - 2 * torch.matmul(FX, FXs.t())) / sigma) * torch.as_tensor(y[:,None]==ys, dtype=torch.float32, device=self.device) # kernel matrix  
                    Ks, Kt = K[:ms], K[ms:]
                    H = 1.0 / mt * torch.matmul(Kt.t(), Kt) 
                    invM = torch.inverse(H + self.epsilon * torch.eye(ms, device=self.device))
                    A = 2 * (1.0 / ms)**2 * torch.matmul(torch.matmul(Ks, invM), Ks.t())
                    B = (1.0 / ms)**2 * torch.matmul(torch.matmul(torch.matmul(torch.matmul(Ks, invM), H), invM), Ks.t())
                    C = A - B

                    def obj(w):
                        w = torch.softmax(w, dim=0) * ms
                        div = torch.matmul(torch.matmul(w, C), w)  - 1.0 
                        return div
                    w0 = torch.zeros(ms, dtype=torch.float32, device=self.device)
                    result = minimize(obj, w0, method='l-bfgs', max_iter=100) #
                    ws = torch.softmax(result.x, dim=0) * ms

            if True == self.log:
                print('epoch ', (epoch+1), ', log loss ',  "{:.5f}".format(log_loss / m_log_loss), 
                      ', sce loss ', "{:.5f}".format(sce_loss / m_sce_loss), 
                      ', total loss ', "{:.5f}".format(log_loss / m_log_loss + self.lamdaE * sce_loss / m_sce_loss),
                      ', test acc. ', "{:.5f}".format((correct / m_test) * 100))  
        #========================================================             
        self.net = net # save the network

    def score(self, Xt, yt):
        with torch.no_grad():
            Xt, yt = torch.as_tensor(Xt, dtype=torch.float32,device=self.device), torch.as_tensor(yt, dtype=torch.int64,device=self.device)    
            pred = torch.argmax(self.net(Xt)[1],dim=1)
            correct = torch.sum((pred == yt)).item()
            m_test = len(yt)
        return (correct / m_test) * 100

In [4]:
import numpy as np
import pandas as pd
import scipy.io as sio
import numpy.linalg as la
from sklearn.preprocessing import scale,LabelEncoder

In [5]:
DEVICE = torch.device('cuda:0') #'cpu'
data = np.loadtxt('OfficeHome/Resnet50_Art_Art.csv', delimiter=',') # source domain
Xs, ys = torch.tensor(data[:,:-1]), data[:,-1]
ys = torch.tensor(LabelEncoder().fit(ys).transform(ys).astype(np.float64),dtype=torch.int64)
    
data = np.loadtxt('OfficeHome/Resnet50_Art_Product.csv', delimiter=',') # target domain
Xt, yt = torch.tensor(data[:,:-1]), data[:,-1]
yt = torch.tensor(LabelEncoder().fit(yt).transform(yt).astype(np.float64),dtype=torch.int64)
Xt, yt = Xt[yt < 25], yt[yt < 25]

In [6]:
instance = DWE(input_size=2048, hidden_size=1024, output_size=65, seed=0, device=DEVICE,
                         epoch=105, start=50, lamdaE=0.1, div='L2', epsilon=1e-2, batch_size=200, lr=1e-2, log=True)
instance.fit(Xs, ys, Xt, yt)
instance.score(Xt, yt)

training the PDA network...
epoch  1 , log loss  4.10115 , sce loss  -63.01585 , total loss  -2.20044 , test acc.  11.09244
epoch  2 , log loss  3.69547 , sce loss  -63.01852 , total loss  -2.60638 , test acc.  18.82353
epoch  3 , log loss  3.17982 , sce loss  -63.02477 , total loss  -3.12266 , test acc.  34.34174
epoch  4 , log loss  2.54052 , sce loss  -63.04595 , total loss  -3.76408 , test acc.  37.70308
epoch  5 , log loss  1.92830 , sce loss  -63.09247 , total loss  -4.38095 , test acc.  52.49300
epoch  6 , log loss  1.47785 , sce loss  -63.15549 , total loss  -4.83770 , test acc.  59.60784
epoch  7 , log loss  1.15947 , sce loss  -63.21565 , total loss  -5.16209 , test acc.  59.83193
epoch  8 , log loss  0.93590 , sce loss  -63.28697 , total loss  -5.39280 , test acc.  64.03361
epoch  9 , log loss  0.77394 , sce loss  -63.34939 , total loss  -5.56100 , test acc.  64.76190
epoch  10 , log loss  0.69634 , sce loss  -63.40007 , total loss  -5.64367 , test acc.  65.21008
epoch  11 ,

epoch  85 , log loss  0.00179 , sce loss  -63.91315 , total loss  -6.38952 , test acc.  83.58543
epoch  86 , log loss  0.00218 , sce loss  -63.91411 , total loss  -6.38923 , test acc.  83.41737
epoch  87 , log loss  0.00184 , sce loss  -63.91522 , total loss  -6.38968 , test acc.  83.58543
epoch  88 , log loss  0.00201 , sce loss  -63.91584 , total loss  -6.38957 , test acc.  83.52941
epoch  89 , log loss  0.00188 , sce loss  -63.91677 , total loss  -6.38979 , test acc.  83.47339
update the source data weights after 90 epochs
epoch  90 , log loss  0.00166 , sce loss  -63.91777 , total loss  -6.39012 , test acc.  83.52941
epoch  91 , log loss  0.00162 , sce loss  -63.91890 , total loss  -6.39027 , test acc.  83.58543
epoch  92 , log loss  0.00191 , sce loss  -63.92008 , total loss  -6.39010 , test acc.  83.64146
epoch  93 , log loss  0.00174 , sce loss  -63.92096 , total loss  -6.39035 , test acc.  83.64146
epoch  94 , log loss  0.00188 , sce loss  -63.92202 , total loss  -6.39032 , tes

83.80952380952381

In [7]:
instance = DWE(input_size=2048, hidden_size=1024, output_size=65, seed=0, device=DEVICE,
                         epoch=105, start=50, lamdaE=0.1, div='Chi2', epsilon=1e-2, batch_size=200, lr=1e-2, log=True)
instance.fit(Xs, ys, Xt, yt)
instance.score(Xt, yt)

training the PDA network...
epoch  1 , log loss  4.10115 , sce loss  -63.01585 , total loss  -2.20044 , test acc.  11.09244
epoch  2 , log loss  3.69547 , sce loss  -63.01852 , total loss  -2.60638 , test acc.  18.82353
epoch  3 , log loss  3.17982 , sce loss  -63.02477 , total loss  -3.12266 , test acc.  34.34174
epoch  4 , log loss  2.54052 , sce loss  -63.04595 , total loss  -3.76408 , test acc.  37.70308
epoch  5 , log loss  1.92830 , sce loss  -63.09247 , total loss  -4.38095 , test acc.  52.49300
epoch  6 , log loss  1.47785 , sce loss  -63.15549 , total loss  -4.83770 , test acc.  59.60784
epoch  7 , log loss  1.15947 , sce loss  -63.21565 , total loss  -5.16209 , test acc.  59.83193
epoch  8 , log loss  0.93590 , sce loss  -63.28697 , total loss  -5.39280 , test acc.  64.03361
epoch  9 , log loss  0.77394 , sce loss  -63.34939 , total loss  -5.56100 , test acc.  64.76190
epoch  10 , log loss  0.69634 , sce loss  -63.40007 , total loss  -5.64367 , test acc.  65.21008
epoch  11 ,

epoch  84 , log loss  0.00561 , sce loss  -63.89235 , total loss  -6.38362 , test acc.  80.11204
epoch  85 , log loss  0.00510 , sce loss  -63.89330 , total loss  -6.38423 , test acc.  80.05602
epoch  86 , log loss  0.00544 , sce loss  -63.89414 , total loss  -6.38397 , test acc.  80.05602
epoch  87 , log loss  0.00525 , sce loss  -63.89596 , total loss  -6.38435 , test acc.  80.05602
epoch  88 , log loss  0.00485 , sce loss  -63.89680 , total loss  -6.38483 , test acc.  80.00000
epoch  89 , log loss  0.00563 , sce loss  -63.89799 , total loss  -6.38417 , test acc.  80.05602
update the source data weights after 90 epochs
epoch  90 , log loss  0.00484 , sce loss  -63.89923 , total loss  -6.38508 , test acc.  80.05602
epoch  91 , log loss  0.00451 , sce loss  -63.90117 , total loss  -6.38561 , test acc.  80.05602
epoch  92 , log loss  0.00469 , sce loss  -63.90280 , total loss  -6.38559 , test acc.  79.94398
epoch  93 , log loss  0.00406 , sce loss  -63.90346 , total loss  -6.38628 , tes

80.3921568627451