In [1]:
import albumentations as A
import numpy as np

import matplotlib.pyplot as plt
from torchvision.datasets import Caltech256 ,Caltech101 ,CIFAR100
import os
from PIL import Image
from urllib.request import urlretrieve
import requests
import tarfile
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

In [3]:
from torch import nn
import torch

In [4]:
import torch
from torch import nn
from collections import OrderedDict


Config_channels = {
"A" : [64,"M" , 128,  "M"  , 256,256,"M" ,512,512 ,"M" , 512,512,"M"] ,
"A_lrn" : [64,"LRN","M" , 128,  "M"  , 256,256,"M" ,512,512 ,"M" , 512,512,"M"] ,
"B" :[64,64,"M" , 128,128,  "M"  , 256,256,"M" ,512,512 ,"M" , 512,512,"M"]  ,
"C" : [64,64,"M" , 128,128,  "M"  , 256,256,256,"M" ,512,512 ,512,"M" , 512,512,512,"M"] ,
"D" :[64,64,"M" , 128,128,  "M"  , 256,256,256,"M" ,512,512 ,512,"M" , 512,512,512,"M"] ,
"E" :[64,64,"M" , 128,128,  "M"  , 256,256,256,256,"M" ,512,512 ,512,512,"M" , 512,512,512,512,"M"]         ,

}



Config_kernel = {
"A" : [3,2 , 3,  2  , 3,3,2 ,3,3 ,2 , 3,3,2] ,
"A_lrn" : [3,2,2 , 3,  2  , 3,3,2 ,3,3 ,2 , 3,3,2] ,
"B" :[3,3,2 , 3,3,  2  , 3,3,2 ,3,3 ,2 , 3,3,2]  ,
"C" : [3,3,2 , 3,3,  2  , 3,3,1,2 ,3,3 ,1,2 , 3,3,1,2] ,
"D" :[3,3,2 , 3,3,  2  , 3,3,3,2 ,3,3 ,3,2 , 3,3,3,2] ,
"E" :[3,3,2 , 3,3,  2  , 3,3,3,3,2 ,3,3 ,3,3,2 , 3,3,3,3,2]         ,

}

def make_feature_extractor(cfg_c,cfg_k):
    feature_extract = []
    in_channels = 3
    i = 1
    for  out_channels , kernel in zip(cfg_c,cfg_k) :
        # print(f"{i} th layer {out_channels} processing")
        if out_channels == "M" :
            feature_extract += [nn.MaxPool2d(kernel,2) ]
        elif out_channels == "LRN":
            feature_extract += [nn.LocalResponseNorm(5,k=2) , nn.ReLU()]
        elif out_channels == 1:
            feature_extract+= [nn.Conv2d(in_channels,out_channels,kernel,stride = 1) , nn.ReLU()]
        else :
            feature_extract+= [nn.Conv2d(in_channels,out_channels,kernel,stride = 1 , padding = 1) , nn.ReLU()]

        if isinstance(out_channels,int) :   in_channels = out_channels
        i+=1
    return nn.Sequential(*feature_extract)


class Model_vgg(nn.Module) :
    def __init__(self,version , num_classes):
        conv_5_out_w ,conv_5_out_h = 7,7
        conv_5_out_dim =512
        conv_1_by_1_1_outchannel = 4096
        conv_1_by_1_2_outchannel = 4096
        self.num_classes = num_classes
        self.linear_out = 4096
        self.xavier_count = 4 
        self.last_xavier= 1  ## if >0 , initialize last 3 fully connected noraml distribution
        # conv_1_by_1_3_outchannel = num_classes
        super().__init__()
        self.feature_extractor = make_feature_extractor(Config_channels[version] , Config_kernel[version])
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))        
        self.output_layer = nn.Sequential(
                             nn.Conv2d(conv_5_out_dim  ,conv_1_by_1_1_outchannel ,7) ,
                             nn.ReLU(),
                             nn.Dropout2d(),
                             nn.Conv2d(conv_1_by_1_1_outchannel ,conv_1_by_1_2_outchannel,1 ) ,
                             nn.ReLU(),
                             nn.Dropout2d(),
                             nn.Conv2d(conv_1_by_1_2_outchannel ,num_classes,1 )
                             )
        # self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        # self.output_layer = nn.Sequential(
        #     nn.Linear(512 * 7 * 7, 4096),
        #     nn.ReLU(),
        #     nn.Dropout(),
        #     nn.Linear(4096, 4096),
        #     nn.ReLU(),
        #     nn.Dropout(),
        #     nn.Linear(4096, num_classes),
        # )
        
        print('weight initialize')
        self.apply(self._init_weights)
        print('weight intialize end')
    def forward(self,x):
        x = self.feature_extractor(x)
        # x= self.avgpool(x)  ##  If Linear is output, use this 
        # x= torch.flatten(x,start_dim = 1) ## If Linear is output, use this 
        x = self.output_layer(x)
        x= self.avgpool(x)
        x= torch.flatten(x,start_dim = 1)
        return x


    @torch.no_grad()
    def _init_weights(self,m):
        
        # print(m)
        if isinstance(m,nn.Conv2d):
            print('-------------')
            print(m.kernel_size)
            print(m.out_channels)
            if (m.out_channels == self.num_classes or m.out_channels == self.linear_out) and self.last_xavier>0 :
                print('xavier')
                # self.last_xavier-=1
                nn.init.xavier_uniform_(m.weight)
            elif self.xavier_count >0 :
                print('xavier')
                nn.init.xavier_uniform_(m.weight)
                self.xavier_count-=1
            else : 
                std = 0.2
                print(f'normal  std : {std}')
                torch.nn.init.normal_(m.weight,std=std)
                # if (m.out_channels == self.num_classes or m.out_channels == self.linear_out) :
                #     self.last_xavier+=1
            if m.bias is not None :
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            nn.init.constant_(m.bias, 0)

    #     pass



