In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
# define the one-hidden-layer neural network model as the classification model
# the hidden layer transforms the neural network features for joint distribution alignment
# the output layer produces the probability value for each class
class NeuralNet(nn.Module):
    def __init__(self, input_size=1000, hidden_size=100, output_size=10):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=output_size, bias=True)
    def forward(self, X):
        FX = F.relu(self.fc1(X)) # hidden layer activation features
        prob = F.softmax(self.fc2(FX), dim=1) # probability output
        return FX, prob

In [3]:
# compute the Triangular Discrimination (TD) distance between joint distributions
def TD(Xs, ys, Xt, yt, device):
    ms, mt = len(Xs), len(Xt)
    X, y = torch.cat((Xs, Xt), dim=0), torch.cat((ys, yt), dim=0)
    
    # compute the Gaussian kernel width
    with torch.no_grad(): 
        pairwise_dist = torch.cdist(X, X, p=2)**2 
        sigma = torch.median(pairwise_dist[pairwise_dist!=0])
    
    # see the paper for the detailed derivations of the following equations            
    X_norm = torch.sum(X ** 2, axis = -1)
    K = torch.exp(-(X_norm[:,None] + X_norm[None,:] - 2 * torch.matmul(X, X.t())) / sigma) * torch.as_tensor(y[:,None]==y, dtype=torch.float32, device=device) # kernel matrix  
    Ks, Kt = K[:ms], K[ms:]
    H = 1.0 / ms * torch.matmul(Ks.t(), Ks) + 1.0 / mt * torch.matmul(Kt.t(), Kt) 
    invM = torch.inverse(H + 0.01 * torch.eye(ms+mt, device=device)) 
    b = torch.mean(Ks, dim=0) - torch.mean(Kt, dim=0) 
    theta = torch.matmul(invM, b)
    TD_hat = 2 * torch.matmul(b, theta) - torch.matmul(theta, torch.matmul(H, theta))
    return TD_hat   

