In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import wandb
import torch
import random
import pickle
import imageio
import librosa
import torchvision

import numpy as np
import pandas as pd
import torchmetrics as tm 
import plotly.express as px
import pytorch_lightning as pl
import matplotlib.pyplot as plt

from torch import nn
from pathlib import Path, PurePath
from IPython.display import Audio
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, AdamW, RMSprop # optmizers
from warmup_scheduler import GradualWarmupScheduler
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau # Learning rate schedulers

import albumentations as A
# from albumentations.pytorch import ToTensorV2

import timm

ModuleNotFoundError: No module named 'albumentations'

In [None]:
print('timm version', timm.__version__)
print('torch version', torch.__version__)

In [None]:
# print(os.getenv('wandb_api_key'))

In [None]:
wandb.login(key=os.getenv('wandb_api_key'))

In [None]:
# detect and define device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

### Config

In [None]:
train_dir = Path('E:\data\BirdCLEF')

In [None]:
class CFG:
    project = 'Bird-local-3'
    comment = 'clean-10'
    
    MIXUP = True
    USE_SCHD = False
    USE_UL = False
    # USE_MISSING_LABELS = False
    USE_SECONDARY = False
    USE_MISSING_LABELS = USE_SECONDARY
    USE_UPSAMPLE = False

    UL_THRESH = 0.05
    up_thr = 50
    
    # Competition Root Folder
    ROOT_FOLDER = train_dir
    birds_csv = train_dir / 'bird_preds.csv'
    AUDIO_FOLDER = train_dir / 'train_audio'
    DATA_DIR = train_dir / 'spectros'
    TRAIN_CSV = train_dir / 'train_metadata.csv'
    UNLABELED_CSV = train_dir / 'predictions.csv'
    RESULTS_DIR = train_dir / 'results'
    CKPT_DIR = RESULTS_DIR / 'ckpt'

    num_workers = 8
    # Maximum decibel to clip audio to
    # TOP_DB = 100
    TOP_DB = 80
    # Minimum rating
    MIN_RATING = 3.0
    # Sample rate as provided in competition description
    # SR = 32000
    SR = 32000

    image_size = 128
    
    ### split train and validation sets
    split_fraction = 0.95
    
    ### model
    model_name = 'eca_nfnet_l0' # 'resnet34', 'resnet200d', 'efficientnet_b1_pruned', 'efficientnetv2_m', efficientnet_b7 ...  
    
    ### training
    BATCH_SIZE = 128

    ### Optimizer
    N_EPOCHS = 240
    WARM_EPOCHS = 3
    COS_EPOCHS = N_EPOCHS - WARM_EPOCHS
    
    # LEARNING_RATE = 5*1e-5 # best
    LEARNING_RATE = 5e-5
    
    weight_decay = 1e-6 # for adamw

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    random_seed = 42

mel_spec_params = {
    "sample_rate": CFG.SR,
    "n_mels": 128,
    "f_min": 150,
    "f_max": CFG.SR / 2,
    "n_fft": 2048,
    "hop_length": 512,
    "normalized": True,
    "center" : True,
    "pad_mode" : "constant",
    "norm" : "slaney",
    "mel_scale" : "slaney"
}

CFG.mel_spec_params = mel_spec_params

sec_labels = ['lotshr1', 'orhthr1', 'magrob', 'indwhe1', 'bltmun1', 'asfblu1']

sample_submission = pd.read_csv(train_dir / 'sample_submission.csv')

# Set labels
CFG.LABELS = sample_submission.columns[1:].tolist()
if CFG.USE_MISSING_LABELS:
    CFG.LABELS += sec_labels
    
CFG.N_LABELS = len(CFG.LABELS)
print(f'# labels: {CFG.N_LABELS}')

bird2id = {b: i for i, b in enumerate(CFG.LABELS)}

display(sample_submission.head())

In [None]:
CFG.N_LABELS

In [None]:
# for reproducibility
def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed = CFG.random_seed)

