In [31]:
import os
import zipfile
import torch
import numpy as np
from datetime import datetime
from torchvision import models, transforms
from collections import OrderedDict
import copy

import object3d
from models.MeshNet import MeshNet

In [2]:
class ModelNet40WithImage(torch.utils.data.Dataset):
    def __init__(self, cfg, renderer, mode='train'):
        self.root = cfg['data_root']
        self.augmentation = cfg['augmentation']
        self.mode = mode

        self.data = []
        labels = os.listdir(self.root)
        labels.sort()
        for label_index, label in enumerate(labels):
            type_root = os.path.join(self.root, label, mode)
            for filename in os.listdir(type_root):
                if filename.endswith('.npz'):
                    identity = os.path.join(type_root, filename[:-4])
                    self.data.append((identity, label_index))
                    
        # init renderer
        self.renderer = renderer
        
        # resnet transform
        if mode == 'train':
            self.resnet_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        else:
            self.resnet_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    def __getitem__(self, i):
        identity, label = self.data[i]
        
        # generate 2D Image
        self.renderer.set_obj(identity + '.off')
        self.renderer.random_context(coverage=0.8, obj_color=True, obj_translate=False, obj_rotation=True)
        pil_image = self.renderer.render(binary=False).convert('RGB')
        torch_image = self.resnet_transform(pil_image)
        
        # process meshnet data
        data = np.load(identity + '.npz')
        centers = data['centers'] # [face, 3]
        corners = data['corners'] # [face, vertice, 3]
        normals = data['normals'] # [face, 3]
        neighbors_index = data['neighbors_index'] # [face, ?]
        
        num_point = len(centers)
        # fill for n < 1024
        if num_point < 1024:
            chosen_indexes = np.random.randint(0, num_point, size=(1024 - num_point))
            centers = np.concatenate((centers, centers[chosen_indexes]))
            corners = np.concatenate((corners, corners[chosen_indexes]))
            normals = np.concatenate((normals, normals[chosen_indexes]))
            neighbors_index = np.concatenate((neighbors_index, neighbors_index[chosen_indexes]))
            
            # choose 3 neighbors
            new_neighbors_index = np.empty([1024, 3], dtype=np.int64)
            for idx in range(1024):
                neighbors = neighbors_index[idx]
                if len(neighbors) > 3:
                    new_neighbors_index[idx] = np.random.choice(neighbors, 3, replace=False)
                else:
                    new_neighbors_index[idx] = np.concatenate((neighbors, [idx]*(3-len(neighbors))))

            neighbors_index = new_neighbors_index
        else:
            chosen_indexes = np.random.choice(num_point, size=1024, replace=False)
            centers = centers[chosen_indexes]
            corners = corners[chosen_indexes]
            normals = normals[chosen_indexes]
            neighbors_index = neighbors_index[chosen_indexes]
            # remove unlinkable index and choose 3 neighbors
            new_neighbors_index = np.empty([1024, 3], dtype=np.int64)
            for idx in range(1024):
                mask = np.in1d(neighbors_index[idx], chosen_indexes)
                neighbors = np.array(neighbors_index[idx])[mask]
                if len(neighbors) > 3:
                    new_neighbors_index[idx] = np.random.choice(neighbors, 3, replace=False)
                else:
                    new_neighbors_index[idx] = np.concatenate((neighbors, [chosen_indexes[idx]]*(3-len(neighbors))))

            # re-index the neighbor
            invert_index = {value: key for key, value in enumerate(chosen_indexes)}
            neighbors_index = np.vectorize(invert_index.get, cache=True)(new_neighbors_index)

        # data augmentation
        if self.augmentation and self.mode == 'train':
            centers = self.__augment__(centers)
            corners = self.__augment__(corners)

        # make corner relative to center
        corners = corners - centers[:, np.newaxis, :]
        corners = corners.reshape([-1, 9])

        # to tensor
        centers = torch.from_numpy(centers).float().permute(1, 0).contiguous()
        corners = torch.from_numpy(corners).float().permute(1, 0).contiguous()
        normals = torch.from_numpy(normals).float().permute(1, 0).contiguous()
        neighbors_index = torch.from_numpy(neighbors_index).long()

        return torch_image, centers, corners, normals, neighbors_index

    def __augment__(self, data):
        sigma, clip = 0.01, 0.05
        jittered_data = np.clip(sigma * np.random.randn(*data.shape), -clip, clip)
        return data + jittered_data

    def __len__(self):
        return len(self.data)
    
    def close(self):
        self.renderer.close()

