In [7]:
from dl_toolbox.lightning_modules import *
from dl_toolbox.networks import *
from dl_toolbox.torch_datasets import *
from torch.utils.data import DataLoader, Subset, ConcatDataset
from dl_toolbox.torch_collate import CustomCollate
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer 

In [8]:
logger = TensorBoardLogger(
    '/d/pfournie/ai4geo/outputs',
    name='test_module_CE'
)

trainer = Trainer(
    max_epochs=100,
    logger=logger,
    callbacks=[
        ModelCheckpoint(),
    ],
    gpus=1
)

train_set = Subset(
    dataset=ResiscDs(
        data_path='/d/pfournie/ai4geo/data/NWPU-RESISC45',
        img_aug='d4',
    ),
    indices=[700*i+j for i in range(45) for j in range(50)]
)

val_set = Subset(
    dataset=ResiscDs(
        data_path='/d/pfournie/ai4geo/data/NWPU-RESISC45',
        img_aug='no'
    ),
    indices=[700*i+j for i in range(45) for j in range(50, 100)]
)

train_dataloader = DataLoader(
    dataset=train_set,
    batch_size=16,
    collate_fn=CustomCollate(),
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
)

val_dataloader = DataLoader(
    dataset=val_set,
    shuffle=False,
    collate_fn=CustomCollate(),
    batch_size=16,
    num_workers=4,
    pin_memory=True,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [9]:
module = CE(
    ignore_index=-1,
    network='Vgg',
    weights=[],
    in_channels=3,
    out_channels=45,
    initial_lr=0.001,
    final_lr=0.0005
)

trainer.fit(
    model=module,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | network | Vgg              | 1.8 M 
1 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.068     Total estimated model params size (MB)


Validation sanity check:  50%|█████████████████████████████████████████████████                                                 | 1/2 [00:02<00:02,  2.31s/it]



Epoch 0:  50%|█████████████████████████████████████████████▎                                             | 140/281 [00:06<00:07, 20.04it/s, loss=3.4, v_num=0]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                     | 0/141 [00:00<?, ?it/s][A
Epoch 0:  51%|█████████████████████████████████████████████▉                                             | 142/281 [00:08<00:07, 17.71it/s, loss=3.4, v_num=0][A
Epoch 0:  54%|████████████████████████████████████████████████▉                                          | 151/281 [00:08<00:06, 18.60it/s, loss=3.4, v_num=0][A
Epoch 0:  57%|███████████████████████████████████████████████████▊                                       | 160/281 [00:08<00:06, 19.29it/s, loss=3.4, v_num=0][A
Epoch 0:  60%|██████████████████████████████████████████████████████▋                                    | 169/281 [00:08<00:05, 20.00it/s, loss=3.4, v_num=0]



Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 281/281 [00:12<00:00, 22.53it/s, loss=3.4, v_num=0]
Epoch 1:  50%|████████████████████████████████████████████▊                                             | 140/281 [00:05<00:05, 24.94it/s, loss=3.24, v_num=0][A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                     | 0/141 [00:00<?, ?it/s][A
Epoch 1:  51%|██████████████████████████████████████████████                                            | 144/281 [00:05<00:05, 24.55it/s, loss=3.24, v_num=0][A
Epoch 1:  54%|█████████████████████████████████████████████████                                         | 153/281 [00:05<00:04, 25.60it/s, loss=3.24, v_num=0][A
Epoch 1:  58%|███████████████████████████████████████████████████▉                                      | 162/281 [00:06<00:04, 26.60it/s, loss=3.24, v_num=0]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


Epoch 4: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 281/281 [00:22<00:00, 12.45it/s, loss=2.88, v_num=0]
Validating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 141/141 [00:16<00:00, 75.54it/s][A