In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
from torchvision import models

import h5py
import torch
import shutil

def save_net(fname, net):
    with h5py.File(fname, 'w') as h5f:
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())
def load_net(fname, net):
    with h5py.File(fname, 'r') as h5f:
        for k, v in net.state_dict().items():
            param = torch.from_numpy(np.asarray(h5f[k]))
            v.copy_(param)

def save_checkpoint(state, is_best,task_id, filename='checkpoint.pth.tar'):
    torch.save(state, task_id+filename)
    if is_best:
        shutil.copyfile(task_id+filename, task_id+'model_best.pth.tar')

class CSRNet(nn.Module):
    def __init__(self, load_weights=False):
        super(CSRNet, self).__init__()
        self.seen = 0
        self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512]
        self.backend_feat  = [512, 512, 512,256,128,64]
        self.frontend = make_layers(self.frontend_feat)
        self.backend = make_layers(self.backend_feat,in_channels = 512,dilation = True)
        self.output_layer = nn.Conv2d(64, 1, kernel_size=1)
        if not load_weights:
            mod = models.vgg16(pretrained = True)
            self._initialize_weights()
            for i in xrange(len(self.frontend.state_dict().items())):
                self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]
    def forward(self,x):
        x = self.frontend(x)
        x = self.backend(x)
        x = self.output_layer(x)
        return x
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False):
    if dilation:
        d_rate = 2
    else:
        d_rate = 1
    layers = []
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

In [None]:
def xrange(x):
  return range(x)

checkpoint = torch.load('/content/drive/MyDrive/CSRNet/PartAmodel_best.pth.tar', map_location=torch.device('cpu'))


In [None]:
model = CSRNet(load_weights=True)

In [None]:
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [None]:
!pip install onnx

Collecting onnx
  Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.6/14.6 MB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.14.0


In [None]:
import torch.onnx
import torchvision.transforms.functional as F
from PIL import Image

In [None]:
# test the model -- prepare an image input and feed it through the model
path = '/content/drive/MyDrive/CSRNet/IMG_4.jpg'
img = Image.open(path)
w,h = img.size
img = img.resize((int(round(w * 9 // 10)), int(round(h * 9 // 10))))
img = 255.0 * F.to_tensor(img.convert('RGB'))
img[0,:,:]=img[0,:,:]-92.8207477031
img[1,:,:]=img[1,:,:]-95.2757037428
img[2,:,:]=img[2,:,:]-104.877445883
#img = img.cuda()
tensor = img.unsqueeze(0)
output = model(tensor)
output = output.detach().cpu().sum().numpy()

print(output)


4286.402


In [None]:
# Step 2: Create an example input
example_input = tensor

# Step 3: Trace the model
# Use torch.jit.trace to trace the model with the example input
traced_model = torch.jit.trace(model, example_input)

# Step 4: Export to ONNX format
# Replace 'path_to_save.onnx' with the desired path to save the ONNX model file
torch.onnx.export(traced_model, example_input, '/content/drive/MyDrive/CSRNet/csrnet.onnx', opset_version=11)



verbose: False, log level: Level.ERROR

