In [1]:
from sklearn.decomposition import PCA
import torch
import scipy
from torch.utils.data import Dataset
from collections import Counter

The history saving thread hit an unexpected error (OperationalError('disk I/O error',)).History will not be written to the database.


In [2]:
def _load_tensors(domain):
    features = torch.load(f'./data_handling/features/OH_{domain}_features.pt')
    labels = torch.load(f'./data_handling/features/OH_{domain}_labels.pt')
    return features, labels

class FeaturesDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
    def __len__(self): return len(self.labels)
    def __getitem__(self, i): return self.features[i], self.labels[i]

In [53]:
def create_datasets(source, target, num_src_classes, num_total_classes):
    src_features, src_labels = _load_tensors(source)
    idxs = torch.where(src_labels < num_src_classes)[0]
    src_features, src_labels = src_features[idxs], src_labels[idxs]
    src_ds = FeaturesDataset(src_features, src_labels)

    tgt_features, tgt_labels = _load_tensors(target)
    idxs = torch.where(tgt_labels < num_total_classes)[0]
    tgt_features, tgt_labels = tgt_features[idxs], tgt_labels[idxs]
    tgt_labels[tgt_labels >= num_src_classes] = num_src_classes
    tgt_ds = FeaturesDataset(tgt_features, tgt_labels)

    return src_ds, tgt_ds

In [54]:
def get_PCA(features, dim):
    return PCA(n_components=dim).fit_transform(features)
def get_l2_norm(t):
    return t.square().sum(dim=1).sqrt()
def l2_normalize(t):
    return t / get_l2_norm(t).unsqueeze(1)

In [55]:
def get_projection_matrix(features, labels, dim):
    # features - tensor of shape Nxd
    # returns projection matrix of shape dxdim
    N, d = features.shape
    X = features.transpose(0,1)
    W = torch.zeros((N,N)).int()
    for i,l in enumerate(labels):
        W[i] = labels == l
    D = torch.zeros_like(W)
    for i in range(len(D)):
        D[i][i] = W[i].sum()
    L = D - W

    A = X.matmul(D.double()).matmul(X.transpose(0,1))
    B = X.matmul(L.double()).matmul(X.transpose(0,1)) + torch.eye(len(X))
    A = (A + A.transpose(0,1)) * 0.5
    B = (B + B.transpose(0,1)) * 0.5
    A = A.numpy()
    B = B.numpy()

    w,v = scipy.linalg.eigh(A, B)
    assert w[-1] == w.max(), 'something changed in scipy'
    return torch.tensor(v[:,-dim:])

def project(features, P):
    return P.transpose(0,1).matmul(features.transpose(0,1)).transpose(0,1)

In [56]:
def get_centroids(features, labels):
    centroids = [features[labels == c].mean(dim=0) for c in labels.unique()]
    return torch.stack(centroids, dim=0)
    
def get_closest_centroid(f, centroids):
    dists = (f - centroids).square().sum(dim=-1).sqrt()
    pseudo = dists.argmin()
    prob = torch.exp(-dists[pseudo]) / torch.exp(-dists).sum()
    return pseudo, prob

def predict_closed_set_pseudo_labels(src_zs, src_labels, tgt_zs):
    centroids = get_centroids(src_zs, src_labels)
    pseudo_preds = [get_closest_centroid(f, centroids) for f in tgt_zs]
    pseudo_labels = torch.stack([_[0] for _ in pseudo_preds])
    pseudo_probs = torch.stack([_[1] for _ in pseudo_preds])
    return pseudo_labels, pseudo_probs

def get_initial_private(pseudo_labels, pseudo_probs, n_r):
    is_private = torch.zeros_like(pseudo_labels)
    is_private[pseudo_probs.sort()[1][:n_r]] = 1
    return is_private

