In [1]:
import logging
from typing import Any, Optional

import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import timm
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from src.data import IDAOData, train_transforms, val_transforms
from src.utils import ExponentialAverage

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Train

In [15]:
BATCH_SIZE = 8
N_EPOCHS = 6
device = torch.device('cuda')

In [34]:
import albumentations as A

CROP = 80
CENTER = 80
NORM_MEAN = 0.3938
NORM_STD = 0.15

def crop_out_center(img, **kwargs):
    height, width = img.shape
    from_h, to_h = height//2 - CENTER//2, height//2 + CENTER//2 
    from_w, to_w = height//2 - CENTER//2, height//2 + CENTER//2 

    img[from_h:to_h, from_w:to_w] = int(0.3938 * 255)
    return img
    
def train_transforms() -> Any:
    transforms = A.Compose(
        [
            A.CenterCrop(CROP, CROP),
#             A.Lambda(image=crop_out_center),
            A.Flip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ]
    )

    return transforms


def val_transforms() -> Any:
    transforms = A.Compose(
        [
            A.CenterCrop(CROP, CROP),
#             A.Lambda(image=crop_out_center),            
            A.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ]
    )

    return transforms


In [35]:
train_ds = IDAOData('data/train', transform=train_transforms())
val_ds = IDAOData('data/val', transform=val_transforms())
test_ds = IDAOData('data/test', transform=val_transforms())
test_holdout_ds = IDAOData('data/test_holdout', transform=val_transforms())

In [36]:
def get_train_dataloader():
    return DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        pin_memory=True
    )

def get_val_dataloader(dataset):
    return DataLoader(
        dataset,
        batch_size=4,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

In [37]:
def update_pbar(pbar, loss: float, eval_loss: Optional[float]):
    if eval_loss is not None:
        pbar.set_postfix({'loss': loss, 'eval_loss': eval_loss})
    else:
        pbar.set_postfix({'loss': loss})
    pbar.update(1)

In [38]:
model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler()
loss_fn = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 40, gamma=0.1)

In [39]:
eval_loss = None
pbar = tqdm(total=len(train_ds) // BATCH_SIZE)

for epoch in range(N_EPOCHS):
    
    pbar.reset()
    pbar.set_description(f'Epoch {epoch+1}/{N_EPOCHS}')
    
    # Train
    model.train()
    loss_avg = ExponentialAverage()
    
    for img, r_type, energy in get_train_dataloader():
        
        img = img.to(device)
        r_type = r_type.to(device)
        energy = energy.to(device)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            r_type_pred = model(img)
            loss = loss_fn(r_type_pred, r_type)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        update_pbar(pbar, loss_avg(loss.item()), eval_loss)
        
    # Eval
    model.eval()
    pbar_eval = tqdm(total=len(val_ds) // BATCH_SIZE, leave=False)
    pbar_eval.set_description(f'Eval')
    total_loss_eval = torch.tensor(0.0, device=device)
    
    for img, r_type, energy in get_val_dataloader(val_ds):
        img = img.to(device)
        r_type = r_type.to(device)
        energy = energy.to(device)

        with torch.no_grad(), torch.cuda.amp.autocast():
            r_type_pred = model(img)
            total_loss_eval += loss_fn(r_type_pred, r_type)
        
        pbar_eval.update(1)
    
    pbar_eval.close()

    eval_loss = total_loss_eval.item() * 4 / len(val_ds)
    update_pbar(pbar, loss_avg.running_avg, eval_loss)
    
    scheduler.step()

  0%|          | 0/1507 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

# Test

In [43]:
model.eval()
predictions = []
ground_truth = []
for img, r_type, energy in tqdm(get_val_dataloader(test_ds)):
    img = img.to(device)
    with torch.no_grad():
        r_type_pred = model(img)

    ground_truth.append(r_type.numpy())
    predictions.append(r_type_pred.argmax(dim=1).cpu().numpy())
    
predictions = np.concatenate(predictions)
ground_truth = np.concatenate(ground_truth)

ROCAUC = roc_auc_score(ground_truth, predictions)
accuracy = accuracy_score(ground_truth, predictions)
print(f'ROC-AUC score is {ROCAUC:.3f} and accuracy {accuracy: .3f}')

  0%|          | 0/171 [00:00<?, ?it/s]

ROC-AUC score is 0.763 and accuracy  0.762


# Test holdout

In [44]:
model.eval()
predictions = []
ground_truth = []
for img, r_type, energy in tqdm(get_val_dataloader(test_holdout_ds)):
    img = img.to(device)
    with torch.no_grad():
        r_type_pred = model(img)

    ground_truth.append(r_type.numpy())
    predictions.append(r_type_pred.argmax(dim=1).cpu().numpy())
    
predictions = np.concatenate(predictions)
ground_truth = np.concatenate(ground_truth)

ROCAUC = roc_auc_score(ground_truth, predictions)
accuracy = accuracy_score(ground_truth, predictions)
print(f'ROC-AUC score is {ROCAUC:.3f} and accuracy {accuracy: .3f}')

  0%|          | 0/3 [00:00<?, ?it/s]

ROC-AUC score is 0.500 and accuracy  0.500


In [45]:
import pandas as pd
pd.DataFrame(np.vstack([predictions, ground_truth]).T, columns=['Pred', 'True'])

Unnamed: 0,Pred,True
0,0,0
1,0,1
2,0,1
3,1,0
4,0,0
5,0,0
6,1,1
7,1,0
8,0,1
9,1,1
