In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from utils import *

In [2]:
class MyConv(nn.Module):
    def __init__(self, d=28):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=2)
        self.fc = nn.Linear(900,10)
        self.softmax=nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        x = self.softmax(x)
        return x

In [3]:
from copy import deepcopy
class MyLin(nn.Module):
    def __init__(self, d=28):
        super().__init__()
        self.lin1 = conv_to_fc(myconv.conv1)
        self.fc = deepcopy(myconv.fc)
        self.softmax=nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.fc(x)
        x = self.softmax(x)
        return x

In [4]:
class BasicLin(nn.Module):
    def __init__(self, d=28):
        super().__init__()
        self.lin1 = nn.Linear(28*28,900)
        self.fc = nn.Linear(900,10)
        self.softmax=nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = self.fc(x)
        x = self.softmax(x)
        return x

In [19]:
train_inp = torch.randn((10000,1,28,28))
test_inp = torch.randn((1000,1,28,28))


In [20]:
myconv=MyConv()
mylin=MyLin()

In [22]:
conv_out = myconv(test_inp).squeeze()
lin_out = mylin(test_inp.view(test_inp.shape[0], -1)).squeeze()
torch.norm( conv_out - lin_out)

tensor(1.0152e-05, grad_fn=<LinalgVectorNormBackward0>)

In [23]:
lin = BasicLin()

In [41]:
train_conv_out = myconv(train_inp).squeeze()

In [42]:
conv_out = myconv(test_inp).squeeze()
lin_out = lin(test_inp.view(test_inp.shape[0], -1)).squeeze()
torch.norm( conv_out - lin_out)

tensor(19.1279, grad_fn=<LinalgVectorNormBackward0>)

## It's possible to learn conv structure with the same fc

In [43]:
mse = torch.nn.MSELoss()
#optimizer = torch.optim.SGD(lin.parameters(), lr=0.1, momentum=0.9)
optimizer = torch.optim.Adam(lin.parameters(), lr=0.001)
num_epochs = 10000

In [45]:
for epoch in range(num_epochs):
    
    # take random 128 batch
    batch_size=10000
    inputs = train_inp[torch.randint(0, train_inp.shape[0], (batch_size,))]
    labels = myconv(inputs).unsqueeze(-1)
    out = lin(inputs.view(inputs.shape[0], -1))
    loss = mse(out, labels.squeeze())
    
    ## Backwards pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if num_epochs%100==0: 
        
        fcn_out = lin(train_inp.view(train_inp.shape[0], -1)).squeeze()
        test_mse = torch.norm( train_conv_out - fcn_out)
        print(f'Train MSE is {test_mse}\n')
        
        fcn_out = lin(test_inp.view(test_inp.shape[0], -1)).squeeze()
        test_mse = torch.norm(conv_out - fcn_out)
        print(f'Test MSE is {test_mse}\n\n')
print('Finished Training')

Train MSE is 44.573326110839844

Test MSE is 16.357738494873047


Train MSE is 46.4608154296875

Test MSE is 17.085193634033203


Train MSE is 52.691925048828125

Test MSE is 18.995315551757812


Train MSE is 45.2849235534668

Test MSE is 17.311193466186523


Train MSE is 36.13462829589844

Test MSE is 15.31116008758545


Train MSE is 37.209171295166016

Test MSE is 15.708662986755371


Train MSE is 41.22621536254883

Test MSE is 16.841938018798828


Train MSE is 38.885948181152344

Test MSE is 16.574756622314453


Train MSE is 31.745832443237305

Test MSE is 15.35290241241455


Train MSE is 27.05884552001953

Test MSE is 14.752788543701172


Train MSE is 28.868864059448242

Test MSE is 15.303986549377441


Train MSE is 31.195207595825195

Test MSE is 15.908234596252441


Train MSE is 29.3188533782959

Test MSE is 15.70212173461914


Train MSE is 24.316781997680664

Test MSE is 14.980239868164062


Train MSE is 21.12152862548828

Test MSE is 14.601717948913574


Train MSE is 22.4383525

Train MSE is 0.15257027745246887

Test MSE is 14.1031494140625


Train MSE is 0.15369781851768494

Test MSE is 14.102893829345703


Train MSE is 0.14391417801380157

Test MSE is 14.10275936126709


Train MSE is 0.1356729120016098

Test MSE is 14.102688789367676


Train MSE is 0.13628965616226196

Test MSE is 14.102566719055176


Train MSE is 0.12931963801383972

Test MSE is 14.102426528930664


Train MSE is 0.11852636933326721

Test MSE is 14.102256774902344


Train MSE is 0.11790107935667038

Test MSE is 14.102229118347168


Train MSE is 0.1159408763051033

Test MSE is 14.10244083404541


Train MSE is 0.10703030228614807

Test MSE is 14.102721214294434


Train MSE is 0.10328154265880585

Test MSE is 14.103011131286621


Train MSE is 0.10266035050153732

Test MSE is 14.10307788848877


Train MSE is 0.09537650644779205

Test MSE is 14.102959632873535


Train MSE is 0.09024560451507568

Test MSE is 14.102753639221191


Train MSE is 0.09041868895292282

Test MSE is 14.102649688720703


Tr

Train MSE is 0.0014495893847197294

Test MSE is 14.103192329406738


Train MSE is 0.001396635314449668

Test MSE is 14.103192329406738


Train MSE is 0.0013363627949729562

Test MSE is 14.103187561035156


Train MSE is 0.0012907792115584016

Test MSE is 14.103184700012207


Train MSE is 0.001259857788681984

Test MSE is 14.10318374633789


Train MSE is 0.0012141101760789752

Test MSE is 14.10318374633789


