In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
# define the one-hidden-layer neural network model as the classification model
# the hidden layer transforms the neural network features for joint distribution alignment
# the output layer produces the probability value for each class
class NeuralNet(nn.Module):
    def __init__(self, input_size=1000, hidden_size=100, output_size=10):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=output_size, bias=True)
    def forward(self, X):
        FX = F.relu(self.fc1(X)) # hidden layer activation features
        prob = F.softmax(self.fc2(FX), dim=1) # probability output
        return FX, prob

In [3]:
# compute the Triangular Discrimination (TD) distance between joint distributions
def TD(Xs, ys, Xt, yt, device):
    ms, mt = len(Xs), len(Xt)
    X, y = torch.cat((Xs, Xt), dim=0), torch.cat((ys, yt), dim=0)
    
    # compute the Gaussian kernel width
    with torch.no_grad(): 
        pairwise_dist = torch.cdist(X, X, p=2)**2 
        sigma = torch.median(pairwise_dist[pairwise_dist!=0])
    
    # see the paper for the detailed derivations of the following equations            
    X_norm = torch.sum(X ** 2, axis = -1)
    K = torch.exp(-(X_norm[:,None] + X_norm[None,:] - 2 * torch.matmul(X, X.t())) / sigma) * torch.as_tensor(y[:,None]==y, dtype=torch.float32, device=device) # kernel matrix  
    Ks, Kt = K[:ms], K[ms:]
    H = 1.0 / ms * torch.matmul(Ks.t(), Ks) + 1.0 / mt * torch.matmul(Kt.t(), Kt) 
    invM = torch.inverse(H + 0.01 * torch.eye(ms+mt, device=device)) 
    b = torch.mean(Ks, dim=0) - torch.mean(Kt, dim=0) 
    theta = torch.matmul(invM, b)
    TD_hat = 2 * torch.matmul(b, theta) - torch.matmul(theta, torch.matmul(H, theta))
    return TD_hat   