In [None]:
meta_df = pd.read_csv(CFG.TRAIN_CSV)
meta_df.head(2)

In [None]:
columns = ['primary_label', 'secondary_labels', 'filename', 'file', 'duration']

In [None]:
meta_df['file'] = meta_df.apply(lambda row: row['filename'].split('/')[-1], axis=1)
meta_df['filename'] = fr'{str(CFG.AUDIO_FOLDER)}/' + meta_df['filename']

In [None]:
meta_df[columns].head(2)

In [None]:
meta_df.iloc[0].filename

### Load data

In [None]:
from dataset import bird_dataset2

In [None]:
dset = bird_dataset2(meta_df, CFG)

print(dset.__len__())

spect, label, = dset.__getitem__(0)
print(spect.shape, label.shape)
print(spect.dtype, label.dtype)

In [None]:
label.sum()

In [None]:
# interv.intersect.sum()

In [None]:
librosa.display.specshow(spect[0].numpy(), y_axis="mel", x_axis='s', sr=CFG.SR, cmap='gray')
plt.show()

In [None]:
librosa.display.specshow(spect[1].numpy(), y_axis="mel", x_axis='s', sr=CFG.SR, cmap='gray')
plt.show()

### Data Module

In [None]:
from dataset import bird_dataset2

In [None]:
class wav_datamodule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, cfg=CFG, train_tfs=None, val_tfs=None):
        super().__init__()
        
        self.train_df = train_df
        self.val_df = val_df
        
        self.train_bs = cfg.BATCH_SIZE
        self.val_bs = cfg.BATCH_SIZE

        self.train_tfs = train_tfs
        self.val_tfs = val_tfs

        self.cfg = cfg
        
        self.num_workers = cfg.num_workers
        
    def train_dataloader(self):
        train_ds = bird_dataset2(self.train_df, self.cfg, tfs=self.train_tfs, mode='train')
        
        train_loader = torch.utils.data.DataLoader(
            train_ds,
            batch_size=self.train_bs,
            pin_memory=False,
            drop_last=False,
            shuffle=True,
            persistent_workers=True,
            num_workers=self.num_workers,
        )
        
        return train_loader
        
    def val_dataloader(self):
        val_ds = bird_dataset2(self.val_df, self.cfg, tfs=self.val_tfs, mode='val')
        
        val_loader = torch.utils.data.DataLoader(
            val_ds,
            batch_size=self.val_bs,
            pin_memory=False,
            drop_last=False,
            shuffle=False,
            persistent_workers=True,
            num_workers=2,
        )
        
        return val_loader

In [None]:
image_size = CFG.image_size

train_tfs = A.Compose([
    # A.HorizontalFlip(p=0.5),
    A.Resize(image_size, image_size),
    A.GaussNoise(p=0.5),
    # A.XYMasking(num_masks_x=4, num_masks_y=2, mask_x_length=26, mask_y_length=26, p=0.7),
    A.CoarseDropout(max_height=int(image_size * 0.375), max_width=int(image_size * 0.375), max_holes=1, p=0.7),
    A.Normalize()
])

val_tfs = A.Compose([
    A.Resize(image_size, image_size),
    A.Normalize()
])

In [None]:
t_df = meta_df[:-100]
# t_df = pd.concat([meta_df[:-100], ul_df[:-100]], ignore_index=True)
v_df = meta_df[-100:]

CFG2 = CFG()
CFG2.BATCH_SIZE = 16
CFG2.num_workers = 2

dm = wav_datamodule(t_df, v_df, cfg=CFG2)
# dm = wav_datamodule(t_df, v_df, cfg=CFG, train_tfs=train_tfs, val_tfs=val_tfs)

x, y = next(iter(dm.train_dataloader()))
x.shape, y.shape, x.dtype, y.dtype

In [None]:
# librosa.display.specshow(x[0].numpy()[0], y_axis="mel", x_axis='s', sr=CFG.SR)
# plt.show()

