In [None]:
!pip install -qU 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[audio,image]'

## Training with Flash

[Flash](https://lightning-flash.readthedocs.io/en/stable) is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. It is focused on:

- Predictions
- Finetuning
- Task-based training

It is built for data scientists, machine learning practitioners, and applied researchers.

In [None]:
import os
import functools
from pathlib import Path

import torch
from torch import nn
import numpy as np
from torch.nn import functional as F
from torchmetrics import Accuracy, F1, Recall, Precision, MeanAbsoluteError
import torchvision

import flash
from flash.audio import AudioClassificationData
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassifier

from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.audio.classification.transforms import default_transforms

import albumentations


PATH_DATASET = Path("/kaggle/input/seti-breakthrough-listen")

### Define Flash DataModule

In this section we need to define our datast used later for classification.

In [None]:
def resolver(root, id):
    return os.path.join(root, id[0], f"{id}.npy")

def preprocess(array):
    array = np.vstack(array).transpose((1, 0))
    return array.astype(np.float32)

def mixup(batch, alpha=1.0):
    images = batch["input"]
    targets = batch["target"].float().unsqueeze(1)
    
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(images.size(0))
    
    new_images = images * lam + images[perm] * (1 - lam)
    new_targets = targets * lam + targets[perm] * (1 - lam)
    batch["input"] = new_images
    batch["target"] = new_targets
    return batch

def convert_val_targets(batch):
    batch["target"] = batch["target"].float().unsqueeze(1)
    return batch


class AlbumentationsAdapter(nn.Module):
    def __init__(self, transform):
        super().__init__()
        self.transform = transform
        
    def forward(self, x):
        return self.transform(image=x)["image"]

train_transform = {
    "pre_tensor_transform": nn.Sequential(
        ApplyToKeys("input", preprocess),
        ApplyToKeys(
            "input",
            AlbumentationsAdapter(albumentations.Compose([
                albumentations.Resize(340, 340),
                #albumentations.HorizontalFlip(p=0.5),
                #albumentations.VerticalFlip(p=0.5),
                albumentations.ShiftScaleRotate(
                    shift_limit=0.1,
                    scale_limit=0.15,
                    rotate_limit=20,
                    p=0.5,
                ),
                albumentations.Resize(224, 224),
                albumentations.RandomBrightness(limit=0.6, p=0.5),
            ]))
        ),
    ),
    "to_tensor_transform": nn.Sequential(
        ApplyToKeys("input", torchvision.transforms.ToTensor()),
        ApplyToKeys("target", torch.as_tensor),
    ),
    "per_batch_transform": mixup,
}

val_transform = {
    "pre_tensor_transform": nn.Sequential(
        ApplyToKeys("input", preprocess),
        ApplyToKeys("input", AlbumentationsAdapter(albumentations.Resize(224, 224))),
    ),
    "to_tensor_transform": nn.Sequential(
        ApplyToKeys("input", torchvision.transforms.ToTensor()),
        ApplyToKeys("target", torch.as_tensor),
    ),
    "per_batch_transform": convert_val_targets,
}

datamodule = AudioClassificationData.from_csv(
    input_field="id",
    target_fields="target",
    train_file=str(PATH_DATASET / "train_labels.csv"),
    train_images_root=str(PATH_DATASET / "train"),
    train_resolver=resolver,
    train_transform=train_transform,
    val_resolver=resolver,
    val_transform=val_transform,
    batch_size=64,
    num_workers=os.cpu_count(),
    val_split=0.1,
)

### Define Classif. model

Flash offers rich collection od backbones for image classification and simple plug-in integration with [TorchMetrics](https://torchmetrics.readthedocs.io/en/stable/)

In [None]:
model = ImageClassifier(
    backbone="efficientnet_b3",
    backbone_kwargs={"in_chans": 1},
    metrics=[MeanAbsoluteError()],
    pretrained=True,
    num_classes=1,
    learning_rate=1e-5,
    loss_fn=nn.BCEWithLogitsLoss(),
)

### Training

it is based on standard [Lightning](https://pytorch-lightning.readthedocs.io/en/stable) trainer

In [None]:
import pytorch_lightning as pl
logger = pl.loggers.CSVLogger(save_dir='logs/')

trainer = flash.Trainer(
    max_epochs=10,
    gpus=torch.cuda.device_count(),
    logger=logger,
    val_check_interval=0.5,
    precision=16,
)
trainer.fit(model, datamodule=datamodule)

trainer.save_checkpoint("audio_classification_model.pt")

### Logged stats

Let's check some training statistic, how los and defined metrics evolved over time/epochs

In [None]:
import pandas as pd

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
display(metrics.head())

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[['train_bcewithlogitsloss_step']].plot(grid=True, legend=True, xlabel=agg_col)
df_metrics[['train_meanabsoluteerror_step', 'val_meanabsoluteerror']].plot(grid=True, legend=True, xlabel=agg_col)
# df_metrics[['train_f1_step', 'train_recall_step', 'train_precision_step', 'val_f1', 'val_recall', 'val_precision']].plot(grid=True, legend=True, xlabel=agg_col)

## Predictions

With trained model we just need run inference to prepeare dibission file

In [None]:
import torch

from flash.core.data.process import Serializer


class IdPredictionSerializer(Serializer):

    def serialize(self, sample):
        preds = sample["preds"]
        preds = torch.tensor(preds).to(torch.float)
        preds = preds.sigmoid().item()
        filepath = sample["metadata"]["filepath"]
        file_id = os.path.basename(filepath).split(".")[0]
        return {"id": file_id, "target": preds}

In [None]:
import glob
from tqdm import tqdm
from itertools import chain

found_npy = glob.glob(str(PATH_DATASET / "test" / "*" / "*.npy"))

predict_transform = {
    "pre_tensor_transform": nn.Sequential(
        ApplyToKeys("input", preprocess),
        ApplyToKeys("input", AlbumentationsAdapter(albumentations.Resize(224, 224))),
    ),
    "to_tensor_transform": ApplyToKeys("input", torchvision.transforms.ToTensor()),
}

datamodule = AudioClassificationData.from_files(
    predict_files=found_npy,
    predict_transform=predict_transform,
    batch_size=1024,
    num_workers=os.cpu_count(),
)

model.serializer = IdPredictionSerializer()

submission = trainer.predict(model, datamodule=datamodule)

submission = list(chain.from_iterable(submission))
pd.DataFrame(submission).set_index("id").to_csv("submission.csv")