In [4]:
class KMUR:
    def __init__(self, input_size=1024, hidden_size=512, output_size=68, device=torch.device('cpu'), theta=0.5, lamda=1.0,
                 pretrain_epoch=100, epoch=200, lamdaTD=1.0, batch_size=200, lr=1e-3, log=False):
        # in the training procedure, the total batch size per iteration is: batch_size * 2, 1 for source and 1 for target
        args_values = locals()
        args_values.pop("self")
        for arg,value in args_values.items():
            setattr(self, arg, value)
            
    def fit(self, Xs, ys, Xt, yt):
        Xs = torch.as_tensor(Xs, dtype=torch.float32, device=self.device)
        Xt = torch.as_tensor(Xt, dtype=torch.float32, device=self.device)
        ys = torch.as_tensor(ys, dtype=torch.int64, device=self.device)
        yt = torch.as_tensor(yt, dtype=torch.int64, device=self.device)
        
        class_labels = torch.unique(ys)
        c = len(class_labels) + 1 # number of all classes including the unknown
        
        # define the neural network instance and the optimizer
        net = NeuralNet(input_size=self.input_size, hidden_size=self.hidden_size, output_size=self.output_size).to(self.device)
        optimizer = optim.Adam(params=net.parameters(), lr=self.lr, weight_decay=1e-4)

        #=============Pre-training the network==========================
        print('Pre-training for generating pseudo labels...')
        m_batch = self.batch_size
        for epoch in range(self.pretrain_epoch):
            sc_loader = torch.utils.data.DataLoader(dataset=torch.cat((Xs, ys[:, None]), dim=1),
                                                                       batch_size=self.batch_size, shuffle=True, drop_last=False)
            tg_loader = torch.utils.data.DataLoader(dataset=Xt,
                                                                        batch_size=self.batch_size, shuffle=True, drop_last=False) 
            
            # 2 batches of identical size are drawn from the source and target datasets
            for sc_batch, tg_batch in zip(sc_loader, tg_loader):
                Xs_batch, ys_batch = sc_batch[:, :-1].to(self.device,torch.float32), sc_batch[:, -1].to(self.device, torch.int64)
                Xt_batch = tg_batch.to(self.device, torch.float32)
                
                probs, probt = net(Xs_batch)[1], net(Xt_batch)[1]
                loss1 = -self.theta * torch.mean(torch.sum(torch.log(probs[:, :-1]) * F.one_hot(ys_batch, c-1), dim=1))
                loss2 = -torch.mean(torch.log(probt[:, -1])) + self.theta * torch.mean(torch.log(probs[:, -1]))
                if loss2.item() >= 0:
                    loss = loss1 + self.lamda * loss2
                else:
                    loss = loss1 - self.lamda * loss2
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            with torch.no_grad():
                yt_hat = torch.argmax(net(Xt)[1], dim=1)
                acc_class = torch.zeros(c)
                for i in range(c):
                    acc_class[i] = torch.sum((yt_hat[yt==i] == i)).item() / torch.sum((i == yt)).item()
                OS_start, UNK = torch.mean(acc_class[:-1]) * 100, acc_class[-1] * 100
                HOS = 2 * OS_start * UNK / (OS_start + UNK)

            if True == self.log:
                print('epoch ', (epoch+1), 
                      ' OS^start ', "{:.5f}".format(OS_start),
                      ' UNK ', "{:.5f}".format(UNK),
                     ' HOS ', "{:.5f}".format(HOS))  
        #========================================================   

        #=============Formal training==========================
        print('Formal training...')
        for epoch in range(self.epoch):
            sc_loader = torch.utils.data.DataLoader(dataset=torch.cat((Xs, ys[:, None]), dim=1),
                                                                       batch_size=self.batch_size, shuffle=True, drop_last=False)
            tg_loader = torch.utils.data.DataLoader(dataset=torch.cat((Xt, yt_hat[:, None]), dim=1),
                                                                        batch_size=self.batch_size, shuffle=True, drop_last=False) 
            
            # 2 batches of identical size are drawn from the source and target datasets
            for sc_batch, tg_batch in zip(sc_loader, tg_loader):
                Xs_batch, ys_batch = sc_batch[:, :-1].to(self.device,torch.float32), sc_batch[:, -1].to(self.device,torch.int64)
                Xt_batch, yt_hat_batch = tg_batch[:, :-1].to(self.device,torch.float32), tg_batch[:, -1].to(self.device, torch.float64)
                
                probs, probt = net(Xs_batch)[1], net(Xt_batch)[1]
                loss1 = -self.theta * torch.mean(torch.sum(torch.log(probs[:, :-1]) * F.one_hot(ys_batch, c-1), dim=1))
                loss2 = -torch.mean(torch.log(probt[:, -1])) + self.theta * torch.mean(torch.log(probs[:, -1]))
                if loss2.item() >=0:
                    loss = loss1 + self.lamda * loss2
                else:
                    loss = loss1 - self.lamda * loss2
                
                #=====compute the TD distance===========
                TD_hat = TD(net(Xs_batch)[0], ys_batch, net(Xt_batch[yt_hat_batch<=class_labels[-1]])[0], yt_hat_batch[yt_hat_batch<=class_labels[-1]], self.device)
                #================================
                loss = loss + self.lamdaTD * TD_hat
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
            with torch.no_grad(): 
                yt_hat = torch.argmax(net(Xt)[1], dim=1)
                acc_class = torch.zeros(c)
                for i in range(c):
                    acc_class[i] = torch.sum((yt_hat[yt==i] == i)).item() / torch.sum((i == yt)).item()
                OS_start, UNK = torch.mean(acc_class[:-1]) * 100, acc_class[-1] * 100
                HOS = 2 * OS_start * UNK / (OS_start + UNK) 

            if True == self.log:
                print('epoch ', (epoch+1), 
                      ' OS^start ', "{:.5f}".format(OS_start),
                      ' UNK ', "{:.5f}".format(UNK),
                     ' HOS ', "{:.5f}".format(HOS))   
        #========================================================             
        self.net = net # save the network

In [5]:
import numpy as np
import scipy.io as sio
from sklearn.preprocessing import scale, LabelEncoder

In [6]:
def readData(sc, tg): 
    # the PIE_Multiview dataset has 67 classes
    # the first 30 classes are defined as the known classes and the remaining 37 classes are combined as unknown
    data =  sio.loadmat('PIE_Multiview/' + sc + '.mat')   # source domain
    Xs, ys = data['features'].astype(np.float64).T, data['labels'].ravel()
    Xs = torch.tensor(scale(Xs / Xs.sum(axis=1, keepdims=True)))
    ys = torch.tensor(LabelEncoder().fit(ys).transform(ys).astype(np.float64),dtype=torch.int64)
    
    data =  sio.loadmat('PIE_Multiview/' + tg + '.mat')   # target domain
    Xt, yt = data['features'].astype(np.float64).T, data['labels'].ravel()
    Xt = torch.tensor(scale(Xt / Xt.sum(axis=1, keepdims=True)))
    yt = torch.tensor(LabelEncoder().fit(yt).transform(yt).astype(np.float64),dtype=torch.int64)
    return Xs, ys, Xt, yt

