<a href="https://colab.research.google.com/github/shpotes/tensorflowers/blob/inference/notebooks/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%cd /content/
!rm -rf tensorflowers
!git clone -b entregable https://github.com/shpotes/tensorflowers.git
%cd tensorflowers
!pip install -qq -r requirements.txt

/content
Cloning into 'tensorflowers'...
remote: Enumerating objects: 454, done.[K
remote: Counting objects: 100% (454/454), done.[K
remote: Compressing objects: 100% (309/309), done.[K
remote: Total 454 (delta 229), reused 320 (delta 121), pack-reused 0[K
Receiving objects: 100% (454/454), 126.04 MiB | 30.30 MiB/s, done.
Resolving deltas: 100% (229/229), done.
/content/tensorflowers


In [2]:
%cd /content/tensorflowers/
!git pull

/content/tensorflowers
Already up to date.


In [3]:
from functools import partial

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms as T
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from src.data import TFColDataModule, create_train_transformations
from src.model import HydraModule
from src.utils.training_utils import turn_off_bn, load_backbone

Using custom data configuration default


In [4]:
dm = TFColDataModule(
    image_train_transforms=create_train_transformations(with_rand_augmentation=True),
    image_eval_transforms=T.Compose([
      T.Resize((224, 224)),
      T.ToTensor(),
      T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
    ]),
    batch_size=64,
)

In [5]:
backbone = load_backbone(
    model_name="resnet50d",
    checkpoint="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth",
)
turn_off_bn(backbone)

model = HydraModule(
    backbone,
    lr=3e-4,
    clf_loss="asl",
)

In [6]:
from google.colab import drive
drive.mount('/content/drive')

checkpoint_callback = ModelCheckpoint(
    monitor="val_cross_entropy_loss",
    dirpath="/content/drive/MyDrive/BORRAR",
    filename="resnet50d-asl-randaug-{epoch:02d}-{val_cross_entropy_loss:.2f}",
    save_top_k=3,
    mode="min",
)

Mounted at /content/drive


In [7]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mshpotes[0m (use `wandb login --relogin` to force relogin)


In [8]:
logger = WandbLogger(
    project="challenge", 
    name="resnet50d-asl-randaug",
    entity="tensorflowers",
)
trainer = pl.Trainer(
    gpus=1,
    logger=logger    
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, dm)