In [3]:
class Identity(torch.jit.ScriptModule):
    r"""A placeholder identity operator that is argument-insensitive.
     Args:
        args: any argument (unused)
        kwargs: any keyword argument (unused)
     Examples::
         >>> m = nn.Identity(54, unused_argumenbt1=0.1, unused_argument2=False)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 20])
     """
    def __init__(self, *args, **kwargs):
        super(Identity, self).__init__()

    @torch.jit.script_method
    def forward(self, input):
        return input

In [28]:
class Image2ComEmbedding(torch.nn.Module):
    def __init__(self, cfg):
        super(Image2ComEmbedding, self).__init__()
        resnet = models.resnet50(pretrained=True)
        resnet.fc = Identity()
        self.resnet = resnet
        
        self.embed_resnet = torch.nn.Sequential(OrderedDict([
            ('embed_resnet_fc1', torch.nn.Linear(2048, 512)),
            ('embed_resnet_relu1', torch.nn.ReLU6()),
            ('embed_resnet_dropout', torch.nn.Dropout(p=0.5)),
            ('embed_resnet_fc2', torch.nn.Linear(512, 1024)),
            ('embed_resnet_relu2', torch.nn.ReLU6()),
        ]))

        if not cfg['refine']:
            for pmt in self.resnet.parameters():
                pmt.requires_grad = False
                
    def forward(self, images):
        resnet_feature = self.resnet(images)
        embedded_image = self.embed_resnet(resnet_feature)
        
        return embedded_image
    
class Model3D2ComEmbedding(torch.nn.Module):
    def __init__(self, cfg):
        super(Model3D2ComEmbedding, self).__init__()
        self.meshnet = MeshNet(cfg, head=False)
        pretrained = cfg.get('pretrained', None)
        if pretrained:
            state_dict = torch.load(pretrained, 
                map_location=lambda storage, location: storage.cuda() if torch.cuda.is_available() else storage)
            self.meshnet.load_state_dict(state_dict, strict=False)
        
        self.embed_meshnet = torch.nn.Sequential(OrderedDict([
            ('embed_meshnet_fc1', torch.nn.Linear(256, 1280)),
            ('embed_resnet_relu1', torch.nn.ReLU6()),
            ('embed_meshnet_dropout', torch.nn.Dropout(p=0.5)),
            ('embed_meshnet_fc2', torch.nn.Linear(1280, 1024)),
            ('embed_resnet_relu2', torch.nn.ReLU6()),
        ]))
        
        if not cfg['refine']:
            for pmt in self.meshnet.parameters():
                pmt.requires_grad = False
            
    def forward(self, centers, corners, normals, neighbor_index):        
        meshnet_feature = self.meshnet(centers, corners, normals, neighbor_index)
        embedded_3dmodel = self.embed_meshnet(meshnet_feature)
        
        return embedded_3dmodel
    
