In [7]:
P_partial

tensor([[0., 1., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [6]:
import torch
import torch.nn as nn

def create_general_permutation_matrix(task1_classes, task2_classes, K):
    # Initialize a zero matrix for the new definition
    P = torch.zeros(K, K)
    
    task1_mapping = {c: i for i, c in enumerate(task1_classes)}
    task2_mapping = {c: i for i, c in enumerate(task2_classes)}
    
    for class_ in task1_classes:
        if class_ in task2_mapping:
            P[task2_mapping[class_], task1_mapping[class_]] = 1
            
    return P


# Initialize the nn.Linear layer with F input features and K output features (classes)
F, K = 4, 3
D = nn.Linear(F, K, bias=True)  # This layer represents D

# Simulate the class lists for Task 1 and Task 2 with full overlap
task1_classes = [0, 1, 2]
task2_classes_full = [2, 1, 0]  # Full overlap with reordering
# Simulate the class lists for Task 1 and Task 2 with partial overlap and no overlap
# task2_classes_partial = [1, 2, 4]  # Partial overlap
task2_classes_partial = [1, 8, 4]  # Partial overlap
task2_classes_no_overlap = [3, 4, 5]  # No overlap

# Verification with a random input X
N = 5  # Number of samples
X = torch.randn(N, F)

def predict_labels(y):
    softmax_scores = torch.nn.functional.softmax(y, dim=1)
    return torch.argmax(softmax_scores, dim=1)

# Function to apply transformation to D and create D'
def transform_D(D, P, F, K):
    # transformed_weights = torch.matmul(P, D.weight.data)
    # transformed_bias = torch.matmul(P, D.bias.data)
    transformed_weights = torch.matmul(P, D.weight.data)
    transformed_bias = torch.matmul(P, D.bias.data)
    D_prime = nn.Linear(F, K, bias=True)
    D_prime.weight.data = transformed_weights
    D_prime.bias.data = transformed_bias
    return D_prime

P = create_general_permutation_matrix(task1_classes, task2_classes_full, K)
# Create the permutation matrices for partial and no overlap cases
P_partial = create_general_permutation_matrix(task1_classes, task2_classes_partial, K)
P_no_overlap = create_general_permutation_matrix(task1_classes, task2_classes_no_overlap, K)


# Apply transformations for partial and no overlap cases
D_prime_full = transform_D(D, P, F, K)
D_prime_partial = transform_D(D, P_partial, F, K)
D_prime_no_overlap = transform_D(D, P_no_overlap, F, K)


# Original predictions
y_original = predict_labels(D(X))

# Transformed predictions for full overlap
y_transformed_full = predict_labels(D_prime_full(X))
y_transformed_partial = predict_labels(D_prime_partial(X))
y_transformed_no_overlap = predict_labels(D_prime_no_overlap(X))


def verify_predictions(y_original, y_transformed, task1_classes, task2_classes):
    # Direct mapping of original class indices to the global class space
    task2_index_to_global = {i: class_ for i, class_ in enumerate(task2_classes)}

    correct = 0
    total_common = 0

    for original, transformed in zip(y_original, y_transformed):
        # If the original class is in the new task, check the prediction
        if task1_classes[original.item()] in task2_classes:
            total_common += 1  # Only count if it's a common class
            # Map the transformed prediction back to its global class
            transformed_global = task2_index_to_global.get(transformed.item())
            # Check correctness
            if transformed_global == task1_classes[original.item()]:
                correct += 1

    return correct, total_common



print("Labels Original:\n", y_original)
print("Labels Full Transform:\n", y_transformed_full)
print("Labels Partial Transform:\n", y_transformed_partial)
print("Labels No Overlap Transform:\n", y_transformed_no_overlap)





# Verify correctness for partial and no overlap cases
correct_full, total_full = verify_predictions(y_original, y_transformed_full, task1_classes, task2_classes_full)
correct_partial, total_partial = verify_predictions(y_original, y_transformed_partial, task1_classes, task2_classes_partial)
correct_no_overlap, total_no_overlap = verify_predictions(y_original, y_transformed_no_overlap, task1_classes, task2_classes_no_overlap)

print('\n\n')
print(f"Correct Full Overlap: {correct_full}/{total_full}")
print(f"Correct Partial Overlap: {correct_partial}/{total_partial}")
print(f"Correct No Overlap: {correct_no_overlap}/{total_no_overlap}")


Labels Original:
 tensor([2, 1, 1, 2, 2])
Labels Full Transform:
 tensor([0, 1, 1, 0, 0])
Labels Partial Transform:
 tensor([0, 0, 1, 1, 0])
Labels No Overlap Transform:
 tensor([0, 0, 0, 0, 0])



Correct Full Overlap: 5/5
Correct Partial Overlap: 1/2
Correct No Overlap: 0/0


In [13]:
D(X)

tensor([[-0.5984,  0.3288,  0.4514],
        [-0.2207,  0.4237, -0.0310],
        [-0.1954, -0.1542, -0.3184],
        [-0.1374, -0.1552,  0.2464],
        [-0.5863,  0.1020,  0.5980]], grad_fn=<AddmmBackward0>)

In [12]:
D_prime_full(X)

tensor([[ 0.4514,  0.3288, -0.5984],
        [-0.0310,  0.4237, -0.2207],
        [-0.3184, -0.1542, -0.1954],
        [ 0.2464, -0.1552, -0.1374],
        [ 0.5980,  0.1020, -0.5863]], grad_fn=<AddmmBackward0>)

In [14]:
D_prime_partial(X)

tensor([[ 0.3288,  0.0000,  0.0000],
        [ 0.4237,  0.0000,  0.0000],
        [-0.1542,  0.0000,  0.0000],
        [-0.1552,  0.0000,  0.0000],
        [ 0.1020,  0.0000,  0.0000]], grad_fn=<AddmmBackward0>)

In [2]:
predict_labels(D(X) @ P_partial.T)

tensor([0, 0, 0, 0, 0])