In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
# %env CUDA_LAUNCH_BLOCKING 1
# %env CUDA_VISIBLE_DEVICES 1
# %env TORCH_CUDNN_V8_API_ENABLED=1

In [3]:
import os
import sys
import torch

sys.path.append(os.path.abspath(".."))

from pyroml import Trainer
from pyroml.template.cifar100 import Cifar100Dataset
from pyroml.models import Backbone
from pyroml.models.falcon import Falcon, FalconConfig




In [4]:
device = torch.device("cuda:0")

print(
    f"""
CUDA available: {torch.cuda.is_available()}
CUDA device # : {torch.cuda.device_count()}
CUDA bf16     : {torch.cuda.is_bf16_supported()}
CUDA device   : {device}
"""
)


CUDA available: True
CUDA device # : 1
CUDA bf16     : True
CUDA device   : cuda:0



In [20]:
tr_ds = Cifar100Dataset(split="train")
te_ds = Cifar100Dataset(split="test")
len(tr_ds), len(te_ds), tr_ds[0].keys()

(50000, 10000, dict_keys(['img', 'fine_label', 'coarse_label']))

In [21]:
len(tr_ds.fine_labels), len(tr_ds.coarse_labels)

(100, 20)

In [None]:
tr_ds = NeighborsWrapper(tr_ds)

In [22]:
num_classes = len(tr_ds.fine_labels)
backbone = Backbone.load(
    name="resnet50",
    num_classes=num_classes,
    image_size=(3, 224, 224),
)
backbone

Backbone has 23,712,932 params


TimmBackbone(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act2): ReLU(inplace=True)
        (aa): Identity()
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, m

In [11]:
config = FalconConfig(
    embed_dim=backbone.last_dim[0],
    fine_classes=len(tr_ds.fine_labels),
    coarse_classes=len(tr_ds.coarse_labels),
)
falcon = Falcon(backbone, config)
falcon

Falcon(
  (backbone): TimmBackbone(
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act1): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act2): ReLU(inplace=True)
          (aa): Identity()
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bi

In [19]:
from pyroml.core.stage import Stage


falcon.step(torch.rand(1, 3, 224, 224), stage=Stage.TRAIN)

TypeError: new(): invalid data type 'str'

In [None]:
trainer = Trainer(
    lr=0.001,
    batch_size=32,
    max_epochs=32,
    wandb=False,
    device=device,
    dtype=torch.bfloat16,
)

In [None]:
trainer.fit(falcon, tr_ds, te_ds)