<a href="https://colab.research.google.com/github/shpotes/tensorflowers/blob/entregable/notebooks/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%cd /content/
!rm -rf tensorflowers
!git clone -b entregable https://github.com/shpotes/tensorflowers.git
%cd tensorflowers
!pip install -qq -r requirements.txt

/content
Cloning into 'tensorflowers'...
remote: Enumerating objects: 426, done.[K
remote: Counting objects: 100% (426/426), done.[K
remote: Compressing objects: 100% (285/285), done.[K
remote: Total 426 (delta 212), reused 306 (delta 117), pack-reused 0[K
Receiving objects: 100% (426/426), 126.03 MiB | 30.11 MiB/s, done.
Resolving deltas: 100% (212/212), done.
/content/tensorflowers


In [1]:
%cd /content/tensorflowers/
!git pull

/content/tensorflowers
remote: Enumerating objects: 9, done.[K
remote: Counting objects: 100% (9/9), done.[K
remote: Compressing objects: 100% (1/1), done.[K
remote: Total 5 (delta 4), reused 5 (delta 4), pack-reused 0[K
Unpacking objects: 100% (5/5), done.
From https://github.com/shpotes/tensorflowers
   f619003..5e373ec  entregable -> origin/entregable
Updating f619003..5e373ec
Fast-forward
 src/data/pl_datamodule.py | 32 [32m++++++++++++++[m[31m------------------[m
 1 file changed, 14 insertions(+), 18 deletions(-)


In [2]:
from functools import partial

import timm
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import numpy as np
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms as T
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from src.data import TFColDataModule, create_train_transformations
from src.model import HydraModule
from src.utils.training_utils import turn_off_bn, load_backbone

Using custom data configuration default


In [3]:
dm = TFColDataModule(
    image_train_transforms=create_train_transformations(with_rand_augmentation=False),
    image_eval_transforms=T.Compose([
      T.Resize(224),
      T.ToTensor(),
      T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
    ]),
    batch_size=64,
)

In [7]:
backbone = load_backbone(
    model_name="resnet50d",
    checkpoint="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth",
)
turn_off_bn(backbone)

model = HydraModule(
    backbone,
    lr=3e-4,
    clf_loss="asl"
)

In [8]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mshpotes[0m (use `wandb login --relogin` to force relogin)


In [9]:
logger = WandbLogger(
    project="challenge", 
    name="last resort",
    entity="tensorflowers",
)
trainer = pl.Trainer(
    max_epochs=50,
    gpus=1,
    logger=logger    
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, dm)

Using custom data configuration default


Downloading and preparing dataset tf_col/default to /root/.cache/huggingface/datasets/tf_col/default/1.0.0/3b0e3e3ab9e837479b0682e6476d88f5d3345d1b26d6c0df84c1cc39703fdebb...


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/687M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/469k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/225M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/155k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/244M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/67.9k [00:00<?, ?B/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset tf_col downloaded and prepared to /root/.cache/huggingface/datasets/tf_col/default/1.0.0/3b0e3e3ab9e837479b0682e6476d88f5d3345d1b26d6c0df84c1cc39703fdebb. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration default
Reusing dataset tf_col (/root/.cache/huggingface/datasets/tf_col/default/1.0.0/3b0e3e3ab9e837479b0682e6476d88f5d3345d1b26d6c0df84c1cc39703fdebb)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration default
Reusing dataset tf_col (/root/.cache/huggingface/datasets/tf_col/default/1.0.0/3b0e3e3ab9e837479b0682e6476d88f5d3345d1b26d6c0df84c1cc39703fdebb)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration default
Reusing dataset tf_col (/root/.cache/huggingface/datasets/tf_col/default/1.0.0/3b0e3e3ab9e837479b0682e6476d88f5d3345d1b26d6c0df84c1cc39703fdebb)


  0%|          | 0/3 [00:00<?, ?it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mshpotes[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name                | Type                     | Params
-----------------------------------------------------------------
0 | train_metric        | CrossEntropyMetric       | 0     
1 | val_metric          | CrossEntropyMetric       | 0     
2 | feature_extraction  | Sequential               | 23.5 M
3 | classification_head | Sequential               | 595 K 
4 | city_criterion      | CrossEntropyLoss         | 0     
5 | clf_criterion       | AsymmetricLossMultiLabel | 0     
6 | city_head           | Sequential               | 532 K 
-----------------------------------------------------------------
24.7 M    Trainable params
0         Non-trainable params
24.7 M    Total params
98.622    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Validating: 0it [00:00, ?it/s]

  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Validating: 0it [00:00, ?it/s]

  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Validating: 0it [00:00, ?it/s]

  "Argument interpolation should be of type InterpolationMode instead of int. "
  "Argument interpolation should be of type InterpolationMode instead of int. "


Validating: 0it [00:00, ?it/s]