In [1]:
import sys
import torch
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

sys.path.append('./models')
from models.matrix_a_linear import MatrixALinear

In [2]:
torch.set_printoptions(threshold=10_000)
def load_model(input_dim: int, output_dim: int):
    matrix_a_linear = MatrixALinear(input_dim=input_dim, output_dim=output_dim, bias=False)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    matrix_a_linear.load_state_dict(torch.load("../../saved_models/matrix_a.pkl", map_location=torch.device(device)))

    return matrix_a_linear

def save_model(new_model: MatrixALinear):
    torch.save(new_model.state_dict(),f'../../saved_models/matrix_a_after_pca_20.pkl')

In [3]:
model = load_model(100, 100)
matrix_a_np = model.linear.weight.data.numpy()

In [13]:
matrix_a_np

array([[ 0.81704366, -0.2660322 , -0.08401775, ..., -0.30233875,
        -0.0488075 , -0.11722549],
       [-0.08853493,  0.61349964, -0.14548564, ..., -0.04823741,
         0.00391661, -0.01794411],
       [-0.2871844 , -0.2862975 ,  0.75088465, ..., -0.19625098,
        -0.1456132 , -0.18078662],
       ...,
       [-0.16676742, -0.07168139,  0.08625817, ...,  0.838528  ,
        -0.06899016, -0.19889902],
       [-0.01211979, -0.09629856,  0.16925535, ..., -0.2892095 ,
         1.0440856 , -0.07014164],
       [-0.13707556,  0.14942361,  0.14688483, ..., -0.25762913,
        -0.16823018,  1.0223261 ]], dtype=float32)

In [14]:
matrix_a_np = StandardScaler().fit_transform(matrix_a_np)

In [15]:
pca = PCA(n_components=23)
principalComponents = pca.fit_transform(matrix_a_np)

In [16]:
# principalComponents = StandardScaler().fit_transform(principalComponents)

In [17]:
pca.explained_variance_ratio_

array([0.49193215, 0.031477  , 0.02599666, 0.02435079, 0.0194992 ,
       0.01871691, 0.01654205, 0.01557036, 0.01453817, 0.01268661],
      dtype=float32)

In [18]:
new_weights = torch.from_numpy(principalComponents)

matrix_a_linear_after_pca = MatrixALinear(input_dim=23, output_dim=100, bias=False)
new_dict = {
    'linear.weight': new_weights
}
matrix_a_linear_after_pca.load_state_dict(new_dict)

<All keys matched successfully>

In [19]:
save_model(matrix_a_linear_after_pca)

In [20]:
matrix_a_linear_after_pca

MatrixALinear(
  (linear): Linear(in_features=10, out_features=100, bias=False)
)

In [21]:
matrix_a_linear_after_pca.linear.weight.data.numpy()

array([[-2.36836100e+00, -2.16956687e+00,  1.38259435e+00,
        -1.92839849e+00, -2.29927516e+00, -8.53460252e-01,
         1.64182782e+00,  5.89460507e-02,  1.63576472e+00,
         6.20316803e-01],
       [-7.02263832e+00,  2.69627720e-01, -5.67205608e-01,
        -2.42504144e+00,  5.66472411e-01, -1.31610060e+00,
         1.30753660e+00,  2.54322886e+00, -3.99155676e-01,
         1.07079379e-01],
       [-1.14587164e+01,  5.53330123e-01,  1.32999861e+00,
         4.74807143e-01,  5.81735037e-02, -5.76772809e-01,
         4.15994287e-01,  6.84409738e-01,  2.42402124e+00,
         3.82434905e-01],
       [ 1.64653337e+00,  5.61855584e-02, -6.02150917e-01,
        -6.75311327e-01,  5.68060875e-01, -1.10657072e+00,
         6.20446622e-01,  1.72931695e+00, -5.52933514e-01,
         3.78852987e+00],
       [ 1.14193811e+01, -1.09902728e+00, -1.16183746e+00,
         1.74328458e+00, -1.58468872e-01, -6.95529103e-01,
         8.26171320e-03,  2.92645574e-01, -1.26432955e+00,
         9.