In [4]:
class KMUR:
    def __init__(self, input_size=1024, hidden_size=512, output_size=68, seed=1000, device=torch.device('cpu'), theta=0.5, lamda=1.0,
                 pretrain_epoch=100, epoch=200, lamdaTD=1.0, batch_size=200, lr=1e-3, log=False):
        # in the training procedure, the total batch size per iteration is: batch_size * 2, 1 for source and 1 for target
        args_values = locals()
        args_values.pop("self")
        for arg,value in args_values.items():
            setattr(self, arg, value)
            
    def fit(self, Xs, ys, Xt, yt):
        Xs = torch.as_tensor(Xs, dtype=torch.float32, device=self.device)
        Xt = torch.as_tensor(Xt, dtype=torch.float32, device=self.device)
        ys = torch.as_tensor(ys, dtype=torch.int64, device=self.device)
        yt = torch.as_tensor(yt, dtype=torch.int64, device=self.device)
        
        class_labels = torch.unique(ys)
        c = len(class_labels) + 1 # number of all classes including the unknown
        
        # define the neural network instance and the optimizer
        torch.manual_seed(self.seed)
        net = NeuralNet(input_size=self.input_size, hidden_size=self.hidden_size, output_size=self.output_size).to(self.device)
        optimizer = optim.Adam(params=net.parameters(), lr=self.lr, weight_decay=1e-4)

        #=============Pre-training the network==========================
        print('Pre-training for generating pseudo labels...')
        m_batch = self.batch_size
        for epoch in range(self.pretrain_epoch):
            sc_loader = torch.utils.data.DataLoader(dataset=torch.cat((Xs, ys[:, None]), dim=1),
                                                                       batch_size=self.batch_size, shuffle=True, drop_last=False)
            tg_loader = torch.utils.data.DataLoader(dataset=Xt,
                                                                        batch_size=self.batch_size, shuffle=True, drop_last=False) 
            
            # 2 batches of identical size are drawn from the source and target datasets
            for sc_batch, tg_batch in zip(sc_loader, tg_loader):
                Xs_batch, ys_batch = sc_batch[:, :-1].to(self.device,torch.float32), sc_batch[:, -1].to(self.device, torch.int64)
                Xt_batch = tg_batch.to(self.device, torch.float32)
                
                probs, probt = net(Xs_batch)[1], net(Xt_batch)[1]
                loss1 = -self.theta * torch.mean(torch.sum(torch.log(probs[:, :-1]) * F.one_hot(ys_batch, c-1), dim=1))
                loss2 = -torch.mean(torch.log(probt[:, -1])) + self.theta * torch.mean(torch.log(probs[:, -1]))
                if loss2.item() >= 0:
                    loss = loss1 + self.lamda * loss2
                else:
                    loss = loss1 - self.lamda * loss2
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            with torch.no_grad():
                yt_hat = torch.argmax(net(Xt)[1], dim=1)
                acc_class = torch.zeros(c)
                for i in range(c):
                    acc_class[i] = torch.sum((yt_hat[yt==i] == i)).item() / torch.sum((i == yt)).item()
                OS_start, UNK = torch.mean(acc_class[:-1]) * 100, acc_class[-1] * 100
                HOS = 2 * OS_start * UNK / (OS_start + UNK)

            if True == self.log:
                print('epoch ', (epoch+1), 
                      ' OS^start ', "{:.5f}".format(OS_start),
                      ' UNK ', "{:.5f}".format(UNK),
                     ' HOS ', "{:.5f}".format(HOS))  
        #========================================================   

        #=============Formal training==========================
        print('Formal training...')
        for epoch in range(self.epoch):
            sc_loader = torch.utils.data.DataLoader(dataset=torch.cat((Xs, ys[:, None]), dim=1),
                                                                       batch_size=self.batch_size, shuffle=True, drop_last=False)
            tg_loader = torch.utils.data.DataLoader(dataset=torch.cat((Xt, yt_hat[:, None]), dim=1),
                                                                        batch_size=self.batch_size, shuffle=True, drop_last=False) 
            
            # 2 batches of identical size are drawn from the source and target datasets
            for sc_batch, tg_batch in zip(sc_loader, tg_loader):
                Xs_batch, ys_batch = sc_batch[:, :-1].to(self.device,torch.float32), sc_batch[:, -1].to(self.device,torch.int64)
                Xt_batch, yt_hat_batch = tg_batch[:, :-1].to(self.device,torch.float32), tg_batch[:, -1].to(self.device, torch.float64)
                
                probs, probt = net(Xs_batch)[1], net(Xt_batch)[1]
                loss1 = -self.theta * torch.mean(torch.sum(torch.log(probs[:, :-1]) * F.one_hot(ys_batch, c-1), dim=1))
                loss2 = -torch.mean(torch.log(probt[:, -1])) + self.theta * torch.mean(torch.log(probs[:, -1]))
                if loss2.item() >=0:
                    loss = loss1 + self.lamda * loss2
                else:
                    loss = loss1 - self.lamda * loss2
                
                #=====compute the TD distance===========
                TD_hat = TD(net(Xs_batch)[0], ys_batch, net(Xt_batch[yt_hat_batch<=class_labels[-1]])[0], yt_hat_batch[yt_hat_batch<=class_labels[-1]], self.device)
                #================================
                loss = loss + self.lamdaTD * TD_hat
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
            with torch.no_grad(): 
                yt_hat = torch.argmax(net(Xt)[1], dim=1)
                acc_class = torch.zeros(c)
                for i in range(c):
                    acc_class[i] = torch.sum((yt_hat[yt==i] == i)).item() / torch.sum((i == yt)).item()
                OS_start, UNK = torch.mean(acc_class[:-1]) * 100, acc_class[-1] * 100
                HOS = 2 * OS_start * UNK / (OS_start + UNK) 

            if True == self.log:
                print('epoch ', (epoch+1), 
                      ' OS^start ', "{:.5f}".format(OS_start),
                      ' UNK ', "{:.5f}".format(UNK),
                     ' HOS ', "{:.5f}".format(HOS))   
        #========================================================             
        self.net = net # save the network

In [5]:
import numpy as np
import scipy.io as sio
from sklearn.preprocessing import scale, LabelEncoder

In [6]:
def readData(sc, tg): 
    # the PIE_Multiview dataset has 67 classes
    # the first 30 classes are defined as the known classes and the remaining 37 classes are combined as unknown
    data =  sio.loadmat('PIE_Multiview/' + sc + '.mat')   # source domain
    Xs, ys = data['features'].astype(np.float64).T, data['labels'].ravel()
    Xs = torch.tensor(scale(Xs / Xs.sum(axis=1, keepdims=True)))
    ys = torch.tensor(LabelEncoder().fit(ys).transform(ys).astype(np.float64),dtype=torch.int64)
    
    data =  sio.loadmat('PIE_Multiview/' + tg + '.mat')   # target domain
    Xt, yt = data['features'].astype(np.float64).T, data['labels'].ravel()
    Xt = torch.tensor(scale(Xt / Xt.sum(axis=1, keepdims=True)))
    yt = torch.tensor(LabelEncoder().fit(yt).transform(yt).astype(np.float64),dtype=torch.int64)
    return Xs, ys, Xt, yt