domains = ['C27', 'C05', 'C37', 'C25', 'C02']

In [7]:
DEVICE = torch.device('cpu')

In [8]:
# compute the TD distance between C27 and C05, C37, C25, C02, results showing that the domain (joint distribution) shift is increasing
for tg in ['C05', 'C37', 'C25', 'C02']:
    Xs, ys, Xt, yt = readData('C27', tg)
    print(TD(Xs, ys, Xt, yt, DEVICE))

tensor(0.4771, dtype=torch.float64)
tensor(0.9514, dtype=torch.float64)
tensor(1.1012, dtype=torch.float64)
tensor(1.1921, dtype=torch.float64)


In [9]:
%%time
print('=============OSDA results under increasing domain shift=============')
for sc in ['C27']:
    for tg in ['C05', 'C37', 'C25', 'C02']:
        print('===============', sc, '-->', tg, '=================')
        Xs, ys, Xt, yt = readData(sc, tg)
        Xs, ys = Xs[ys < 30], ys[ys < 30]
        yt[yt >=30] = 30
        Xs, ys, Xt, yt = Xs.to(DEVICE), ys.to(DEVICE), Xt.to(DEVICE), yt.to(DEVICE)
        instance = KMUR(input_size=1024, hidden_size=512, output_size=31, device=DEVICE, 
                                    theta=30./67, lamda=1.0, lamdaTD=1.0, 
                                    pretrain_epoch=50, epoch=100, batch_size=400, lr=1e-3, log=True)
        instance.fit(Xs, ys, Xt, yt)

Pre-training for generating pseudo labels...
epoch  1  OS^start  2.53968  UNK  99.09910  HOS  4.95245
epoch  2  OS^start  4.12698  UNK  99.74260  HOS  7.92602
epoch  3  OS^start  8.88889  UNK  89.83269  HOS  16.17707
epoch  4  OS^start  28.41270  UNK  75.16087  HOS  41.23684
epoch  5  OS^start  40.15873  UNK  65.25096  HOS  49.71831
epoch  6  OS^start  42.38095  UNK  64.86486  HOS  51.26605
epoch  7  OS^start  52.38095  UNK  44.01544  HOS  47.83521
epoch  8  OS^start  54.60318  UNK  56.62806  HOS  55.59719
epoch  9  OS^start  56.34921  UNK  47.49035  HOS  51.54189
epoch  10  OS^start  54.60318  UNK  55.08366  HOS  54.84237
epoch  11  OS^start  54.12699  UNK  42.21364  HOS  47.43372
epoch  12  OS^start  56.19048  UNK  42.47104  HOS  48.37688
epoch  13  OS^start  58.57143  UNK  46.33205  HOS  51.73774
epoch  14  OS^start  63.65080  UNK  52.25225  HOS  57.39103
epoch  15  OS^start  70.00000  UNK  36.80824  HOS  48.24678
epoch  16  OS^start  73.17460  UNK  39.63964  HOS  51.42285
epoch  17

epoch  86  OS^start  90.00000  UNK  68.21107  HOS  77.60515
epoch  87  OS^start  87.14285  UNK  71.42857  HOS  78.50708
epoch  88  OS^start  82.06349  UNK  75.41827  HOS  78.60068
epoch  89  OS^start  81.58731  UNK  74.25997  HOS  77.75139
epoch  90  OS^start  83.17461  UNK  64.47877  HOS  72.64305
epoch  91  OS^start  86.66667  UNK  67.69627  HOS  76.01579
epoch  92  OS^start  88.09525  UNK  72.71558  HOS  79.66997
epoch  93  OS^start  82.85715  UNK  73.23038  HOS  77.74689
epoch  94  OS^start  80.00000  UNK  68.85457  HOS  74.01003
epoch  95  OS^start  85.23809  UNK  68.21107  HOS  75.77990
epoch  96  OS^start  92.06350  UNK  73.35907  HOS  81.65382
epoch  97  OS^start  90.63493  UNK  72.97297  HOS  80.85062
epoch  98  OS^start  87.93651  UNK  66.66667  HOS  75.83847
epoch  99  OS^start  87.77778  UNK  75.93307  HOS  81.42694
epoch  100  OS^start  86.34920  UNK  76.06178  HOS  80.87967
Pre-training for generating pseudo labels...
epoch  1  OS^start  0.63492  UNK  98.19122  HOS  1.261