In [57]:
def select_pseudo_labels(t, T, num_src_classes, tgt_zs, pseudo_labels, pseudo_probs, is_private):
    selected_pseudo_labels = torch.ones_like(pseudo_labels) * -1
    classes = pseudo_labels.unique() 
    for c in classes:
        idxs = torch.where((pseudo_labels == c) * (is_private == 0))[0]
        if len(idxs) > 0:
            selected_pseudo_labels[idxs[pseudo_probs[idxs].sort()[1][-(t*len(idxs)//T):]]] = c
    idxs_lbld = torch.cat((torch.where(selected_pseudo_labels >= 0)[0], torch.where(is_private == 1)[0]), dim=0)
    idxs_unlbld = torch.where((selected_pseudo_labels < 0) * (is_private == 0))[0]
    new_is_private = is_private.clone()
    for i in idxs_unlbld:
        dists = (tgt_zs[i] - tgt_zs[idxs_lbld]).square().sum(dim=1)
        nn_idx = idxs_lbld[dists.argmin()]
        if is_private[nn_idx]:
            new_is_private[i] = 1
    selected_pseudo_labels[new_is_private > 0] = num_src_classes
    return selected_pseudo_labels, new_is_private

In [58]:
def evaluate(cs_pseudo_labels, is_private, num_src_classes, tgt_labels):
    labels_predicted = cs_pseudo_labels.clone()
    labels_predicted[is_private == 1] = num_src_classes
    return (labels_predicted == tgt_labels).float().mean()

In [59]:
params = type('test', (), {})()
params.pca_dim = 512
params.proj_dim = 128
params.T = 10
params.n_r  = 1200
params.l2_first = False
params.dataset = 'OfficeHome'
params.source = 'art'
params.target = 'clipart'
params.num_src_classes = 25
params.num_total_classes = 65

In [62]:
src_ds, tgt_ds = create_datasets(params.source, params.target, params.num_src_classes, params.num_total_classes)
assert (src_ds.labels.unique() == torch.arange(0, params.num_src_classes)).all()
assert (tgt_ds.labels.unique() == torch.arange(0, params.num_src_classes+1)).all()

In [63]:
len(src_ds), len(tgt_ds)

(1089, 4365)

In [64]:
pca_features = torch.cat((src_ds.features, tgt_ds.features), dim=0)
if params.l2_first:
    pca_features = l2_normalize(pca_features)
    pca_features = torch.tensor(get_PCA(pca_features, params.pca_dim))
else:
    pca_features = torch.tensor(get_PCA(pca_features, params.pca_dim))
    pca_features = l2_normalize(pca_features)
src_features, tgt_features = pca_features[:len(src_ds)], pca_features[len(src_ds):]

src_labels, tgt_labels = src_ds.labels, tgt_ds.labels

In [65]:
P = get_projection_matrix(src_features, src_labels, params.proj_dim)
src_zs, tgt_zs = project(src_features, P), project(tgt_features, P)
cs_pseudo_labels, cs_pseudo_probs = predict_closed_set_pseudo_labels(src_zs, src_labels, tgt_zs)
is_private = get_initial_private(cs_pseudo_labels, cs_pseudo_probs, params.n_r)
print(0, is_private.sum())

0 tensor(1200)


In [66]:
for t in range(1, params.T+1):
    selected, is_private = select_pseudo_labels(t, params.T, params.num_src_classes, tgt_zs, cs_pseudo_labels, cs_pseudo_probs, is_private)
    print(t, is_private.sum())
    _features = torch.cat((src_features, tgt_features[selected >= 0]), dim=0)
    _labels = torch.cat((src_labels, selected[selected >= 0]), dim=0)
    P = get_projection_matrix(_features, _labels, params.proj_dim)
    src_zs, tgt_zs = project(src_features, P), project(tgt_features, P)
    cs_pseudo_labels, cs_pseudo_probs = predict_closed_set_pseudo_labels(src_zs, src_labels, tgt_zs)

1 tensor(3081)
2 tensor(3753)
3 tensor(3999)
4 tensor(4098)
5 tensor(4153)
6 tensor(4181)
7 tensor(4190)
8 tensor(4193)
9 tensor(4193)
10 tensor(4193)


In [67]:
tgt_labels[is_private == 1].unique()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25])

In [68]:
labels_predicted = cs_pseudo_labels.clone()
labels_predicted[is_private == 1] = params.num_src_classes
(labels_predicted == tgt_labels).float().mean()

tensor(0.6241)

In [72]:
acc_unk = (labels_predicted[tgt_labels == params.num_src_classes] == params.num_src_classes).float().mean(); acc_unk

tensor(0.9792)

In [75]:
accs = torch.stack([(labels_predicted[tgt_labels == c] == c).float().mean() for c in range(params.num_src_classes)])

In [79]:
acc_knw = accs.mean(); acc_knw

tensor(0.0602)

In [81]:
h_score = 2 * acc_unk * acc_knw / (acc_unk + acc_knw); h_score

tensor(0.1134)