In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
import time
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import DBSCAN
from sklearn.datasets import make_blobs
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import NearestNeighbors
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers
import datetime, os

In [None]:
def flatten_grads(grads):
    grads = torch.cat([grad.view(-1) for grad in grads])
    return grads

def unflatten_grads(flattened_grads, model):
    grads = []
    start = 0
    for param in model.parameters():
        end = start + param.numel()
        grads.append(flattened_grads[start:end].view(param.size()))
        start = end
    return grads

def sort_grads(grads, sort_order=True):
    avg_abs = [torch.mean(torch.abs(grad)).item() for grad in grads]
    order = sorted(range(len(avg_abs)), key=lambda i: avg_abs[i], reverse=sort_order)
    return order

def unsort_grads(grads, order):
    unsorted_grads = [grads[i] for i in order]
    return unsorted_grads

def orthogonalize_grads(grads):
    flattened_grads = [flatten_grads(cluster_grads) for cluster_grads in grads]

    #sort the vectors by average absolute value in descending order
    order = sort_grads(flattened_grads)
    flattened_grads = [flattened_grads[i] for i in order]

    #now we have a set of vectors, we can perform the gram schmidt process
    ortho_grads = []
    for i, grad in enumerate(flattened_grads):
        for j in range(i):
            #subtract the projection of the current vector onto the previous vectors
            if torch.norm(ortho_grads[j]) != 0:
                grad = grad - (grad @ ortho_grads[j]) / (ortho_grads[j] @ ortho_grads[j]) * ortho_grads[j]
        ortho_grads.append(grad)
    #sort the vectors back to their original order
    ortho_grads = unsort_grads(ortho_grads, order)
    return ortho_grads

Testing of the orthoganlization process

Tests to be carried out:
1. test orthogonalization algorithm as is, with many 0s and very small numbers 
2. Ensure any size tensor is passed in, flattened (keeping the seperate cluster vectors) and orthogonalized, and returned in the correct shape
3. Ensure the vectors are orthogonalized and no overflows or division by 0 occur.

In [25]:
#orthogonalization function that takes in a set of vectors and orthogonalizes them
def orthogonalize_vectors(vectors):
    ortho_vectors = []
    for vector in vectors:
        for ortho_vector in ortho_vectors:
            #handle the case where the vector is all zeros
            if torch.norm(ortho_vector) == 0:
                continue
            #subtract the projection of the current vector onto the previous vectors
            vector = vector - ((vector @ ortho_vector) / (ortho_vector @ ortho_vector) * ortho_vector)
        ortho_vectors.append(vector)
    return ortho_vectors
    
#create a vectorized version of the orthogonalization function that does the work in parallel, dont include the normalization step
def orthogonalize_vectors_vectorized(vectors):
    #create a matrix of the vectors
    vectors = torch.stack(vectors)
    #create a matrix of the orthogonalized vectors
    ortho_vectors = torch.zeros_like(vectors)
    #create a matrix of the previous vectors
    prev_vectors = torch.zeros_like(vectors)
    for i in range(vectors.size(0)):
        #subtract the projection of the current vector onto the previous vectors
        ortho_vectors[i] = vectors[i] - (vectors[i] @ prev_vectors.t() / (prev_vectors @ prev_vectors.t()) @ prev_vectors)
        #update the previous vectors
        prev_vectors[i] = ortho_vectors[i]
    return ortho_vectors

#first test, create a random set of vectors and orthogonalize them, record the average change of elements in the vectors after orthogonalization
# vectors = [torch.randn(10) for i in range(10)]
# vectors[1] = torch.zeros(10)
vectors = [[] for _ in range(3)]
vectors[0] = [0.3, 0.7, 0.6, 0.9, 0.1]
vectors[1] = [-0.3, 0.6, -0.8, 0.8, 0.8]
vectors[2] = [0.1, -0.1, -0.6, -0.9, 0]
vectors = [torch.tensor(vector) for vector in vectors]
print("Original vectors:")
print(vectors)

#test sort order
#sort the vectors by average absolute value in descending order, dont use the function
# avg_abs = [torch.mean(torch.abs(vectors[i])).item() for i in range(len(vectors))]
# order = sorted(range(len(avg_abs)), key=lambda i: avg_abs[i], reverse=True)
# vectors = [vectors[i] for i in order]
# print("Sorted vectors:")
# print(vectors)



#orthogonalize the vectors
ortho_vectors = orthogonalize_vectors(vectors)
print("Orthogonalized vectors:")
print(ortho_vectors)

#use the vectorized version of the orthogonalization function
ortho_vectors = orthogonalize_vectors_vectorized(vectors)
print("Vectorized Orthogonalized vectors:")
print(ortho_vectors)

#calculate the average change in the elements of the vectors for each vector
change_vectors = [torch.mean(torch.abs(vectors[i] - ortho_vectors[i])).item() for i in range(len(vectors))]
print("Average absolute Change vectors:")
print(change_vectors)

#create a vector for each vector to record the change in the elements of the vectors after orthogonalization
change_vectors = [-1*(vectors[i] - ortho_vectors[i]) for i in range(len(vectors))]
print("Change vectors:")
print(change_vectors)


Original vectors:
[tensor([0.3000, 0.7000, 0.6000, 0.9000, 0.1000]), tensor([-0.3000,  0.6000, -0.8000,  0.8000,  0.8000]), tensor([ 0.1000, -0.1000, -0.6000, -0.9000,  0.0000])]
Orthogonalized vectors:
[tensor([0.3000, 0.7000, 0.6000, 0.9000, 0.1000]), tensor([-0.4108,  0.3415, -1.0216,  0.4676,  0.7631]), tensor([ 0.3288,  0.3625, -0.1314, -0.3069,  0.0269])]


RuntimeError: expand(torch.FloatTensor{[3, 5]}, size=[5]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)

Testing of the orthogonalization vectorization algorithms and their effeciency
Tests to be carried out:
1. Ensure the results are the same
2. Ensure the vectorized version is faster, with results