In [1]:
import timm 
import torch

model = timm.create_model('resnet34')
x     = torch.randn(1, 3, 224, 224)
model(x).shape

torch.Size([1, 1000])

In [2]:
pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)

In [3]:
import timm 
import torch

model = timm.create_model('resnet34', num_classes=10)
x     = torch.randn(1, 3, 224, 224)
model(x).shape

torch.Size([1, 10])

# Список моделей с предварительно подготовленными весами

In [4]:
avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models), avail_pretrained_models[:5]

(1163,
 ['bat_resnext26ts.ch_in1k',
  'beit_base_patch16_224.in22k_ft_in22k',
  'beit_base_patch16_224.in22k_ft_in22k_in1k',
  'beit_base_patch16_384.in22k_ft_in22k_in1k',
  'beit_large_patch16_224.in22k_ft_in22k'])

# Поиск архитектур моделей по шаблону

In [5]:
all_densenet_models = timm.list_models('*densenet*')
all_densenet_models

['densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264d',
 'densenetblur121d']

# Точная настройка модели timm в fastai

In [6]:
'''
from fastai.vision.all import *

path = untar_data(URLs.PETS)/'images'
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2,
    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))
    
# if a string is passed into the model argument, it will now use timm (if it is installed)
learn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)

learn.fine_tune(1)
'''

"\nfrom fastai.vision.all import *\n\npath = untar_data(URLs.PETS)/'images'\ndls = ImageDataLoaders.from_name_func(\n    path, get_image_files(path), valid_pct=0.2,\n    label_func=lambda x: x[0].isupper(), item_tfms=Resize(224))\n    \n# if a string is passed into the model argument, it will now use timm (if it is installed)\nlearn = vision_learner(dls, 'vit_tiny_patch16_224', metrics=error_rate)\n\nlearn.fine_tune(1)\n"

# Создайте любую модель из timm.

In [8]:
import timm 
import torch
import segmentation_models_pytorch as smp
import cv2
import numpy as np
pretrained_resnet_34 = timm.create_model('convnext_nano',features_only=False, pretrained=True)
def img2tensor(image: np.ndarray,
               mean: list = None, std: list = None,
               size: int = 224) -> torch.Tensor:
    if mean is None:
        mean = mean=[0.485, 0.456, 0.406]
    if std is None:
        std=[0.229, 0.224, 0.225]
    t = cv2.resize(image, (size, size), interpolation=cv2.INTER_AREA)
    t = torch.from_numpy(t.astype(np.float32) / 255.0) # 0..255 -> 0..1
    t = t.permute(2, 0, 1)  # HWC -> CHW
    _m = torch.FloatTensor(mean).unsqueeze(1).unsqueeze(1) # Cx1x1
    _s =  torch.FloatTensor(std).unsqueeze(1).unsqueeze(1)
    t = (t - _m) / (_s + 1E-7)
    return t

Downloading:   0%|          | 0.00/62.4M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


# Напишите код запуска модели на пользовательском изображении

In [9]:
import cv2
import torch
import urllib
import os
import numpy as np
from matplotlib import pyplot as plt
import segmentation_models_pytorch as smp
import timm
from tqdm.notebook import tqdm
%matplotlib inline
import urllib.request
'''

img_url = 'https://shack.explorer-russia.ru/gallery/auto/modification/2949.jpg'
req = urllib.request.urlopen(img_url)
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
img = cv2.imdecode(arr, -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(12, 10))
plt.imshow(img)
plt.show()
'''
image = cv2.imread("756ac8864b07aaf0393073b2741213ad.jpeg")
x     = img2tensor(image,size=512).unsqueeze(0)
pretrained_resnet_34(x).shape

torch.Size([1, 1000])

In [10]:
import cv2 as cv
img = cv2.imread('756ac8864b07aaf0393073b2741213ad.jpeg')
while True: 
    cv.imshow('ImageDisplay',img)
    if cv.waitKey(20) & 0xFF == ord('d'):
        break
cv.destroyAllWindows()

In [11]:
print(x.shape)
# Option 1
with torch.no_grad():
    out = model(x)

print(out.shape)
print(out.max())

torch.Size([1, 3, 512, 512])
torch.Size([1, 10])
tensor(0.5682)


In [12]:
probabilities = torch.nn.functional.softmax(out[0], dim=0)  # BS is 1 => take 0 elem
print(probabilities.max())
print(probabilities.sum())

tensor(0.1513)
tensor(1.0000)


In [13]:
url = 'https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt' 
urllib.request.urlretrieve(url, 'imagenet_classes.txt')

('imagenet_classes.txt', <http.client.HTTPMessage at 0x17072cd9f90>)

In [14]:
import torch
with torch.no_grad():
    out = pretrained_resnet_34(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
print(probabilities.shape)

torch.Size([1000])


In [15]:
# Get imagenet class mappings
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

In [18]:
# Print top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

beagle 0.3965320885181427
fountain 0.006739823613315821
tennis ball 0.004441143479198217
Walker hound 0.0042942119762301445
Chihuahua 0.003589363768696785


In [None]:
# потом дописать классификацию и трассировку

## Создайте модель с помощью SMP с encoder частью из timm 

In [None]:
model = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
    activation='sigmoid'
)
_=model.eval()
torch.save(model.encoder.state_dict(),'resnet50_1.pt')
model2= timm.create_model('resnet50',features_only=True,checkpoint_path='resnet50_1.pt')

### повторим для другой модели

In [None]:
model3 = smp.Unet(
    encoder_name="resnet50",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="ssl",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
    activation='sigmoid'
)
_=model3.eval()
torch.save(model.encoder.state_dict(),'resnet50_2.pt')
model4= timm.create_model('resnet50',features_only=True,checkpoint_path='resnet50_2.pt')