In [None]:
librosa.display.specshow(x[2].numpy()[0], y_axis="mel", x_axis='s', sr=CFG.SR, cmap='gray')
plt.show()

In [None]:
dm = wav_datamodule(t_df, v_df, cfg=CFG, train_tfs=train_tfs, val_tfs=val_tfs)

x, y = next(iter(dm.train_dataloader()))
x.shape, y.shape, x.dtype, y.dtype

In [None]:
librosa.display.specshow(x[0].numpy()[0], y_axis="mel", x_axis='s', sr=CFG.SR)
plt.show()

In [None]:
librosa.display.specshow(x[1].numpy()[0], y_axis="mel", x_axis='s', sr=CFG.SR)
plt.show()

In [None]:
librosa.display.specshow(x[2].numpy()[0], y_axis="mel", x_axis='s', sr=CFG.SR)
plt.show()

In [30]:
# img = x[0]
# img.shape, img.unsqueeze(dim=0).numpy().shape, img.expand(3, -1, -1).shape

In [31]:
# img.expand(3, -1, -1).permute(1, 2, 0).shape, img.expand(3, -1, -1).permute(1, 2, 0).numpy().transpose(2,0,1).shape

In [32]:
# train_tfs(image=img.numpy())

In [33]:
del dm

### Loss function

In [34]:
class FocalLossBCE(torch.nn.Module):
    def __init__(
            self,
            alpha: float = 0.25,
            gamma: float = 2,
            reduction: str = "mean",
            bce_weight: float = 1.0,
            focal_weight: float = 1.0,
    ):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        self.bce_weight = bce_weight
        self.focal_weight = focal_weight

    def forward(self, logits, targets):
        focall_loss = torchvision.ops.focal_loss.sigmoid_focal_loss(
            inputs=logits,
            targets=targets,
            alpha=self.alpha,
            gamma=self.gamma,
            reduction=self.reduction,
        )
        bce_loss = self.bce(logits, targets)
        return self.bce_weight * bce_loss + self.focal_weight * focall_loss