Train MSE is 0.0011590697104111314

Test MSE is 14.10318660736084


Train MSE is 0.0011132039362564683

Test MSE is 14.103188514709473


Train MSE is 0.001080342335626483

Test MSE is 14.103187561035156


Train MSE is 0.0010538367787376046

Test MSE is 14.103184700012207


Train MSE is 0.0010183124104514718

Test MSE is 14.103182792663574


Train MSE is 0.000971113913692534

Test MSE is 14.103184700012207


Train MSE is 0.0009371138294227421

Test MSE is 14.10318660736084


Train MSE is 0.0009141791379079223

Test MSE is 14.103187561035156


Train MSE is 0.0008842844399623573

Test M

Train MSE is 5.569779023062438e-05

Test MSE is 14.103188514709473


Train MSE is 5.520986087503843e-05

Test MSE is 14.10318660736084


Train MSE is 5.459368912852369e-05

Test MSE is 14.103187561035156


Train MSE is 5.443023474072106e-05

Test MSE is 14.10318660736084


Train MSE is 5.4086609452497214e-05

Test MSE is 14.103188514709473


Train MSE is 5.317589602782391e-05

Test MSE is 14.103187561035156


Train MSE is 5.319219417287968e-05

Test MSE is 14.103187561035156


Train MSE is 5.290394983603619e-05

Test MSE is 14.103187561035156


Train MSE is 5.2667710406240076e-05

Test MSE is 14.103187561035156


Train MSE is 5.215838973526843e-05

Test MSE is 14.103187561035156


Train MSE is 5.187015267438255e-05

Test MSE is 14.103187561035156


Train MSE is 5.161913577467203e-05

Test MSE is 14.103187561035156


Train MSE is 5.1332022849237546e-05

Test MSE is 14.103188514709473


Train MSE is 5.099162444821559e-05

Test MSE is 14.103187561035156


Train MSE is 5.0748167268466204e-

Train MSE is 4.9468693759990856e-05

Test MSE is 14.103188514709473


Train MSE is 4.92989165650215e-05

Test MSE is 14.103187561035156


Train MSE is 4.937092671752907e-05

Test MSE is 14.103187561035156


Train MSE is 4.983518738299608e-05

Test MSE is 14.10318660736084


Train MSE is 4.973828254151158e-05

Test MSE is 14.10318660736084


Train MSE is 4.947027628077194e-05

Test MSE is 14.103188514709473


Train MSE is 4.957343844580464e-05

Test MSE is 14.103187561035156


Train MSE is 4.986298154108226e-05

Test MSE is 14.103188514709473


Train MSE is 4.971170346834697e-05

Test MSE is 14.103189468383789


Train MSE is 4.9494112317916006e-05

Test MSE is 14.103187561035156


Train MSE is 4.985628402209841e-05

Test MSE is 14.103187561035156


Train MSE is 4.983818143955432e-05

Test MSE is 14.103187561035156


Train MSE is 4.998779695597477e-05

Test MSE is 14.103188514709473


Train MSE is 5.003397382097319e-05

Test MSE is 14.103187561035156


Train MSE is 5.017224248149432e-05


Train MSE is 5.424821574706584e-05

Test MSE is 14.103187561035156


Train MSE is 5.4801148507976905e-05

Test MSE is 14.103187561035156


Train MSE is 5.484456414706074e-05

Test MSE is 14.103188514709473


Train MSE is 5.468134258990176e-05

Test MSE is 14.103188514709473


Train MSE is 5.463272100314498e-05

Test MSE is 14.103188514709473


Train MSE is 5.492302079801448e-05

Test MSE is 14.103187561035156


Train MSE is 5.464494461193681e-05

Test MSE is 14.103187561035156


Train MSE is 5.498055907082744e-05

Test MSE is 14.103187561035156


Train MSE is 5.502151179825887e-05

Test MSE is 14.10318660736084


Train MSE is 5.525887536350638e-05

Test MSE is 14.103188514709473


Train MSE is 5.5211792641784996e-05

Test MSE is 14.103187561035156


Train MSE is 5.5052754760254174e-05

Test MSE is 14.103187561035156


Train MSE is 5.548023909796029e-05

Test MSE is 14.10318660736084


Train MSE is 5.543911174754612e-05

Test MSE is 14.103187561035156


Train MSE is 5.514392978511751e-0

Train MSE is 6.148680404294282e-05

Test MSE is 14.103188514709473


Train MSE is 6.17566256551072e-05

Test MSE is 14.103187561035156


Train MSE is 6.178572948556393e-05

Test MSE is 14.10318660736084


Train MSE is 6.202034273883328e-05

Test MSE is 14.103187561035156


Train MSE is 6.217195914359763e-05

Test MSE is 14.103188514709473


Train MSE is 6.199925701366737e-05

Test MSE is 14.103187561035156


Train MSE is 6.249000580282882e-05

Test MSE is 14.103187561035156


Train MSE is 6.283471884671599e-05

Test MSE is 14.103185653686523


Train MSE is 6.250387377804145e-05

Test MSE is 14.103187561035156


Train MSE is 6.234201282495633e-05

Test MSE is 14.103187561035156


Train MSE is 6.276048952713609e-05

Test MSE is 14.103188514709473


Train MSE is 6.293370097409934e-05

Test MSE is 14.103187561035156


Train MSE is 6.263991963351145e-05

Test MSE is 14.10318660736084


Train MSE is 6.289970770012587e-05

Test MSE is 14.103187561035156


Train MSE is 6.311622564680874e-05

T

KeyboardInterrupt: 

In [34]:
out.shape

torch.Size([10000, 10])

In [36]:
labels.shape

torch.Size([10000, 10, 1])