class IMEmbedding(torch.jit.ScriptModule):
    def __init__(self, cfg):
        super(IMEmbedding, self).__init__()
        batch_size = cfg['batch_size']
        image_size = cfg['image_size']
        model3d_size = cfg['model3d_size']
        
        embed_image = Image2ComEmbedding(cfg['resnet'])
        self.embed_image = embed_image
        images = torch.randn(1, 3, image_size, image_size, dtype=torch.float32)
        embed_image.eval()
        self.embed_image_eval = torch.jit.trace(embed_image, [images], check_trace=False)
        embed_image.train()
        self.embed_image_train = torch.jit.trace(embed_image, [images], check_trace=False)
        
        embed_model3d = Model3D2ComEmbedding(cfg['meshnet'])
        self.embed_model3d = embed_model3d
        centers = torch.randn(1, 3, model3d_size, dtype=torch.float32)
        normals = torch.randn(1, 3, model3d_size, dtype=torch.float32)
        corners = torch.randn(1, 9, model3d_size, dtype=torch.float32)
        neighbor_index = torch.randint(0, model3d_size, (1, model3d_size, 3), dtype=torch.long)
        embed_model3d.eval()
        self.embed_model3d_eval = torch.jit.trace(embed_model3d, 
                                  [centers, corners, normals, neighbor_index], check_trace=False)
        embed_model3d.train()
        self.embed_model3d_train = torch.jit.trace(embed_model3d, 
                                   [centers, corners, normals, neighbor_index], check_trace=False)
        
        
    @torch.jit.script_method
    def __loss(self, images, model3d):
        # [image, model3d] or [batch, batch]
        im_similarity = torch.matmul(images, model3d.transpose(0, 1)) 
        # ground-truth
        sim_gt = torch.diagonal(im_similarity)
        # curremt match
        sorted_image, sarg_image = torch.sort(im_similarity, dim=1, descending=True)
        sorted_model3d, sarg_model3d = torch.sort(im_similarity, dim=0, descending=True)
        
        top1_image = sorted_image[:, 0]
        top1_model3d = sorted_model3d[0, :]
        
        # loss weights
        batch_size = im_similarity.shape[0]
        range_batch = torch.arange(0, batch_size, dtype=torch.long)
        grid_image, grid_model3d = torch.meshgrid(range_batch, range_batch)
        gt_image_rank = torch.nonzero(sarg_image == grid_image)[:, 1]
        gt_model3d_rank = torch.nonzero(sarg_model3d == grid_model3d)[:, 0]
        
        # combine loss with weights
        loss_image = torch.dot((top1_image - sim_gt), gt_image_rank.float() / batch_size)
        loss_model3d = torch.dot((top1_model3d - sim_gt), gt_model3d_rank.float() / batch_size)
        loss = (torch.sum(loss_image) + torch.sum(loss_model3d)) / batch_size
        
        return loss
    
    @torch.jit.script_method
    def forward(self, images, centers, corners, normals, neighbor_index):
        embedded_images = self.embed_image_train(images)
        embedded_model3d = self.embed_model3d_train(centers, corners, normals, neighbor_index)
        
        return embedded_images, embedded_model3d
    
    @torch.jit.script_method
    def forward_with_loss(self, images, centers, corners, normals, neighbor_index):
        embedded_images, embedded_model3d = self.forward(images, centers, corners, normals, neighbor_index)
        loss = self.__loss(embedded_images, embedded_model3d)
        return loss
    
    @torch.jit.script_method
    def forward_eval(self, images, centers, corners, normals, neighbor_index):
        embedded_images = self.embed_image_eval(images)
        embedded_model3d = self.embed_model3d_eval(centers, corners, normals, neighbor_index)
        
        return embedded_images, embedded_model3d
    
    @torch.jit.script_method
    def forward_with_loss_eval(self, images, centers, corners, normals, neighbor_index):
        embedded_images, embedded_model3d = self.forward_eval(images, centers, corners, normals, neighbor_index)
        loss = self.__loss(embedded_images, embedded_model3d)
        return loss
    
    def fit(self, dataloaders, optimizer=None, start_epoch=1, end_epochs=30, ckpt_root='ckpt_root'):
        if optimizer is None:
            optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        
        val_loss_hist = []
        for epoch in range(start_epoch, end_epochs + 1):
            print('-' * 60)
            print('Epoch: {} / {}'.format(epoch, end_epochs))
            print('-' * 60)
            
            for phrase in ['train', 'val']:
                if phrase == 'train':
                    self.train()
                else:
                    self.eval()

                running_loss = 0.0
                batch_size = dataloaders[phrase].batch_size
                dataset_size = len(dataloaders[phrase].dataset)
                total_steps = int(dataset_size / batch_size)
                
                for step, (images, centers, corners, normals, neighbor_index) in enumerate(dataloaders[phrase]):
                    optimizer.zero_grad()
                    
                    if torch.cuda.is_available():
                        images = images.cuda()
                        centers = centers.cuda()
                        normals = normals.cuda()
                        corners = corners.cuda()
                        neighbor_index = neighbor_index.cuda()
                        targets = targets.cuda()
                        
                    with torch.set_grad_enabled(phrase == 'train'):
                        
                        
                        if phrase == 'train':
                            loss = self.forward_with_loss(images, centers, corners, normals, neighbor_index)
                            loss.backward()
                            optimizer.step()
                        else:
                            loss = self.forward_with_loss_eval(images, centers, corners, normals, neighbor_index)
                            
                        batch_loss = loss.item()
                        running_loss += batch_loss * batch_size
                        
                    print(f'{datetime.now()} {phrase} epoch: {epoch}/{end_epochs} '
                          f'step:{step + 1}/{total_steps} loss: {batch_loss:.4f}')
                          
                epoch_loss = running_loss / dataset_size
                print(f'{phrase} epoch: {epoch}/{end_epochs} loss: {epoch_loss:.4f}')
                
                if phrase == 'train':
                    filename = os.path.join(ckpt_root, f'{epoch:04d}.pkl')
                    torch.save(copy.deepcopy(self.state_dict()), filename)
                
                if phrase == 'val':
                    val_loss_hist.append(epoch_loss)
                    
        return val_loss_hist
    