In [76]:
model = Model_vgg('B',10)

weight initialize
-------------
(3, 3)
64
xavier
-------------
(3, 3)
64
xavier
-------------
(3, 3)
128
xavier
-------------
(3, 3)
128
xavier
-------------
(3, 3)
256
normal  std : 0.2
-------------
(3, 3)
256
normal  std : 0.2
-------------
(3, 3)
512
normal  std : 0.2
-------------
(3, 3)
512
normal  std : 0.2
-------------
(3, 3)
512
normal  std : 0.2
-------------
(3, 3)
512
normal  std : 0.2
-------------
(7, 7)
4096
xavier
-------------
(1, 1)
4096
xavier
-------------
(1, 1)
10
xavier
weight intialize end


In [6]:
model.state_dict()

AttributeError: 'collections.OrderedDict' object has no attribute 'detach'

In [5]:
chkpoint = torch.load('vgg_B_64_trainmin_256_testmin256_dataset_Cifar10.pt')

In [8]:
chkpoint['optimizer_state_dict']['state']

{0: {'momentum_buffer': tensor([[[[ 4.3574e-03,  4.6046e-03,  4.2620e-03],
            [ 4.3514e-03,  4.4192e-03,  4.8152e-03],
            [ 4.6992e-03,  5.0730e-03,  4.8431e-03]],
  
           [[ 3.2233e-03,  3.4432e-03,  3.7138e-03],
            [ 3.7529e-03,  3.5779e-03,  3.9388e-03],
            [ 3.8271e-03,  4.2433e-03,  3.9027e-03]],
  
           [[ 3.9892e-03,  4.0185e-03,  3.9739e-03],
            [ 4.3434e-03,  4.5734e-03,  4.5853e-03],
            [ 4.4056e-03,  4.4683e-03,  4.7937e-03]]],
  
  
          [[[ 1.0374e-02,  1.0987e-02,  1.1256e-02],
            [ 1.0649e-02,  1.0650e-02,  1.0871e-02],
            [ 1.0511e-02,  1.1177e-02,  1.1057e-02]],
  
           [[ 8.7675e-03,  9.2629e-03,  9.0371e-03],
            [ 8.9317e-03,  9.0869e-03,  9.2312e-03],
            [ 9.3814e-03,  9.6699e-03,  9.5732e-03]],
  
           [[ 3.9653e-03,  3.9669e-03,  4.3949e-03],
            [ 4.2141e-03,  4.1775e-03,  4.8659e-03],
            [ 4.3317e-03,  5.1887e-03,  5.1623e-03]]]

In [66]:
type(chkpoint['model_state_dict'])

collections.OrderedDict

In [30]:
new_dict = {}
for key in chkpoint['model_state_dict'].keys() :
    chkpoint['model_state_dict'][key] = chkpoint['model_state_dict'][key].detach().to('cpu')



In [107]:
temp_pt['model_state_dict'] = OrderedDict([(key,chkpoint['model_state_dict'][key].detach().to('cpu')) for key in chkpoint['model_state_dict'].keys() ])

AttributeError: 'list' object has no attribute 'keys'

In [84]:
temp_pt['optimizer_state_dict']['state'] = [chkpoint['optimizer_state_dict']['state'][key1][key2].detach().to('cpu')for key1 in chkpoint['optimizer_state_dict']['state'].keys() )for key2 in chkpoint['optimizer_state_dict']['state'][key1].keys() ])
        

In [None]:
temp_pt['optimizer_state_dict']['state'] = [chkpoint['optimizer_state_dict']['state'][key1][key2].detach().to('cpu')for key1 in chkpoint['optimizer_state_dict']['state'].keys() )for key2 in chkpoint['optimizer_state_dict']['state'][key1].keys() ])
         

