In [106]:
import logging
from typing import Any, Optional

import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
import timm
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from src.data import IDAOData, train_transforms, val_transforms
from src.utils import ExponentialAverage

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Train

In [107]:
BATCH_SIZE = 128
N_EPOCHS = 8
device = torch.device('cuda')

In [108]:
import albumentations as A

CROP = 100
CENTER = 80
NORM_MEAN = 0.3938
NORM_STD = 0.15

def crop_out_center(img, **kwargs):
    height, width = img.shape
    from_h, to_h = height//2 - CENTER//2, height//2 + CENTER//2 
    from_w, to_w = height//2 - CENTER//2, height//2 + CENTER//2 

    img[from_h:to_h, from_w:to_w] = int(0.3938 * 255)
    return img
    
def train_transforms() -> Any:
    transforms = A.Compose(
        [
            A.CenterCrop(CROP, CROP),
#             A.Lambda(image=crop_out_center),
            A.Flip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ]
    )

    return transforms


def val_transforms() -> Any:
    transforms = A.Compose(
        [
            A.CenterCrop(CROP, CROP),
#             A.Lambda(image=crop_out_center),            
            A.Normalize(mean=NORM_MEAN, std=NORM_STD),
        ]
    )

    return transforms


In [109]:
train_ds = IDAOData('data/train', transform=train_transforms())
val_ds = IDAOData('data/val', transform=val_transforms())
test_ds = IDAOData('data/test', transform=val_transforms())
test_holdout_ds = IDAOData('data/test_holdout', transform=val_transforms())

In [110]:
def get_train_dataloader():
    return DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=8,
        pin_memory=True
    )

def get_val_dataloader(dataset):
    return DataLoader(
        dataset,
        batch_size=4,
        shuffle=False,
        num_workers=8,
        pin_memory=True
    )

In [111]:
def update_pbar(pbar, loss: float, eval_loss: Optional[float]):
    if eval_loss is not None:
        pbar.set_postfix({'loss': loss, 'eval_loss': eval_loss})
    else:
        pbar.set_postfix({'loss': loss})
    pbar.update(1)

In [112]:
model = timm.create_model('efficientnet_b0', in_chans=1, num_classes=1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
scaler = torch.cuda.amp.GradScaler()
loss_fn = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)

In [113]:
eval_loss = None
pbar = tqdm(total=len(train_ds) // BATCH_SIZE)

for epoch in range(N_EPOCHS):
    
    pbar.reset()
    pbar.set_description(f'Epoch {epoch+1}/{N_EPOCHS}')
    
    # Train
    model.train()
    loss_avg = ExponentialAverage()
    
    for img, r_type, energy in get_train_dataloader():
        
        img = img.to(device)
        r_type = r_type.to(device)
        energy = energy.to(device).to(torch.float)
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            energy_pred = torch.squeeze(model(img), dim=1)
            loss = loss_fn(energy_pred, energy)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        update_pbar(pbar, loss_avg(loss.item()), eval_loss)
        
    # Eval
    model.eval()
    pbar_eval = tqdm(total=len(val_ds) // BATCH_SIZE, leave=False)
    pbar_eval.set_description(f'Eval')
    total_loss_eval = torch.tensor(0.0, device=device)
    
    for img, r_type, energy in get_val_dataloader(val_ds):
        img = img.to(device)
        r_type = r_type.to(device)
        energy = energy.to(device).to(torch.float)

        with torch.no_grad(), torch.cuda.amp.autocast():
            energy_pred = torch.squeeze(model(img), dim=1)
            total_loss_eval += loss_fn(energy_pred, energy)
        
        pbar_eval.update(1)
    
    pbar_eval.close()

    eval_loss = total_loss_eval.item() * 4 / len(val_ds)
    update_pbar(pbar, loss_avg.running_avg, eval_loss)
    
    scheduler.step()

  0%|          | 0/94 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

# Test

In [114]:
model.eval()
predictions = []
ground_truth = []
for img, r_type, energy in tqdm(get_val_dataloader(test_ds)):
    img = img.to(device)
    with torch.no_grad():
        energy_pred =  torch.squeeze(model(img), dim=1)

    ground_truth.append(energy.numpy())
    predictions.append(energy_pred.cpu().numpy())
    
predictions = np.concatenate(predictions)
ground_truth = np.concatenate(ground_truth)

RMSE = mean_squared_error(ground_truth, predictions, squared=False)
print(f'RMSE score is {RMSE:.3f}')

  0%|          | 0/171 [00:00<?, ?it/s]

RMSE score is 0.812


# Test holdout

In [115]:
model.eval()
predictions = []
ground_truth = []
for img, r_type, energy in tqdm(get_val_dataloader(test_holdout_ds)):
    img = img.to(device)
    with torch.no_grad():
        energy_pred = torch.squeeze(model(img), dim=1)

    ground_truth.append(energy.numpy())
    predictions.append(energy_pred.cpu().numpy())
    
predictions = np.concatenate(predictions)
ground_truth = np.concatenate(ground_truth)

RMSE = mean_squared_error(ground_truth, predictions, squared=False)
print(f'RMSE score is {RMSE:.3f}')

  0%|          | 0/3 [00:00<?, ?it/s]

RMSE score is 5.023


In [118]:
diffs = np.abs(np.repeat(np.array([[1,3,6,10,20,30]]), 12, 0) - np.expand_dims(predictions, 1))
index = diffs.argmin(1)
prediction_clas = [[1,3,6,10,20,30][ind] for ind in index]
prediction_clas = np.array(prediction_clas)

In [119]:
import pandas as pd
pd.DataFrame(np.vstack([predictions, prediction_clas, ground_truth]).T, columns=['Pred', 'Pred class', 'True'])

Unnamed: 0,Pred,Pred class,True
0,27.542919,30.0,20.0
1,12.314597,10.0,10.0
2,24.063484,20.0,30.0
3,1.730715,1.0,1.0
4,9.821837,10.0,6.0
5,29.738873,30.0,20.0
6,3.381606,3.0,3.0
7,2.165079,3.0,1.0
8,20.362335,20.0,30.0
9,9.21964,10.0,10.0
