In [2]:
import tqdm as notebook_tqdm
import torch
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Dict, Optional, cast
from torch import Tensor
from collections import OrderedDict 
from torchvision.models.resnet import *
from torchvision.models.resnet import BasicBlock, Bottleneck
from torchvision.models.resnet import ResNet50_Weights

In [3]:
class IntResNet(ResNet):
    def __init__(self,output_layer,*args):
        self.output_layer = output_layer
        super().__init__(*args)
        
        self._layers = []
        for l in list(self._modules.keys()):
            self._layers.append(l)
            if l == output_layer:
                break
        self.layers = OrderedDict(zip(self._layers,[getattr(self,l) for l in self._layers]))

    def _forward_impl(self, x):
        for l in self._layers:
            x = self.layers[l](x)

        return x

    def forward(self, x):
        return self._forward_impl(x)

In [6]:
def new_resnet(
    arch: str,
    outlayer: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> IntResNet:

    '''model_urls = {
        'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
        'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
        'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
        'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
        'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
        'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
        'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
        'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
        'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
    }'''

    model = IntResNet(outlayer, block, layers, **kwargs)
    if pretrained:
        state_dict = torch.hub.load_state_dict_from_url("https://download.pytorch.org/models/resnet50-19c8e357.pth",
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

In [7]:
model = new_resnet('resnet50','layer4',Bottleneck, [3, 4, 6, 3],True,True)
# model = model.to('cuda:0')

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to C:\Users\Sheryl/.cache\torch\hub\checkpoints\resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 54.8MB/s]


In [22]:
from PIL import Image
import requests
from io import BytesIO
import torchvision.transforms as transforms

def get_image_from_url(url):
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))

    resize = transforms.Resize([224, 224])
    img = resize(img)

    to_tensor = transforms.ToTensor()
    img = to_tensor(img)
    img.unsqueeze_(0)
    
    return img

torch.Size([1, 3, 224, 224])


In [23]:
f = open('items_shuffle_1000.json')
json_data = json.load(f)
for item in json_data:
    for url in item['images']:
        if ".gif" in url:
            continue
        try:
            img = get_image_from_url(url)
            out = model(img)
        except:
            print("error")
            continue
# out = out.cpu().data.numpy()