In [35]:
class GeM(torch.nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = torch.nn.Parameter(torch.ones(1) * p)
        self.eps = eps

    def forward(self, x):
        bs, ch, h, w = x.shape
        x = torch.nn.functional.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(
            1.0 / self.p)
        x = x.view(bs, ch)
        return x

### Optimizer

In [36]:
# Fix Warmup Bug
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
        
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                    
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
            
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

### Model

In [37]:
print('Number of models available: ', len(timm.list_models(pretrained=True)))
print('Number of models available: ', len(timm.list_models()))
print('\ models: ', timm.list_models('eca_nfnet_*'))

Number of models available:  1329
Number of models available:  1032
\ models:  ['eca_nfnet_l0', 'eca_nfnet_l1', 'eca_nfnet_l2', 'eca_nfnet_l3']


In [38]:
backbone = 'eca_nfnet_l1'
# backbone = 'efficientnet_b4'
out_indices = (3, 4)

model = timm.create_model(
    backbone,
    features_only=True,
    pretrained=False,
    in_chans=3,
    num_classes=5,
    # out_indices=out_indices,
    )

model.feature_info.channels(), np.sum(model.feature_info.channels())

([64, 256, 512, 1536, 3072], 5440)

In [39]:
def mixup(data, targets, alpha, device):
    indices = torch.randperm(data.size(0))
    data2 = data[indices]
    targets2 = targets[indices]

    lam = torch.FloatTensor([np.random.beta(alpha, alpha)]).to(device)
    data = data * lam + data2 * (1 - lam)
    
    targets = targets * lam + targets2 * (1 - lam)
    return data, targets

    # data += data2
    # targets += targets2
    # return data, targets.clip(max=1)

In [40]:
class GeMModel(pl.LightningModule):
    def __init__(self, cfg = CFG, pretrained = True):
        super().__init__()

        self.cfg = cfg
        
        out_indices = (3, 4)

        self.criterion = FocalLossBCE()

        self.train_acc = tm.classification.MulticlassAccuracy(num_classes=self.cfg.N_LABELS)
        self.val_acc = tm.classification.MulticlassAccuracy(num_classes=self.cfg.N_LABELS)

        # self.train_acc = tm.classification.MultilabelAccuracy(num_labels=self.cfg.N_LABELS)
        self.val_macc = tm.classification.MultilabelAccuracy(num_labels=self.cfg.N_LABELS)

        self.train_auroc = tm.classification.MulticlassAUROC(num_classes=self.cfg.N_LABELS)
        self.val_auroc = tm.classification.MulticlassAUROC(num_classes=self.cfg.N_LABELS)

        # self.model_name = self.cfg.model_name
        print(self.cfg.model_name)
        
        self.backbone = timm.create_model(
            self.cfg.model_name, 
            features_only=True,
            pretrained=pretrained,
            in_chans=3,
            num_classes=self.cfg.N_LABELS,
            out_indices=out_indices,
        )

        feature_dims = self.backbone.feature_info.channels()

        self.global_pools = torch.nn.ModuleList([GeM() for _ in out_indices])
        self.mid_features = np.sum(feature_dims)
        
        self.neck = torch.nn.BatchNorm1d(self.mid_features)
        self.head = torch.nn.Linear(self.mid_features, self.cfg.N_LABELS)

    def forward(self, x):
        ms = self.backbone(x)
        
        h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
        x = self.neck(h)
        x = self.head(x)
        
        return x
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(model.parameters(), lr=self.cfg.LEARNING_RATE, weight_decay=CFG.weight_decay)
        
        if self.cfg.USE_SCHD:
            scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.cfg.COS_EPOCHS)
            scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=self.cfg.WARM_EPOCHS, after_scheduler=scheduler_cosine)

            return [optimizer], [scheduler_warmup]
        else:
            # LRscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.2)
            
            # return [optimizer], [LRscheduler]
            return optimizer

    def step(self, batch, batch_idx, mode='train'):
        x, y = batch

        if self.cfg.MIXUP and mode == 'train':
            x, y = mixup(x, y, 0.5, self.cfg.device)
        
        preds = self(x)
        
        loss = self.criterion(preds, y)
        
        if mode == 'train':
            self.train_acc(preds, y.argmax(1))
            # self.train_auroc(preds, y.argmax(1))
        else:
            self.val_acc(preds, y.argmax(1))
            self.val_macc(preds, y)
            # self.val_auroc(preds, y.argmax(1))
        
        self.log(f'{mode}/loss', loss, on_step=True, on_epoch=True)
        # self.log(f'{mode}/kl_loss', kl_loss, on_step=True, on_epoch=True)

        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, mode='train')
        self.log(f'train/acc', self.train_acc, on_step=True, on_epoch=True)
        # self.log(f'train/auroc', self.train_auroc, on_step=True, on_epoch=True)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx, mode='val')
        self.log(f'val/acc', self.val_acc, on_step=True, on_epoch=True)
        self.log(f'val/macc', self.val_macc, on_step=True, on_epoch=True)
        # self.log(f'val/auroc', self.val_auroc, on_step=True, on_epoch=True)
    
        return loss
    
    def on_train_epoch_end(self):
        self.train_acc.reset()

    def on_validation_epoch_end(self):
        self.val_acc.reset()
        self.val_macc.reset()

In [41]:
model = GeMModel(CFG)

eca_nfnet_l0


In [42]:
foo = model(x)

In [43]:
foo.shape

torch.Size([128, 182])

### Split

In [44]:
from sklearn.model_selection import ShuffleSplit, StratifiedShuffleSplit

