# Correlation Alignment (CORAL)

CORAL minimizes domain shift by aligning the second-order statistics of source and target distributions, without requiring any target labels.

To minimize the distance between the second-order statistics (covariance) of the source and target features, a linear transformation A is applied to the original source features and use the Frobenius norm as the matrix distance metric.

Reference

Sun B, Feng J, Saenko K. Return of frustratingly easy domain adaptation. AAAI. 2016, 6(7): 8.

In [2]:
import numpy as np
import scipy.io
import scipy.linalg
import sklearn.metrics
import sklearn.neighbors

In [13]:
class CORAL:
    def __init__(self):
        super(CORAL, self).__init__()

    def fit(self, Xs, Xt):
        '''
        Perform CORAL on the source domain features
        :param Xs: ns * n_feature, source feature
        :param Xt: nt * n_feature, target feature
        :return: New source domain features
        '''
        cov_src = np.cov(Xs.T) + np.eye(Xs.shape[1])
        cov_tar = np.cov(Xt.T) + np.eye(Xt.shape[1])
        A_coral = np.dot(scipy.linalg.fractional_matrix_power(cov_src, -0.5),
                         scipy.linalg.fractional_matrix_power(cov_tar, 0.5))
        Xs_new = np.dot(Xs, A_coral).astype(float)
        #Xs_new = np.dot(Xs, A_coral)
        return Xs_new

    def fit_predict(self, Xs, Ys, Xt, Yt):
        '''
        Perform CORAL, then predict using 1NN classifier
        :param Xs: ns * n_feature, source feature
        :param Ys: ns * 1, source label
        :param Xt: nt * n_feature, target feature
        :param Yt: nt * 1, target label
        :return: Accuracy and predicted labels of target domain
        '''
        Xs_new = self.fit(Xs, Xt)
        clf = sklearn.neighbors.KNeighborsClassifier(n_neighbors=1)
        clf.fit(Xs_new, Ys.ravel())
        y_pred = clf.predict(Xt)
        acc = sklearn.metrics.accuracy_score(Yt, y_pred)
        return acc, y_pred

In [15]:
if __name__ == '__main__':
    domains = ['caltech_SURF_L10.mat', 'amazon_SURF_L10.mat', 'webcam_SURF_L10.mat', 'dslr_SURF_L10.mat']
    for i in [2]:
        for j in [3]:
            if i != j:
                src, tar = 'data\\' + domains[i], 'data\\' + domains[j]
                src_domain, tar_domain = scipy.io.loadmat(src), scipy.io.loadmat(tar)
                Xs, Ys, Xt, Yt = src_domain['fts'], src_domain['labels'], tar_domain['fts'], tar_domain['labels']
                coral = CORAL()
                acc, ypre = coral.fit_predict(Xs, Ys, Xt, Yt)
                print("Classification accuracy on" ,domains[i], "v.s.", domains[j], "=", acc)

Classification accuracy on webcam_SURF_L10.mat v.s. dslr_SURF_L10.mat = 0.7133757961783439


  Xs_new = np.dot(Xs, A_coral).astype(float)
