In [6]:
import numpy as np
from sklearn.decomposition import PCA

import torch
import torch.nn as nn

In [22]:
X = np.array([[-1, -1, -1], [-2, -1, -2], [-3, -2, -2], [1, 1, 1], [2, 1, 2], [3, 2, 1]])

pca = PCA(n_components=2)
pca.fit(X)

print(pca.mean_, pca.whiten, pca.components_.shape, "\n")

def transform(X, mean, components, explained_variance, whiten = False):
    """From source of sklearn """
    Y = np.dot(X - mean, components.T)    
    if whiten:
        Y /= np.sqrt(explained_variance)
    return Y
    
Y = pca.transform(X)
print(Y, "\n")

Y = transform(X, pca.mean_, pca.components_, pca.explained_variance_, pca.whiten)
print(Y)

[ 0.          0.         -0.16666667] False (2, 3) 

[[ 1.61410591 -0.04030163]
 [ 2.84581994  0.48808345]
 [ 4.03668406 -0.25740461]
 [-1.78340263 -0.24177294]
 [-3.01511665 -0.77015802]
 [-3.69809063  0.82155375]] 

[[ 1.61410591 -0.04030163]
 [ 2.84581994  0.48808345]
 [ 4.03668406 -0.25740461]
 [-1.78340263 -0.24177294]
 [-3.01511665 -0.77015802]
 [-3.69809063  0.82155375]]


In [68]:
class Model_PCA(nn.Module):
    def __init__(self, dim1=3, dim2=2):                                                        
        super(Model_PCA, self).__init__()      
            
        self.mean = nn.Parameter( torch.zeros((dim1, ),     dtype=torch.float32),  requires_grad = False)
        self.comp = nn.Parameter( torch.zeros((dim2, dim1), dtype=torch.float32),  requires_grad = False)
 
    def set(self, mean, comp):  # numpy
        self.mean.copy_( torch.tensor(mean, dtype=torch.float32) )
        self.comp.copy_( torch.tensor(comp, dtype=torch.float32) )

    def forward(self, x):               
        x =  x - self.mean
        print(x.shape, self.comp.T.shape)
        x =  torch.mm(x, self.comp.T)        
        return x
    
model = Model_PCA()  
model.set(pca.mean_, pca.components_)   # from pca

Y = model(torch.tensor(X))
print(Y)
  
print(model.comp)    

torch.save({'model':  model.state_dict()}, 'model_pca.pt')        

m = Model_PCA(3,2)
state = torch.load('model_pca.pt')
m.load_state_dict(state['model'])  
print(m.comp)

Y = m(torch.tensor(X))
print(Y)

torch.Size([6, 3]) torch.Size([3, 2])
tensor([[ 1.6141, -0.0403],
        [ 2.8458,  0.4881],
        [ 4.0367, -0.2574],
        [-1.7834, -0.2418],
        [-3.0151, -0.7702],
        [-3.6981,  0.8216]])
Parameter containing:
tensor([[-0.7238, -0.4670, -0.5079],
        [ 0.3178,  0.4276, -0.8462]])
Parameter containing:
tensor([[-0.7238, -0.4670, -0.5079],
        [ 0.3178,  0.4276, -0.8462]])
torch.Size([6, 3]) torch.Size([3, 2])
tensor([[ 1.6141, -0.0403],
        [ 2.8458,  0.4881],
        [ 4.0367, -0.2574],
        [-1.7834, -0.2418],
        [-3.0151, -0.7702],
        [-3.6981,  0.8216]])


In [72]:
class Model(nn.Module):
    def __init__(self):                                                        
        super(Model, self).__init__()             
        self.fc  = nn.Linear(3, 2) # something
        #...
        self.pca = Model_PCA()
 
    def forward(self, x):               
        return self.fc(x)    

model = Model()    

state = torch.load('model_pca.pt')
print(state)
model.pca.load_state_dict(state['model'])  