domains = ['C27', 'C05', 'C37', 'C25', 'C02']

In [7]:
DEVICE = torch.device('cpu')

In [8]:
# compute the TD distance between C27 and C05, C37, C25, C02, results showing that the domain (joint distribution) shift is increasing
for tg in ['C05', 'C37', 'C25', 'C02']:
    Xs, ys, Xt, yt = readData('C27', tg)
    print(TD(Xs, ys, Xt, yt, DEVICE))

tensor(0.4771, dtype=torch.float64)
tensor(0.9514, dtype=torch.float64)
tensor(1.1012, dtype=torch.float64)
tensor(1.1921, dtype=torch.float64)


In [9]:
%%time
print('=============OSDA results under increasing domain shift=============')
for sc in ['C27']:
    for tg in ['C05', 'C37', 'C25', 'C02']:
        print('===============', sc, '-->', tg, '=================')
        Xs, ys, Xt, yt = readData(sc, tg)
        Xs, ys = Xs[ys < 30], ys[ys < 30]
        yt[yt >=30] = 30
        Xs, ys, Xt, yt = Xs.to(DEVICE), ys.to(DEVICE), Xt.to(DEVICE), yt.to(DEVICE)
        instance = KMUR(input_size=1024, hidden_size=512, output_size=31, device=DEVICE, 
                                    theta=30./67, lamda=1.0, lamdaTD=1.0, 
                                    pretrain_epoch=50, epoch=100, batch_size=400, lr=1e-3, log=True)
        instance.fit(Xs, ys, Xt, yt)

Pre-training for generating pseudo labels...
epoch  1  OS^start  1.74603  UNK  98.97040  HOS  3.43152
epoch  2  OS^start  1.58730  UNK  99.48520  HOS  3.12475
epoch  3  OS^start  8.09524  UNK  94.85200  HOS  14.91734
epoch  4  OS^start  34.28572  UNK  60.36036  HOS  43.73131
epoch  5  OS^start  18.88889  UNK  89.44659  HOS  31.19102
epoch  6  OS^start  35.71429  UNK  62.67696  HOS  45.50126
epoch  7  OS^start  36.66667  UNK  74.25997  HOS  49.09309
epoch  8  OS^start  23.01587  UNK  86.74389  HOS  36.37920
epoch  9  OS^start  24.60318  UNK  80.69498  HOS  37.70916
epoch  10  OS^start  31.74603  UNK  72.45818  HOS  44.14907
epoch  11  OS^start  47.30159  UNK  78.63577  HOS  59.07058
epoch  12  OS^start  71.11112  UNK  57.78636  HOS  63.76002
epoch  13  OS^start  76.34921  UNK  28.82883  HOS  41.85395
epoch  14  OS^start  76.34921  UNK  40.41184  HOS  52.85003
epoch  15  OS^start  63.01588  UNK  71.94337  HOS  67.18434
epoch  16  OS^start  52.69841  UNK  73.61648  HOS  61.42540
epoch  17

epoch  85  OS^start  85.87302  UNK  67.95367  HOS  75.86963
epoch  86  OS^start  87.14286  UNK  66.92406  HOS  75.70677
epoch  87  OS^start  89.84126  UNK  69.11197  HOS  78.12495
epoch  88  OS^start  91.74604  UNK  67.31017  HOS  77.65105
epoch  89  OS^start  93.96826  UNK  63.96396  HOS  76.11597
epoch  90  OS^start  92.38097  UNK  66.15186  HOS  77.09662
epoch  91  OS^start  87.14285  UNK  72.84428  HOS  79.35461
epoch  92  OS^start  86.19048  UNK  72.58687  HOS  78.80592
epoch  93  OS^start  87.46031  UNK  70.78507  HOS  78.24410
epoch  94  OS^start  90.31745  UNK  70.52767  HOS  79.20513
epoch  95  OS^start  91.58730  UNK  67.95367  HOS  78.02000
epoch  96  OS^start  92.53969  UNK  64.09267  HOS  75.73295
epoch  97  OS^start  92.69841  UNK  66.53796  HOS  77.46928
epoch  98  OS^start  88.73015  UNK  72.71558  HOS  79.92858
epoch  99  OS^start  85.87302  UNK  71.81467  HOS  78.21718
epoch  100  OS^start  86.98413  UNK  69.62677  HOS  77.34358
Pre-training for generating pseudo labe