In [5]:
cfg = {
    'resnet': {
        'refine': False
    },
    'meshnet': {
        'refine': False,
        'structural_descriptor': {
            'num_kernel': 64,
            'sigma': 0.2
        },
        'mesh_convolution': {
            'aggregation_method': 'Concat'
        },
        'pretrained': './data/meshnet.pkl'
    },
    'batch_size': 4,
    'image_size': 224,
    'model3d_size': 1024,
    'data_root': './data/ModelNet40',
    'augmentation': True
}


In [7]:
renderer = object3d.Panda3DRenderer(output_size=(224, 224), 
                cast_shadow=True, light_on=True)
loader_dict = {
    'train': torch.utils.data.DataLoader(ModelNet40WithImage(cfg, renderer, mode='train'), 4, num_workers=0,
                                         shuffle=True, pin_memory=True, drop_last=True),
    'val': torch.utils.data.DataLoader(ModelNet40WithImage(cfg, renderer, mode='test'), 4, num_workers=0,
                                         shuffle=True, pin_memory=True, drop_last=True),
}

In [29]:
embedding = IMEmbedding(cfg)

In [32]:
embedding.fit(loader_dict)

------------------------------------------------------------
Epoch: 1 / 30
------------------------------------------------------------
2019-04-30 18:45:25.845807 train epoch: 1/30 step:1/2460 loss: 14.2821
2019-04-30 18:45:32.353376 train epoch: 1/30 step:2/2460 loss: 10.8488
2019-04-30 18:45:38.947149 train epoch: 1/30 step:3/2460 loss: 22.9048
2019-04-30 18:45:47.431597 train epoch: 1/30 step:4/2460 loss: 8.4489
2019-04-30 18:45:53.719084 train epoch: 1/30 step:5/2460 loss: 16.5312
2019-04-30 18:46:00.358389 train epoch: 1/30 step:6/2460 loss: 13.9870
2019-04-30 18:46:07.237415 train epoch: 1/30 step:7/2460 loss: 7.9650
2019-04-30 18:46:13.551401 train epoch: 1/30 step:8/2460 loss: 17.9350


KeyboardInterrupt: 

In [26]:
t = loader_dict['train']

In [27]:
t.batch_size

4