In [45]:
def upsample_data(df, thr=20):
    # get the class distribution
    class_dist = df['primary_label'].value_counts()

    # identify the classes that have less than the threshold number of samples
    down_classes = class_dist[class_dist < thr].index.tolist()

    # create an empty list to store the upsampled dataframes
    up_dfs = []

    # loop through the undersampled classes and upsample them
    for c in down_classes:
        # get the dataframe for the current class
        class_df = df.query("primary_label==@c")
        # find number of samples to add
        num_up = thr - class_df.shape[0]
        # upsample the dataframe
        class_df = class_df.sample(n=num_up, replace=True, random_state=CFG.random_seed)
        # append the upsampled dataframe to the list
        up_dfs.append(class_df)

    # concatenate the upsampled dataframes and the original dataframe
    up_df = pd.concat([df] + up_dfs, axis=0, ignore_index=True)
    
    return up_df

In [46]:
sss = StratifiedShuffleSplit(n_splits=1, test_size=1-CFG.split_fraction, random_state=CFG.random_seed)
train_idx, val_idx = next(sss.split(meta_df.filename, meta_df.primary_label))

t_df = meta_df.iloc[train_idx]
v_df = meta_df.iloc[val_idx]

print(t_df.shape)

if not CFG.USE_SECONDARY:
    t_df = t_df[t_df['secondary_labels'] == '[]']

if CFG.USE_UPSAMPLE:
    t_df = upsample_data(t_df, thr=CFG.up_thr)

if CFG.USE_UL:
    t_df = pd.concat([t_df, ul_df], ignore_index=True)

t_df.shape, v_df.shape

(23236, 14)


((21433, 14), (1223, 14))

In [47]:
# t_df = t_df[t_df['rating'] > 1]
# t_df.shape

In [48]:
short_df = t_df[t_df['duration'] < 10]
short_df.shape, short_df.primary_label.nunique()

((4570, 14), 178)

In [49]:
short_labels = short_df.primary_label.unique().tolist()

missing = list(set(CFG.LABELS) - set(short_labels))
extra = list(set(short_labels) - set(CFG.LABELS))

len(short_labels), len(missing), len(extra)

(178, 4, 0)

In [50]:
missing

['crfbar1', 'wynlau1', 'whbsho3', 'grehor1']

In [51]:
missing_df = t_df[t_df['primary_label'].isin(missing)]
missing_df.primary_label.value_counts()

primary_label
grehor1    34
whbsho3    21
crfbar1    11
wynlau1     6
Name: count, dtype: int64

In [52]:
foo_df = missing_df[missing_df['duration'] < 40]
foo_df.shape, foo_df.primary_label.value_counts()

((38, 14),
 primary_label
 grehor1    21
 crfbar1    10
 whbsho3     5
 wynlau1     2
 Name: count, dtype: int64)

In [53]:
t_df.shape, short_df.shape, foo_df.shape

((21433, 14), (4570, 14), (38, 14))

In [54]:
t_df = pd.concat([short_df, foo_df], axis=0)
t_df.shape

(4608, 14)

### Train

In [55]:
# dm = wav_datamodule(t_df,v_df)
dm = wav_datamodule(t_df, v_df, CFG, train_tfs=train_tfs, val_tfs=val_tfs) 

In [56]:
CFG.BATCH_SIZE

128

In [57]:
len(dm.train_dataloader()), len(dm.val_dataloader())

(36, 10)

In [58]:
model = GeMModel(CFG)

eca_nfnet_l0


In [59]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, LearningRateMonitor

In [60]:
run_name = f'{CFG.model_name} {CFG.LEARNING_RATE} {CFG.N_EPOCHS} eps {CFG.comment}'

In [61]:
wandb_logger = WandbLogger(
    name=run_name,
    project=CFG.project,
    job_type='train',
    save_dir=CFG.RESULTS_DIR,
    # config=cfg,
)

In [62]:
loss_ckpt = pl.callbacks.ModelCheckpoint(
    monitor='val/loss',
    auto_insert_metric_name=False,
    dirpath=CFG.CKPT_DIR / run_name,
    filename='ep_{epoch:02d}_loss_{val/loss:.5f}',
    every_n_epochs=8,
    save_top_k=2,
    mode='min',
)