epoch  71  OS^start  63.49207  UNK  56.58915  HOS  59.84220
epoch  72  OS^start  65.87302  UNK  55.42636  HOS  60.19984
epoch  73  OS^start  66.34921  UNK  52.84238  HOS  58.83049
epoch  74  OS^start  66.34921  UNK  51.67959  HOS  58.10276
epoch  75  OS^start  64.28572  UNK  55.03876  HOS  59.30395
epoch  76  OS^start  65.55556  UNK  54.65117  HOS  59.60877
epoch  77  OS^start  67.77779  UNK  52.84238  HOS  59.38542
epoch  78  OS^start  70.79365  UNK  47.41602  HOS  56.79320
epoch  79  OS^start  68.88889  UNK  50.25840  HOS  58.11706
epoch  80  OS^start  67.30158  UNK  52.84238  HOS  59.20190
epoch  81  OS^start  65.07937  UNK  55.29716  HOS  59.79080
epoch  82  OS^start  67.61904  UNK  50.64600  HOS  57.91456
epoch  83  OS^start  68.73016  UNK  49.22481  HOS  57.36475
epoch  84  OS^start  66.19048  UNK  52.58398  HOS  58.60786
epoch  85  OS^start  65.23810  UNK  54.13437  HOS  59.16982
epoch  86  OS^start  66.34921  UNK  54.65117  HOS  59.93472
epoch  87  OS^start  68.88889  UNK  52.7

epoch  58  OS^start  38.73016  UNK  50.83655  HOS  43.96516
epoch  59  OS^start  40.00000  UNK  50.19305  HOS  44.52055
epoch  60  OS^start  41.58730  UNK  47.74775  HOS  44.45511
epoch  61  OS^start  43.17460  UNK  45.55984  HOS  44.33516
epoch  62  OS^start  43.17460  UNK  46.71815  HOS  44.87653
epoch  63  OS^start  42.38095  UNK  49.42085  HOS  45.63097
epoch  64  OS^start  42.38095  UNK  49.03475  HOS  45.46570
epoch  65  OS^start  42.53968  UNK  49.54955  HOS  45.77782
epoch  66  OS^start  42.85714  UNK  48.77735  HOS  45.62600
epoch  67  OS^start  43.96825  UNK  46.58945  HOS  45.24092
epoch  68  OS^start  44.60318  UNK  45.68855  HOS  45.13934
epoch  69  OS^start  43.33333  UNK  46.58945  HOS  44.90244
epoch  70  OS^start  41.26984  UNK  49.54955  HOS  45.03228
epoch  71  OS^start  40.31746  UNK  51.48005  HOS  45.22007
epoch  72  OS^start  40.00000  UNK  51.09395  HOS  44.87143
epoch  73  OS^start  40.15873  UNK  49.67825  HOS  44.41413
epoch  74  OS^start  40.47619  UNK  48.9

epoch  45  OS^start  37.30159  UNK  48.90605  HOS  42.32278
epoch  46  OS^start  41.11112  UNK  42.47104  HOS  41.78002
epoch  47  OS^start  42.06349  UNK  42.08494  HOS  42.07421
epoch  48  OS^start  38.88889  UNK  46.97555  HOS  42.55142
epoch  49  OS^start  38.73016  UNK  47.87645  HOS  42.82035
epoch  50  OS^start  37.77777  UNK  48.39125  HOS  42.43088
epoch  51  OS^start  37.93650  UNK  46.97555  HOS  41.97492
epoch  52  OS^start  37.61905  UNK  46.97555  HOS  41.77987
epoch  53  OS^start  38.88889  UNK  44.91634  HOS  41.68587
epoch  54  OS^start  40.47619  UNK  42.98584  HOS  41.69328
epoch  55  OS^start  39.84127  UNK  44.65894  HOS  42.11277
epoch  56  OS^start  38.73016  UNK  47.74775  HOS  42.76880
epoch  57  OS^start  37.46032  UNK  48.77735  HOS  42.37626
epoch  58  OS^start  38.73016  UNK  45.94595  HOS  42.03060
epoch  59  OS^start  41.42857  UNK  41.31274  HOS  41.37058
epoch  60  OS^start  41.26984  UNK  43.62934  HOS  42.41681
epoch  61  OS^start  37.93651  UNK  49.5