In [1]:
import torch
data = torch.load("d3s_training_activations.pth")

In [2]:
activations = data['activations'].reshape(-1,2048)
classes = data['classes'].reshape(-1)
backgrounds = data['backgrounds'].reshape(-1)

In [3]:
# Split train and validation sets
import numpy as np
np.random.seed(0)
val_indices = np.random.choice(100000,size=10000,replace=False)
val_indices.sort()
train_indices = np.array(list(set(range(100000)) - set(list(val_indices))))
train_indices.sort()
train_activations = activations[train_indices]
train_classes = classes[train_indices]
train_backgrounds = backgrounds[train_indices]

val_activations = activations[val_indices]
val_classes = classes[val_indices]
val_backgrounds = backgrounds[val_indices]

In [5]:
from sklearnex import patch_sklearn
patch_sklearn()
from sklearn.linear_model import LogisticRegression

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [6]:
## Section adapted from https://github.com/shauli-ravfogel/nullspace_projection ###

import scipy
from typing import List

def get_rowspace_projection(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix over the rowspace
    """

    if np.allclose(W, 0):
        w_basis = np.zeros_like(W.T)
    else:
        w_basis = scipy.linalg.orth(W.T) # orthogonal basis

    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace

    return P_W

def get_projection_to_intersection_of_nullspaces(rowspace_projection_matrices: List[np.ndarray], input_dim: int):
    """
    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),
    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.
    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:
    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))
    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections
    :param dim: input dim
    """

    I = np.eye(input_dim)
    Q = np.sum(rowspace_projection_matrices, axis = 0)
    P = I - get_rowspace_projection(Q)

    return P


In [7]:
# Faster implementation for orthogonal projection, using GPU
## Section adapted from https://github.com/shauli-ravfogel/nullspace_projection ###

import scipy
from typing import List

def orth_torch(A):
    u, s, vh = torch.linalg.svd(torch.tensor(A).cuda(), full_matrices=False)
    M, N = u.shape[0], vh.shape[1]
    rcond = torch.finfo(s.dtype).eps * max(M, N)
    tol = torch.amax(s.reshape(-1)) * rcond
    num = torch.sum(s > tol, dtype=int)
    Q = u[:, :num]
    return Q.cpu().numpy()



def get_rowspace_projection_torch(W: np.ndarray) -> np.ndarray:
    """
    :param W: the matrix over its nullspace to project
    :return: the projection matrix over the rowspace
    """

    if np.allclose(W, 0):
        w_basis = np.zeros_like(W.T)
    else:
        w_basis = orth_torch(W.T) # orthogonal basis

    P_W = w_basis.dot(w_basis.T) # orthogonal projection on W's rowspace

    return P_W

def get_projection_to_intersection_of_nullspaces_torch(rowspace_projection_matrices: List[np.ndarray], input_dim: int):
    """
    Given a list of rowspace projection matrices P_R(w_1), ..., P_R(w_n),
    this function calculates the projection to the intersection of all nullspasces of the matrices w_1, ..., w_n.
    uses the intersection-projection formula of Ben-Israel 2013 http://benisrael.net/BEN-ISRAEL-NOV-30-13.pdf:
    N(w1)∩ N(w2) ∩ ... ∩ N(wn) = N(P_R(w1) + P_R(w2) + ... + P_R(wn))
    :param rowspace_projection_matrices: List[np.array], a list of rowspace projections
    :param dim: input dim
    """

    I = np.eye(input_dim)
    Q = np.sum(rowspace_projection_matrices, axis = 0)
    P = I - get_rowspace_projection_torch(Q)

    return P



In [19]:
# Baseline for background (all features)
clf = LogisticRegression(random_state=0, multi_class='multinomial', max_iter=400, C = 0.1).fit(train_activations, train_backgrounds)
print(clf.score(train_activations, train_backgrounds))
print(clf.score(val_activations, val_backgrounds))

0.8603888888888889
0.8057


In [35]:
#baseline for foreground (all features)
clf = LogisticRegression(random_state=0, multi_class='multinomial', max_iter=400, C = 0.1, verbose=1).fit(train_activations, train_classes)
print(clf.score(train_activations, train_classes))
print(clf.score(val_activations, val_classes))

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  7.6min finished


0.9842333333333333
0.8144


In [9]:
# Iteratively find subspaces with background info
train_activations_proj = train_activations
val_activations_proj = val_activations
rowspace_projections = []

# Let's go for < 20% val accuracy:
val_acc = 1.
while (True):
    clf = LogisticRegression(random_state=0, multi_class='multinomial', max_iter=400, C = 0.1).fit(train_activations_proj, train_backgrounds)
    print(clf.score(train_activations_proj, train_backgrounds))
    val_acc = clf.score(val_activations_proj, val_backgrounds)
    print(val_acc)
    if (val_acc < .2):
        break
    W = clf.coef_
    P_rowspace_wi = get_rowspace_projection_torch(W)
    rowspace_projections.append(P_rowspace_wi)
    P = get_projection_to_intersection_of_nullspaces_torch(rowspace_projections, 2048)
    train_activations_proj = (P.dot(train_activations.T)).T
    val_activations_proj = (P.dot(val_activations.T)).T



0.8603888888888889
0.8057
0.7753666666666666
0.7577
0.7213222222222222
0.7134
0.6795666666666667
0.6734
0.6403666666666666
0.6352
0.6023888888888889
0.6012
0.5635111111111111
0.5667
0.5238222222222222
0.5279
0.4902111111111111
0.4939
0.45587777777777777
0.4616
0.42344444444444446
0.4307
0.39344444444444443
0.399
0.3653222222222222
0.3695
0.3380111111111111
0.3391
0.3123444444444444
0.3207
0.29102222222222224
0.3015
0.2732111111111111
0.2808
0.2579888888888889
0.2686
0.24657777777777778
0.2547
0.23303333333333334
0.2398
0.21876666666666666
0.2267
0.20578888888888888
0.207
0.1928111111111111
0.1923


In [10]:
def test_proj_after_it(it):
    # Get foreground projection matrix 
    rowspace_projections_temp = rowspace_projections[:it]
    P = get_projection_to_intersection_of_nullspaces_torch(rowspace_projections_temp, 2048)
    # Get background projection matrix
    anti_P = get_projection_to_intersection_of_nullspaces_torch([P],2048)
    # Truncate matrices to proper dimension (this should be done already, but matrices end up being slightly larger due to numerical instability)
    P_trunc = (orth_torch(P)[:,:2048-10*it]).T
    anti_P_trunc = (orth_torch(anti_P)[:,:10*it]).T
    # Project data 
    train_activations_antiproj = (anti_P_trunc.dot(train_activations.T)).T
    val_activations_antiproj = (anti_P_trunc.dot(val_activations.T)).T
    train_activations_proj = (P_trunc.dot(train_activations.T)).T
    val_activations_proj = (P_trunc.dot(val_activations.T)).T
    # Train and evaluate classifiers using projected heads
    print("Background features, backgound accs")
    clf = LogisticRegression(random_state=0, multi_class='multinomial', max_iter=400, C = 0.1, verbose=1).fit(train_activations_antiproj, train_backgrounds)
    print(clf.score(train_activations_antiproj, train_backgrounds))
    print(clf.score(val_activations_antiproj, val_backgrounds))
    print("Background features, foreground accs")
    clf = LogisticRegression(random_state=0, multi_class='multinomial', max_iter=400, C = 0.1, verbose=1).fit(train_activations_antiproj, train_classes)
    print(clf.score(train_activations_antiproj, train_classes))
    print(clf.score(val_activations_antiproj, val_classes))
    print("Foreground features, backgound accs")
    clf = LogisticRegression(random_state=0, multi_class='multinomial', max_iter=400, C = 0.1, verbose=1).fit(train_activations_proj, train_backgrounds)
    print(clf.score(train_activations_proj, train_backgrounds))
    print(clf.score(val_activations_proj, val_backgrounds))
    print("Foreground features, foreground accs")
    clf = LogisticRegression(random_state=0, multi_class='multinomial', max_iter=400, C = 0.1, verbose=1).fit(train_activations_proj, train_classes)
    print(clf.score(train_activations_proj, train_classes))
    print(clf.score(val_activations_proj, val_classes))
    # Confirm invertability; will except otherwise
    np.linalg.inv(np.concatenate([P_trunc,anti_P_trunc]))
    return P_trunc, anti_P_trunc
    

In [11]:
P_11, anti_P_11 = test_proj_after_it(11)

Background features, backgound accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.7s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


0.8606
0.8053
Background features, foreground accs


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:   47.6s finished


0.5676555555555556
0.5221
Foreground features, backgound accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    7.2s finished


0.39287777777777777
0.399
Foreground features, foreground accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  5.8min finished


0.9814555555555555
0.8134


In [12]:
P_22, anti_P_22 = test_proj_after_it(22)

Background features, backgound accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    1.1s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


0.8605111111111111
0.8059
Background features, foreground accs


[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:   58.4s finished


0.7582222222222222
0.692
Foreground features, backgound accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    5.7s finished


0.1925888888888889
0.1918
Foreground features, foreground accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  5.4min finished


0.9785666666666667
0.8124


In [13]:
P_8, anti_P_8 = test_proj_after_it(8)

Background features, backgound accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220D-16
 N =         1110     M =           10

At X0         0 variables are exactly at the bounds

At iterate    0    f=  2.07233D+05    |proj g|=  1.90914D+04

At iterate   50    f=  4.44203D+04    |proj g|=  2.23931D+01

At iterate  100    f=  4.44060D+04    |proj g|=  4.72462D-01

           * * *

Tit   = total number of iterations
Tnf   = total number of function evaluations
Tnint = total number of segments explored during Cauchy searches
Skip  = number of BFGS updates skipped
Nact  = number of active bounds at final generalized Cauchy point
Projg = norm of the final projected gradient
F     = final function value

           * * *

   N    Tit     Tnf  Tnint  Skip  Nact     Projg        F
 1110    115    120      1     0     0   1.008D-01   4.441D+04
  F =   44405.945657474040     

CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH             
RUNNING THE L-BFGS-B CODE

           * * *

Machine precision = 2.220

[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.3s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:   16.9s finished


0.4306888888888889
0.3874
Foreground features, backgound accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    7.8s finished


0.4904
0.4947
Foreground features, foreground accs


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:  6.0min finished


0.9820888888888889
0.8134


In [15]:
# Save projection heads

torch.save({'foreground': torch.tensor(P_8.T),'background': torch.tensor(anti_P_8.T) }, "disentangle_8.pth")

torch.save({'foreground': torch.tensor(P_11.T),'background': torch.tensor(anti_P_11.T) }, "disentangle_11.pth")

torch.save({'foreground': torch.tensor(P_22.T),'background': torch.tensor(anti_P_22.T) }, "disentangle_22.pth")