torch.save({'model':  model.state_dict()}, 'model_all.pt')  
state = torch.load('model_all.pt')
print(state)

{'model': OrderedDict([('mean', tensor([ 0.0000,  0.0000, -0.1667])), ('comp', tensor([[-0.7238, -0.4670, -0.5079],
        [ 0.3178,  0.4276, -0.8462]]))])}
{'model': OrderedDict([('fc.weight', tensor([[ 0.3555,  0.3801, -0.1033],
        [-0.0138,  0.4348, -0.4440]])), ('fc.bias', tensor([-0.0600,  0.5334])), ('pca.mean', tensor([ 0.0000,  0.0000, -0.1667])), ('pca.comp', tensor([[-0.7238, -0.4670, -0.5079],
        [ 0.3178,  0.4276, -0.8462]]))])}


In [142]:
#p  = nn.Parameter( torch.zeros((10,)))
bn = nn.BatchNorm2d(num_features=4) 
fc1 = nn.Linear(2, 3) 
fc2 = nn.Linear(4, 5) 

m  = nn.Sequential(fc1, nn.Sequential(bn, fc2))

tot = 0
for n, p in bn.state_dict().items():
    pars = p.numel(); tot += pars
    print(f'{n:20s} : {pars}  =  {tuple(p.shape)}')
    
print(f"{'parameters':20s} :{tot:7d}")

weight               : 4  =  (4,)
bias                 : 4  =  (4,)
running_mean         : 4  =  (4,)
running_var          : 4  =  (4,)
num_batches_tracked  : 1  =  ()
parameters           :     17


In [145]:
for p in bn.parameters():
    print(f'{tuple(p.shape)} {p.requires_grad}')

(4,) True
(4,) True


In [146]:
for n, p in bn.named_parameters():    
    print(f"{n+':':10s} {tuple(p.shape)}, {p.requires_grad}")

weight:    (4,), True
bias:      (4,), True


In [180]:
im = torch.tensor([ [ [[1,2],[2,1]],  
                      [[1,1],[1,1]],
                      [[0,0],[0,0]],
                    ],
                   
                      
                  ]).float()
print(im.shape)

bn = nn.BatchNorm2d(num_features=3) 

for n, p in bn.state_dict().items():
    print(f'{n:20s} : {tuple(p.shape)} {p.requires_grad} {p}')
    
print(bn(im))

print("mean:",im.mean((0,2,3)))
print("var :",im.var((0,2,3)))


for n, p in bn.state_dict().items():
    print(f'{n:20s} : {tuple(p.shape)} {p.requires_grad} {p.data}')
    
for n, p in bn.named_parameters():    
    print(f"{n:10s} : {tuple(p.shape)}, {p.requires_grad}, {p.data}")    

torch.Size([1, 3, 2, 2])
weight               : (3,) False tensor([1., 1., 1.])
bias                 : (3,) False tensor([0., 0., 0.])
running_mean         : (3,) False tensor([0., 0., 0.])
running_var          : (3,) False tensor([1., 1., 1.])
num_batches_tracked  : () False 0
tensor([[[[-1.0000,  1.0000],
          [ 1.0000, -1.0000]],

         [[ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],

         [[ 0.0000,  0.0000],
          [ 0.0000,  0.0000]]]], grad_fn=<NativeBatchNormBackward0>)
mean: tensor([1.5000, 1.0000, 0.0000])
var : tensor([0.3333, 0.0000, 0.0000])
weight               : (3,) False tensor([1., 1., 1.])
bias                 : (3,) False tensor([0., 0., 0.])
running_mean         : (3,) False tensor([0.1500, 0.1000, 0.0000])
running_var          : (3,) False tensor([0.9333, 0.9000, 0.9000])
num_batches_tracked  : () False 1
weight     : (3,), True, tensor([1., 1., 1.])
bias       : (3,), True, tensor([0., 0., 0.])
