In [1]:
import timm
import torch
import torchvision
from timm.optim.optim_factory import create_optimizer
from types import SimpleNamespace
from timm.data.transforms_factory import create_transform
from timm.data import ImageDataset
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
create_transform(input_size=224, is_training=True, no_aug=False, hflip=0.5, vflip=0.4)

Compose(
    RandomResizedCropAndInterpolation(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)
    RandomHorizontalFlip(p=0.5)
    RandomVerticalFlip(p=0.4)
    ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.6, 1.4], hue=None)
    ToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)

In [18]:
import warnings
warnings.filterwarnings('ignore', '.*interpolation.*', )

model = timm.create_model('resnet50', pretrained=True, num_classes=1081)

create_transform_custom = create_transform(input_size=224, is_training=True, no_aug=False, hflip=0.5, vflip=0.4)

dataset = ImageDataset(r"C:\Users\lulu5\Documents\echantillon_2", 
                       transform=create_transform_custom)
loader = DataLoader(dataset, batch_size=100, shuffle=True, num_workers=16)



In [25]:
dataiter = iter(loader)



In [26]:
images, labels = dataiter.next()

In [30]:
from matplotlib.pyplot import imshow

In [32]:
img_grid = torchvision.utils.make_grid(images)

# show images
imshow(img_grid, one_channel=True)

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act3): ReLU(inplace=True)
      (downsample): Sequen

In [20]:
args = SimpleNamespace()
args.weight_decay = 0
args.lr = 1e-4
args.opt = 'sgd' #'lookahead_adam' to use `lookahead`
args.momentum = 0.9

optimizer = create_optimizer(args, model)
criterion = CrossEntropyLoss()

scheduler = MultiStepLR(optimizer, milestones=[30, 40], gamma=0.1)

In [21]:
for epoch in range(1):
    print(epoch)
    running_loss = 0
    for inputs, targets in tqdm(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        running_loss += loss.item()
    print('Loss:', running_loss/len(loader))
    scheduler.step()

  0%|          | 0/22 [00:00<?, ?it/s]

0


100%|██████████| 22/22 [12:20<00:00, 33.67s/it]

Loss: 6.99011924050071





In [22]:
torch.save(model.state_dict(),'model.torch')

In [9]:
torch.load('model.torch')

OrderedDict([('conv1.weight',
              tensor([[[[ 8.3267e-03,  2.8562e-02,  1.9534e-02,  ..., -4.8805e-02,
                         -9.6412e-03,  1.0822e-02],
                        [-3.6440e-02,  7.3036e-02, -6.1102e-04,  ..., -1.0023e-01,
                         -6.7153e-02,  3.0907e-02],
                        [-6.2676e-03,  1.1590e-01,  9.7607e-02,  ..., -1.9845e-02,
                         -1.8217e-03,  9.7529e-02],
                        ...,
                        [ 1.6272e-02,  2.7184e-01,  5.4999e-01,  ...,  5.5690e-01,
                          6.5946e-01,  5.1294e-01],
                        [-1.4850e-02,  1.1854e-01,  2.1143e-01,  ...,  3.3958e-01,
                          3.8925e-01,  3.1459e-01],
                        [-4.7336e-02,  1.3812e-02, -4.5144e-02,  ..., -9.2893e-02,
                         -1.6903e-02,  1.0397e-01]],
              
                       [[ 3.8145e-02,  1.0174e-01, -1.0503e-02,  ..., -9.1465e-02,
                         -1.8382