In [88]:
temp_pt={}
for key1 in chkpoint['optimizer_state_dict']['state'].keys()  :
    for key2 in chkpoint['optimizer_state_dict']['state'][key1].keys() :
        chkpoint['optimizer_state_dict']['state'][key1][key2] = chkpoint['optimizer_state_dict']['state'][key1][key2].detach().to('cpu')

In [85]:
type(chkpoint['optimizer_state_dict']['state'])

dict

In [58]:
torch.save(chkpoint,'temp.pt')

In [9]:
temp_pt = torch.load('vgg_B_64_trainmin_256_testmin256_dataset_Cifar.pt')

In [12]:
temp_pt['model_state_dict']

OrderedDict([('feature_extractor.0.weight',
              tensor([[[[ 0.0361,  0.0019,  0.0902],
                        [ 0.0637,  0.0823,  0.0462],
                        [ 0.0955, -0.0035,  0.0687]],
              
                       [[-0.0346, -0.0161,  0.0118],
                        [-0.0290, -0.0685, -0.0085],
                        [-0.0474, -0.0859, -0.0483]],
              
                       [[ 0.0345, -0.0451,  0.0106],
                        [-0.0423, -0.0220, -0.0121],
                        [ 0.0322,  0.0316, -0.0394]]],
              
              
                      [[[ 0.0449,  0.0202,  0.0244],
                        [-0.0516,  0.0632,  0.0541],
                        [ 0.0045,  0.0491,  0.0646]],
              
                       [[-0.0459,  0.0369, -0.0450],
                        [ 0.0201, -0.0594,  0.0711],
                        [ 0.0297,  0.0286,  0.0657]],
              
                       [[ 0.0057,  0.0160,  0.0326],
            

: 

In [8]:
torch.save(temp_pt,'vgg_B_64_trainmin_256_testmin256_dataset_Cifar.pt')

In [77]:
model.load_state_dict(temp_pt['model_state_dict'])

<All keys matched successfully>

In [95]:
from torch import optim
lr=1
weight_decay=1
momentum=1
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay,momentum=momentum)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max',patience=10,threshold=1e-3,eps = 1e-5)


In [96]:
optimizer

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 1
    maximize: False
    momentum: 1
    nesterov: False
    weight_decay: 1
)

In [91]:
chkpoint['optimizer_state_dict']

{'state': {0: {'momentum_buffer': tensor([[[[ 4.3574e-03,  4.6046e-03,  4.2620e-03],
             [ 4.3514e-03,  4.4192e-03,  4.8152e-03],
             [ 4.6992e-03,  5.0730e-03,  4.8431e-03]],
   
            [[ 3.2233e-03,  3.4432e-03,  3.7138e-03],
             [ 3.7529e-03,  3.5779e-03,  3.9388e-03],
             [ 3.8271e-03,  4.2433e-03,  3.9027e-03]],
   
            [[ 3.9892e-03,  4.0185e-03,  3.9739e-03],
             [ 4.3434e-03,  4.5734e-03,  4.5853e-03],
             [ 4.4056e-03,  4.4683e-03,  4.7937e-03]]],
   
   
           [[[ 1.0374e-02,  1.0987e-02,  1.1256e-02],
             [ 1.0649e-02,  1.0650e-02,  1.0871e-02],
             [ 1.0511e-02,  1.1177e-02,  1.1057e-02]],
   
            [[ 8.7675e-03,  9.2629e-03,  9.0371e-03],
             [ 8.9317e-03,  9.0869e-03,  9.2312e-03],
             [ 9.3814e-03,  9.6699e-03,  9.5732e-03]],
   
            [[ 3.9653e-03,  3.9669e-03,  4.3949e-03],
             [ 4.2141e-03,  4.1775e-03,  4.8659e-03],
             [ 4.3317

In [97]:
optimizer.load_state_dict(chkpoint['optimizer_state_dict'])

In [103]:
chkpoint['model_state_dict'] = temp_pt['model_state_dict']

In [114]:
chkpoint.keys()

dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss', 'steps'])

: 