loss_ckpt2 = pl.callbacks.ModelCheckpoint(
    # monitor='val/loss',
    # auto_insert_metric_name=False,
    dirpath=CFG.CKPT_DIR / run_name,
    # filename='ep_{epoch:02d}_loss_{val/loss:.5f}',
    every_n_epochs=8,
    # save_top_k=2,
    # mode='min',
)

In [63]:
acc_ckpt = pl.callbacks.ModelCheckpoint(
    monitor='val/acc',
    auto_insert_metric_name=False,
    dirpath=CFG.CKPT_DIR / run_name,
    filename='ep_{epoch:02d}_acc_{val/acc:.5f}',
    save_top_k=2,
    mode='max',
)

In [64]:
lr_monitor = LearningRateMonitor(logging_interval='step')

In [65]:
CFG.device

'cuda'

In [66]:
trainer = pl.Trainer(
    max_epochs=CFG.N_EPOCHS,
    deterministic=True,
    accelerator=CFG.device,
    default_root_dir=CFG.RESULTS_DIR,
    gradient_clip_val=0.5, 
    # gradient_clip_algorithm="value",
    check_val_every_n_epoch=5,
    logger=wandb_logger,
    callbacks=[loss_ckpt, acc_ckpt, lr_monitor],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [67]:
trainer.fit(model, dm)

You are using a CUDA device ('NVIDIA GeForce RTX 4090 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type               | Params
----------------------------------------------------
0 | criterion    | FocalLossBCE       | 0     
1 | train_acc    | MulticlassAccuracy | 0     
2 | val_acc      | MulticlassAccuracy | 0     
3 | val_macc     | MultilabelAccuracy | 0     
4 | train_auroc  | MulticlassAUROC    | 0     
5 | val_auroc    | MulticlassAUROC    | 0     
6 | backbone     | FeatureListNet     | 21.8 M
7 | global_pools | ModuleList         | 2     
8 | neck         | BatchNorm1d        | 7.7 K 
9 | head         | Linear             | 699 K 
----------------------------------------------------
22.5 M    Trainable params
0         Non-trainable params
22.5 M    Total params
90.183    Total estimated model params size (MB)


Sanity Checking: |                                                                        | 0/? [00:00<?, ?it/…

C:\Users\Asus\.conda\envs\llms\lib\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (36) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                               | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

Validation: |                                                                             | 0/? [00:00<?, ?it/…

`Trainer.fit` stopped: `max_epochs=240` reached.


In [68]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr-Adam,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/acc_epoch,▁▁▁▁▁▁▂▂▂▃▄▄▅▆▆▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇██▇█████
train/acc_step,▁▁▁▁▁▂▂▃▃▄▅▅▅▆▆▇▆▇▇▆▆▇▆▆▇▆▇█▇▇▇▇▇▆▇▇▇▆▇█
train/loss_epoch,█▆▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_step,█▆▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▁▂▁▂▁▃▃▁▃▁▄▁▄▁▄▅▅▁▁▅▅▆▁▆▁▆▁▇▁▇▇▁█▁█▁
val/acc_epoch,▁▁▁▁▁▂▃▄▅▆▇▇▇▇█▇▇▇▇██▇███████▇███▇██▇█▇█
val/acc_step,▁▁▁▂▂▂▄▅▆▇▆▇▆▇▇▇▇▆█▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇█▇▆█
val/loss_epoch,█▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,239.0
lr-Adam,5e-05
train/acc_epoch,0.56093
train/acc_step,0.42922
train/loss_epoch,0.01945
train/loss_step,0.02333
trainer/global_step,8639.0
val/acc_epoch,0.34583
val/acc_step,0.47266
val/loss_epoch,0.02037


### Predict

In [69]:
x, y = next(iter(dm.train_dataloader()))

In [70]:
foo = model(x)
# foo = model(x.to(CFG.device))
foo.shape

torch.Size([128, 182])

In [71]:
y[1]

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.])

In [72]:
foo.sigmoid().topk(3,dim=-1)

torch.return_types.topk(
values=tensor([[0.9670, 0.0125, 0.0089],
        [0.8242, 0.0426, 0.0233],
        [0.9425, 0.0108, 0.0101],
        [0.0849, 0.0691, 0.0676],
        [0.9149, 0.0091, 0.0090],
        [0.9079, 0.0596, 0.0369],
        [0.4304, 0.2121, 0.1969],
        [0.3939, 0.0879, 0.0486],
        [0.3369, 0.0516, 0.0472],
        [0.9209, 0.0062, 0.0051],
        [0.9492, 0.0561, 0.0172],
        [0.9530, 0.0130, 0.0116],
        [0.9393, 0.0277, 0.0100],
        [0.9686, 0.0121, 0.0117],
        [0.9416, 0.0247, 0.0216],
        [0.9733, 0.0154, 0.0105],
        [0.8123, 0.0098, 0.0091],
        [0.9109, 0.0275, 0.0201],
        [0.8941, 0.0108, 0.0082],
        [0.1612, 0.0824, 0.0722],
        [0.9070, 0.0127, 0.0102],
        [0.0804, 0.0724, 0.0669],
        [0.1031, 0.0749, 0.0723],
        [0.0739, 0.0652, 0.0603],
        [0.8196, 0.0531, 0.0507],
        [0.7812, 0.0147, 0.0104],
        [0.9646, 0.0166, 0.0090],
        [0.0764, 0.0713, 0.0701],
        [0.1481,

In [73]:
topk = foo.sigmoid().topk(3,dim=-1)

In [74]:
vals = topk[0].detach().numpy()
idx = topk[1].detach().numpy()
vals.shape, idx.shape

((128, 3), (128, 3))

In [75]:
# idx, vals

In [76]:
np.concatenate([vals,idx], axis=-1).shape

(128, 6)

In [77]:
torch.nn.functional.softmax(foo[0], dim=-1)

tensor([9.0485e-06, 3.7186e-05, 6.9245e-06, 2.8890e-05, 7.4366e-05, 5.7888e-06,
        8.1787e-06, 3.6819e-06, 2.0124e-06, 1.6958e-04, 7.2844e-05, 1.1508e-05,
        9.6789e-06, 3.4729e-05, 6.2720e-05, 3.4647e-05, 2.6148e-05, 9.9267e-01,
        1.6048e-05, 2.6983e-05, 1.2380e-05, 3.7459e-06, 1.0328e-05, 5.0724e-06,
        2.9185e-05, 6.7577e-06, 1.2446e-05, 3.9360e-05, 2.8123e-05, 1.8150e-05,
        8.6618e-06, 1.7170e-05, 5.6242e-06, 1.0515e-04, 1.0609e-05, 4.5380e-06,
        2.0474e-05, 4.2801e-04, 6.1276e-05, 2.9585e-04, 2.4093e-04, 3.5689e-05,
        1.5127e-05, 1.1204e-05, 3.0596e-04, 1.2075e-05, 5.6250e-05, 8.1111e-06,
        2.1951e-05, 2.7671e-05, 2.6034e-05, 1.3223e-05, 8.6438e-06, 1.5183e-04,
        4.1561e-06, 1.7744e-04, 7.9636e-06, 2.2835e-04, 5.0331e-06, 9.7821e-05,
        3.3174e-05, 3.5365e-06, 2.1331e-04, 3.9997e-05, 4.8452e-05, 2.6205e-04,
        2.1295e-05, 2.1324e-05, 2.4811e-06, 2.2991e-06, 3.2479e-05, 1.3173e-04,
        4.3004e-05, 3.2632e-05, 9.0304e-

In [78]:
torch.nn.functional.softmax(foo, dim=-1).max(dim=-1)

torch.return_types.max(
values=tensor([0.9927, 0.9290, 0.9921, 0.0576, 0.9865, 0.9688, 0.2474, 0.5941, 0.5422,
        0.9895, 0.9889, 0.9893, 0.9890, 0.9951, 0.9884, 0.9942, 0.9552, 0.9681,
        0.9778, 0.1941, 0.9908, 0.0529, 0.0721, 0.0470, 0.9100, 0.9435, 0.9905,
        0.0486, 0.1141, 0.9882, 0.9043, 0.0693, 0.0725, 0.9771, 0.9836, 0.9916,
        0.9798, 0.0522, 0.1377, 0.1026, 0.9838, 0.9826, 0.9874, 0.9755, 0.9912,
        0.1648, 0.9904, 0.9900, 0.9878, 0.9835, 0.1425, 0.0507, 0.9909, 0.9860,
        0.9867, 0.6258, 0.0498, 0.3531, 0.9758, 0.8579, 0.0607, 0.1185, 0.9905,
        0.1322, 0.0648, 0.9439, 0.9504, 0.9892, 0.0580, 0.2974, 0.8125, 0.9770,
        0.9746, 0.6183, 0.0655, 0.9946, 0.0550, 0.9936, 0.1522, 0.0635, 0.0459,
        0.3631, 0.9908, 0.9904, 0.1558, 0.0470, 0.0575, 0.2643, 0.8787, 0.9872,
        0.7546, 0.9599, 0.9858, 0.7708, 0.9924, 0.9623, 0.9736, 0.9503, 0.0950,
        0.9748, 0.9938, 0.3057, 0.4859, 0.0595, 0.0544, 0.3069, 0.0647, 0.9855,
        0

In [79]:
torch.nn.functional.softmax(foo, dim=-1).argmax(dim=-1)

tensor([ 17, 141, 177,  44, 143,  33,  62,  65, 138,  39,  39, 108,  82, 106,
         14, 176,  17,  43,  39,  40,  65,  44,  44,  40, 115,   2, 107,  44,
         57,  81, 110,  44, 123,  82,  37,  98,  66,  71,  33,  40,  27, 106,
        105, 172, 107,  53, 107,  37,  17, 115,  65,  44,  13,  39,   1,   2,
         44, 157,  71, 154,  44,  65,  50, 156,  44, 109,  65, 105,  44,  39,
         82,  57,  74, 166,  65,  70,  40, 106,  71,  44,  40,  76,  46, 167,
         45,  40,  44, 143,  38,  71,  64,  98,  40,  57,  40, 138,  82,  53,
        107, 106, 166, 143,  35, 106,  44,  39,  44,  40,  82,  40,  55,  65,
        106, 131,  14,  65,  53,  14, 167,  57,  53,  45, 177,  65, 107,  44,
         10,  50])

In [80]:
y.argmax(dim=-1)

tensor([ 17, 141, 177, 120, 143,  33, 100,  65, 138,  39,  39, 108,  82, 106,
         14, 176,  17,  43,  39,  40,  65, 158,  10,  43, 115,   2, 107,  49,
         13,  81, 110, 177,  39,  82,  37,  98,  66,   7,  76,  33,  27, 106,
        105, 172, 107, 142, 107,  37,  17, 115,  65, 177,  13,  39,   1,   2,
         40, 157,  71, 154,   4,  59,  50, 156, 132, 109,  65, 105,  68,  39,
         82,  57,  74, 166,  13,  70, 105, 106,  40,  44, 107,  76,  46, 167,
         65, 144,  53, 164,  38,  71,  64,  98,  40,  57,  40, 138,  82,  53,
         63, 106, 166, 143,  65, 122,  40,  53,  53,  40,  82,  40,  55,  53,
        106, 131,  14,  10,  53,  14, 167,  57,  53,  45, 100, 106,  57, 111,
         10, 132])