<a href="https://colab.research.google.com/github/yutakasawai/Pytorch_sandbox/blob/master/timm_Catalyst_21_42.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip list | grep catalyst

catalyst                      21.4.2        


In [2]:
!pip install catalyst



In [3]:
!pip install git+https://github.com/rwightman/pytorch-image-models

Collecting git+https://github.com/rwightman/pytorch-image-models
  Cloning https://github.com/rwightman/pytorch-image-models to /tmp/pip-req-build-536dvtne
  Running command git clone -q https://github.com/rwightman/pytorch-image-models /tmp/pip-req-build-536dvtne
Building wheels for collected packages: timm
  Building wheel for timm (setup.py) ... [?25l[?25hdone
  Created wheel for timm: filename=timm-0.4.8-cp37-none-any.whl size=344961 sha256=6eee62ec0431f102630de31401a15c124af2a5826cce26536dc038e196300fe9
  Stored in directory: /tmp/pip-ephem-wheel-cache-0apm2nu8/wheels/32/13/1e/cefd77fe01c775b407b9cdbc45a18e6805e57a395239922f72
Successfully built timm


In [4]:
import os
import random

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
from torchvision import datasets,transforms,models
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from catalyst import dl, utils,metrics
from catalyst.dl import AccuracyCallback
from catalyst.callbacks.misc import EarlyStoppingCallback

import timm


In [5]:
def seed_everything(seed=0):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

    
seed_everything(1192)

In [6]:
device = utils.get_device()  # <--------- TPU device
device

device(type='cuda')

In [7]:
val_transform = transforms.Compose([
        transforms.Resize([224,224]),
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,))
])

train_transform = transforms.Compose([
        transforms.Resize([224,224]),
        transforms.RandomCrop(224,padding=4),
        transforms.RandomHorizontalFlip(),

        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,))
])

train_dataset = datasets.CIFAR10(root="./data",train=True,download=True,transform=train_transform)
val_dataset = datasets.CIFAR10(root="./data",train=False,download=True,transform=val_transform)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
model_name = 'resnet18d'


In [9]:
class timmModel(nn.Module):
    def __init__(self, num_classes):
        super(timmModel,self).__init__()
        #モデルの定義
        self.model = timm.create_model(model_name, pretrained=True)
        
        #最終層の再定義
        if hasattr(self.model, "fc"):
          nb_ft = self.model.fc.in_features
          self.model.fc = nn.Linear(nb_ft, num_classes)
        elif hasattr(self.model, "_fc"):
          nb_ft = self.model._fc.in_features
          self.model._fc = nn.Linear(nb_ft, num_classes)
        elif hasattr(self.model, "classifier"):
          nb_ft = self.model.classifier.in_features
          self.model.classifier = nn.Linear(nb_ft, num_classes)
        elif hasattr(self.model, "last_linear"):
          nb_ft = self.model.last_linear.in_features
          self.model.last_linear = nn.Linear(nb_ft, num_classes)

    def forward(self, x):
        return self.model(x)

data_config = timm.data.resolve_data_config({}, model=model_name, verbose=True)

class EfficientNet_b0(nn.Module):
    def __init__(self, n_out):
        super(EfficientNet_b0, self).__init__()
        self.effnet = timm.create_model(model_name, pretrained=True)
        self.effnet.classifier = nn.Linear(1280, n_out)

    def forward(self, x):
        return self.effnet(x)

In [10]:
model = timmModel(10).to(device)

In [11]:
print(model)

timmModel(
  (model): ResNet(
    (conv1): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_

In [12]:
for x,y in model.named_children():
    print(x)

model


In [13]:
names = ("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")

In [14]:
use_cuda = torch.cuda.is_available()
kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}

BATCH_SIZE = 64
train_dataloader = DataLoader(train_dataset,batch_size= BATCH_SIZE ,shuffle=True, **kwargs)
val_dataloader = DataLoader(val_dataset,batch_size= BATCH_SIZE ,shuffle=False, **kwargs)


In [15]:
criterion = nn.CrossEntropyLoss().to(device)
# param_groups = [{'params':model.trained_model.parameters(),'lr':0.0001},
#                 {'params':model.final.parameters(),'lr':0.001}]
#optimizer = optim.Adam(model.fc.parameters(),lr=0.001)
optimizer = optim.Adam(model.parameters())

In [16]:
import collections
loaders = collections.OrderedDict()
loaders["train"]=train_dataloader
loaders["valid"]=val_dataloader

In [17]:
from catalyst.dl import SupervisedRunner
runner = SupervisedRunner()

In [18]:
!rm -rf ./logs

In [None]:
runner.train(model=model,
             engine=dl.DeviceEngine("cuda:0"),
            optimizer=optimizer,
             criterion=criterion,
            loaders=loaders,
            callbacks=[AccuracyCallback(input_key="logits", target_key="targets", num_classes=10)],
            #main_metric="accuracy01",
            valid_loader="valid",
            valid_metric="accuracy01",
             minimize_valid_metric=False,
             logdir="./logs",
             num_epochs=5,
             verbose=True,
             
             
            )

HBox(children=(FloatProgress(value=0.0, description='1/5 * Epoch (train)', max=782.0, style=ProgressStyle(desc…

  for k, v in self.batch_metrics.items()
