In [1]:
# Run this to install the package.
# !pip install ..

# Run these to fix the progress bar (if they are broken) and reload.
# !jupyter nbextension enable --py widgetsnbextension
# !jupyter labextension install @jupyter-widgets/jupyterlab-manager

In [2]:
import timm
import torch
import wandb
from tqdm.notebook import tqdm
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.optim import lr_scheduler
from mids_plane_classification.loaders.dataloader import PlaneDataModule

In [3]:
EPOCHS = 30
LR = 0.1
TRAIN_BATCH = 16
VAL_BATCH = 16
NUM_WORKERS = 1
LOG_INTERVAL = 15

wandb.config = {
    "learning_rate": LR,
    "epochs": EPOCHS,
    "train_batch_size": TRAIN_BATCH,
    "val_batch_size": VAL_BATCH
}

model = timm.create_model('fbnetv3_b', pretrained = True)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

train_transform = A.Compose([
    A.Resize(width=224, height=224),
    # A.RandomResizedCrop(height=224, width=224, scale=(0.7, 0.9)),
    # A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
    # A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
    # A.RandomBrightnessContrast(p=0.5),
    A.Normalize(),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.Resize(width=224, height=224),
    A.Normalize(),
    ToTensorV2(),
])

dm = PlaneDataModule(
    train_batch_size=TRAIN_BATCH,
    val_batch_size=VAL_BATCH,
    data_dir='../data',
    train_transform=train_transform,
    val_transform=val_transform,
    num_workers=NUM_WORKERS,
    seed=2
)
dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.CrossEntropyLoss()
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [4]:
wandb.init(project="mids_plane_classification", entity="mids-w251")
wandb.watch(model)

[34m[1mwandb[0m: Currently logged in as: [33msotoodaa[0m ([33mmids-w251[0m). Use [1m`wandb login --relogin`[0m to force relogin


[]

In [5]:
val_labels = np.array([label[1] for label in val_loader.dataset])

for epoch in range(EPOCHS):
    print('Epoch {}/{}'.format(epoch+1, EPOCHS))
    model.train()
    running_loss = 0.0
    pb = tqdm(train_loader)
    
    for step, batch in enumerate(pb):
        inputs = batch[0].to(device, dtype=torch.float)
        labels = batch[1].to(device).long()

        # Set gradients to zero.
        optimizer.zero_grad()

        # Predictions.
        outputs = model(inputs)

        # Get loss.
        loss = criterion(outputs, labels)

        # Backprop.
        loss.backward()

        # Step in the optimizer and scheduler.
        optimizer.step()
        scheduler.step()

        # Calculate the running loss.
        running_loss += loss.item()
        train_loss = running_loss / (step + 1)

        # Update the progress bar.
        pb.set_postfix(train_loss=(train_loss))
        
        if step % LOG_INTERVAL == 0:
            wandb.log({"train_loss": train_loss})
        
    # Keep track of the validation predictions.
    val_preds = []

    # Set model to eval mode.
    model.eval()
    running_loss = 0.0

    # Create separate progress bar for validation.
    pbv = tqdm(val_loader)

    for step, batch in enumerate(pbv):
        inputs = batch[0].to(device, dtype=torch.float)
        labels = batch[1].to(device).long()
        
        # No gradients for evaluation.
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        
        # Retain all predictions for full accuracy later.
        val_preds.append(outputs)
        
        # Compute validation loss.
        running_loss += loss.item()
        val_loss = running_loss / (step + 1)
        
        # Update the progress bar.
        pbv.set_postfix(valid_loss=(val_loss))
        
        if step % LOG_INTERVAL == 0:
            wandb.log({"val_loss": val_loss})

    preds = torch.cat(val_preds) \
        .argmax(1) \
        .detach() \
        .cpu() \
        .numpy()
    
    val_accuracy = (val_labels == preds).mean()
    wandb.log({"val_accuracy": val_accuracy})
    print(f'val_accuracy: {val_accuracy}')

Epoch 1/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5277127244340359
Epoch 2/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5253708040593287
Epoch 3/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5238095238095238
Epoch 4/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5183450429352069
Epoch 5/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5300546448087432
Epoch 6/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5269320843091335
Epoch 7/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5292740046838408
Epoch 8/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5238095238095238
Epoch 9/30


  0%|          | 0/321 [00:00<?, ?it/s]

  0%|          | 0/81 [00:00<?, ?it/s]

val_accuracy: 0.5362997658079626
Epoch 10/30


  0%|          | 0/321 [00:00<?, ?it/s]

KeyboardInterrupt: 