epoch  73  OS^start  69.04762  UNK  54.13437  HOS  60.68824
epoch  74  OS^start  66.34921  UNK  57.36434  HOS  61.53050
epoch  75  OS^start  66.19048  UNK  56.84754  HOS  61.16428
epoch  76  OS^start  66.82539  UNK  56.33075  HOS  61.13092
epoch  77  OS^start  67.14285  UNK  56.71834  HOS  61.49192
epoch  78  OS^start  68.88889  UNK  55.55556  HOS  61.50794
epoch  79  OS^start  69.36508  UNK  56.58915  HOS  62.32917
epoch  80  OS^start  68.09524  UNK  57.75193  HOS  62.49853
epoch  81  OS^start  69.04762  UNK  57.36434  HOS  62.66608
epoch  82  OS^start  69.20634  UNK  56.97674  HOS  62.49890
epoch  83  OS^start  72.53968  UNK  52.97158  HOS  61.23022
epoch  84  OS^start  71.58730  UNK  55.03876  HOS  62.23168
epoch  85  OS^start  70.31747  UNK  56.20155  HOS  62.47205
epoch  86  OS^start  66.98414  UNK  59.04393  HOS  62.76390
epoch  87  OS^start  68.09524  UNK  57.62274  HOS  62.42280
epoch  88  OS^start  68.73016  UNK  56.45995  HOS  61.99374
epoch  89  OS^start  69.52381  UNK  55.5

epoch  59  OS^start  43.96826  UNK  52.89576  HOS  48.02060
epoch  60  OS^start  42.69841  UNK  53.66796  HOS  47.55885
epoch  61  OS^start  40.79365  UNK  55.08366  HOS  46.87373
epoch  62  OS^start  41.26984  UNK  54.44015  HOS  46.94884
epoch  63  OS^start  41.42857  UNK  52.89576  HOS  46.46513
epoch  64  OS^start  42.53968  UNK  51.60875  HOS  46.63742
epoch  65  OS^start  42.85714  UNK  54.44015  HOS  47.95918
epoch  66  OS^start  42.22223  UNK  55.85586  HOS  48.09145
epoch  67  OS^start  42.85714  UNK  55.21236  HOS  48.25647
epoch  68  OS^start  42.38095  UNK  54.95495  HOS  47.85579
epoch  69  OS^start  43.01587  UNK  52.89576  HOS  47.44695
epoch  70  OS^start  43.49207  UNK  51.73745  HOS  47.25779
epoch  71  OS^start  44.12698  UNK  50.32175  HOS  47.02122
epoch  72  OS^start  43.65080  UNK  51.22265  HOS  47.13457
epoch  73  OS^start  43.01587  UNK  51.60875  HOS  46.92215
epoch  74  OS^start  43.01587  UNK  52.89576  HOS  47.44695
epoch  75  OS^start  43.65080  UNK  52.6

epoch  45  OS^start  25.71429  UNK  43.37194  HOS  32.28657
epoch  46  OS^start  25.39683  UNK  45.17375  HOS  32.51412
epoch  47  OS^start  25.23810  UNK  44.40154  HOS  32.18312
epoch  48  OS^start  25.87302  UNK  44.14415  HOS  32.62463
epoch  49  OS^start  24.92064  UNK  45.43114  HOS  32.18605
epoch  50  OS^start  25.87301  UNK  44.01544  HOS  32.58942
epoch  51  OS^start  26.50794  UNK  43.50064  HOS  32.94203
epoch  52  OS^start  26.34921  UNK  44.14415  HOS  33.00065
epoch  53  OS^start  26.03175  UNK  44.40154  HOS  32.82112
epoch  54  OS^start  26.19048  UNK  44.53025  HOS  32.98236
epoch  55  OS^start  26.82540  UNK  43.37194  HOS  33.14854
epoch  56  OS^start  28.09524  UNK  40.41184  HOS  33.14637
epoch  57  OS^start  27.61905  UNK  42.72844  HOS  33.55114
epoch  58  OS^start  27.46032  UNK  43.50064  HOS  33.66757
epoch  59  OS^start  28.09524  UNK  42.47104  HOS  33.81882
epoch  60  OS^start  28.09524  UNK  41.69884  HOS  33.57130
epoch  61  OS^start  28.25397  UNK  42.7