In [None]:
!pip install tsai
!pip install geopandas
!pip install geojson
!pip install pytorch_lightning
!pip install neptune-client

In [6]:
!git clone https://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@github.com/yuasosnin/aihacks-2022-fields

Cloning into 'aihacks-2022-fields'...
remote: Enumerating objects: 100, done.[K
remote: Counting objects: 100% (73/73), done.[K
remote: Compressing objects: 100% (53/53), done.[K
remote: Total 100 (delta 29), reused 63 (delta 19), pack-reused 27[K
Receiving objects: 100% (100/100), 59.21 MiB | 16.54 MiB/s, done.
Resolving deltas: 100% (32/32), done.
Checking out files: 100% (44/44), done.


In [1]:
%cd aihacks-2022-fields

/content/aihacks-2022-fields


In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import os, datetime
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd
import geopandas as gpd
import geojson
# import shapely
# import shapely.geometry
from sklearn.metrics import recall_score

In [2]:
import matplotlib.pyplot as plt
# import contextily

In [3]:
from src import *

In [4]:
import warnings
warnings.filterwarnings('ignore')

In [5]:
data = read_data('data/train_dataset_train_2.csv')
data_test = read_data('data/test_dataset_test_2.csv')
data_ts, data_id = process_data(data)
data_ts_test, data_id_test = process_data(data_test)

In [6]:
data_ts_modis = pd.read_csv('data/train_dataset_modis.csv').fillna(0.0)
data_ts_modis_test = pd.read_csv('data/test_dataset_modis.csv').fillna(0.0)
data_ts_modis_2020 = pd.read_csv('data/train_dataset_modis_2020.csv').fillna(0.0)
data_ts_modis_test_2020 = pd.read_csv('data/test_dataset_modis_2020.csv').fillna(0.0)

In [7]:
data_ts_landsat = pd.read_csv('data/train_dataset_landsat.csv').fillna(0)
data_ts_landsat_test = pd.read_csv('data/test_dataset_landsat.csv').fillna(0)
data_ts_landsat_2020 = pd.read_csv('data/train_dataset_landsat_2020.csv').fillna(0)
data_ts_landsat_test_2020 = pd.read_csv('data/test_dataset_landsat_2020.csv').fillna(0)

In [8]:
data_ts_sentinel = pd.read_csv('data/train_dataset_sentinel.csv').fillna(0)
data_ts_sentinel_test = pd.read_csv('data/test_dataset_sentinel.csv').fillna(0)
data_ts_sentinel_2020 = pd.read_csv('data/train_dataset_sentinel_2020.csv').fillna(0)
data_ts_sentinel_test_2020 = pd.read_csv('data/test_dataset_sentinel_2020.csv').fillna(0)

In [9]:
train_dataframes = [
    data_ts, 
    data_ts_modis,
    data_ts_modis_2020,
    data_ts_landsat,
    data_ts_landsat_2020,
    data_ts_sentinel,
    data_ts_sentinel_2020]
pred_dataframes = [
    data_ts_test, 
    data_ts_modis_test,
    data_ts_modis_test_2020,
    data_ts_landsat_test,
    data_ts_landsat_test_2020,
    data_ts_sentinel_test,
    data_ts_sentinel_test_2020]

# neural

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger
from src.torch_utils.lightning import PrintMetricsCallback

In [22]:
pl.seed_everything(2)

Global seed set to 2


2

In [12]:
with open('.api_key') as f:
    API_KEY = f.read()

In [13]:
from src import StackRNN, StackTransformer, StackInception#, StackDataModule

### Transformer

In [63]:
best_checkpointer = ModelCheckpoint(save_top_k=1, save_last=True, monitor='valid_recall', mode='max', filename='best')
neptune_logger = NeptuneLogger(
    api_key=API_KEY,
    project='fant0md/aihacks-2022-fields')
lr_monitor = LearningRateMonitor(logging_interval='epoch')
printer = PrintMetricsCallback(metrics=['valid_recall', 'train_recall', 'valid_loss', 'train_loss'])

trainer = pl.Trainer(
    log_every_n_steps=1, 
    logger=neptune_logger, 
    callbacks=[best_checkpointer, lr_monitor, printer], 
    max_epochs=100, 
    accelerator='auto',
    devices=1)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [64]:
pl.seed_everything(2)
pl_model = StackTransformer(
    seq_lens=[70, 139, 139, 18, 17, 55, 55],
    d_model=64, 
    nhead=16, 
    dim_feedforward=64, 
    d_head=64, 
    num_layers=4, 
    num_head_layers=1, 
    dropout=0.2, 
    fc_dropout=0.5,
    activation='relu', 
    reduction='avg', 
    lr=0.0001, wd=0, 
    # T_0=5, T_mult=1,
    gamma=0.99
)
pl_data = StackDataModule(train_dataframes, pred_dataframes, data_id['crop'], batch_size=64)

In [65]:
trainer.fit(pl_model, pl_data)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name         | Type             | Params
--------------------------------------------------
0 | models       | ModuleList       | 2.8 M 
1 | pool         | AvgReduce        | 0     
2 | act          | ReLU             | 0     
3 | head         | MLP              | 4.6 K 
4 | criterion    | CrossEntropyLoss | 0     
5 | train_recall | Recall           | 0     
6 | valid_recall | Recall           | 0     
--------------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.023    Total estimated model params size (MB)


https://app.neptune.ai/fant0md/aihacks-2022-fields/e/AIH-152
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


Sanity Checking: 0it [00:00, ?it/s]

epoch: -1
valid_recall: 0.1953125
valid_loss: 1.9390144348144531
--------------------------------------------------------------------------------


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

epoch: 0
valid_recall: 0.5113871693611145
train_recall: 0.6666666865348816
valid_loss: 1.7085974216461182
train_loss: 1.604060173034668
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 1
valid_recall: 0.6832298040390015
train_recall: 0.5833333134651184
valid_loss: 1.3656688928604126
train_loss: 1.412624716758728
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 2
valid_recall: 0.7556935548782349
train_recall: 0.7083333134651184
valid_loss: 1.0799596309661865
train_loss: 1.0194748640060425
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 3
valid_recall: 0.8136646151542664
train_recall: 0.7916666865348816
valid_loss: 0.8388142585754395
train_loss: 0.8575155735015869
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 4
valid_recall: 0.8757764101028442
train_recall: 0.875
valid_loss: 0.6669543981552124
train_loss: 0.6065910458564758
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 5
valid_recall: 0.8964803218841553
train_recall: 0.7083333134651184
valid_loss: 0.5384373664855957
train_loss: 0.8630786538124084
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 6
valid_recall: 0.9233954548835754
train_recall: 0.9166666865348816
valid_loss: 0.4172093868255615
train_loss: 0.41811463236808777
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 7
valid_recall: 0.9275362491607666
train_recall: 0.9583333134651184
valid_loss: 0.35300305485725403
train_loss: 0.2752988934516907
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 8
valid_recall: 0.9378882050514221
train_recall: 0.875
valid_loss: 0.27338600158691406
train_loss: 0.503095805644989
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 9
valid_recall: 0.9440993666648865
train_recall: 0.9583333134651184
valid_loss: 0.22306089103221893
train_loss: 0.2897878885269165
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 10
valid_recall: 0.9523809552192688
train_recall: 0.7916666865348816
valid_loss: 0.20407713949680328
train_loss: 0.33103707432746887
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 11
valid_recall: 0.9523809552192688
train_recall: 0.9583333134651184
valid_loss: 0.17233756184577942
train_loss: 0.2779725193977356
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 12
valid_recall: 0.9585921168327332
train_recall: 1.0
valid_loss: 0.14891596138477325
train_loss: 0.19033145904541016
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 13
valid_recall: 0.95652174949646
train_recall: 1.0
valid_loss: 0.1457735300064087
train_loss: 0.053930412977933884
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 14
valid_recall: 0.9606625437736511
train_recall: 1.0
valid_loss: 0.13015717267990112
train_loss: 0.119759202003479
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 15
valid_recall: 0.9627329111099243
train_recall: 1.0
valid_loss: 0.11774288862943649
train_loss: 0.09124702960252762
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 16
valid_recall: 0.9648033380508423
train_recall: 0.9583333134651184
valid_loss: 0.11183672398328781
train_loss: 0.23574556410312653
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 17
valid_recall: 0.9627329111099243
train_recall: 0.9583333134651184
valid_loss: 0.1058899313211441
train_loss: 0.14039133489131927
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 18
valid_recall: 0.9668737053871155
train_recall: 1.0
valid_loss: 0.09671875089406967
train_loss: 0.033966850489377975
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 19
valid_recall: 0.9648033380508423
train_recall: 0.9583333134651184
valid_loss: 0.09056002646684647
train_loss: 0.16965530812740326
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 20
valid_recall: 0.9648033380508423
train_recall: 1.0
valid_loss: 0.08607729524374008
train_loss: 0.11761762946844101
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 21
valid_recall: 0.9689440727233887
train_recall: 1.0
valid_loss: 0.0855131596326828
train_loss: 0.056194812059402466
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 22
valid_recall: 0.9648033380508423
train_recall: 0.9583333134651184
valid_loss: 0.08775588124990463
train_loss: 0.2036583572626114
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 23
valid_recall: 0.9668737053871155
train_recall: 1.0
valid_loss: 0.08007890731096268
train_loss: 0.07919421046972275
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 24
valid_recall: 0.9668737053871155
train_recall: 0.9583333134651184
valid_loss: 0.07762632519006729
train_loss: 0.12499453872442245
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 25
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.07790300250053406
train_loss: 0.04335656762123108
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 26
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.07378595322370529
train_loss: 0.014678813517093658
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 27
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.07495799660682678
train_loss: 0.007868985645473003
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 28
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.07507506012916565
train_loss: 0.06038275361061096
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 29
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.07212501019239426
train_loss: 0.07776468247175217
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 30
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06606133282184601
train_loss: 0.047325775027275085
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 31
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.07277316600084305
train_loss: 0.03904421254992485
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 32
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06531418859958649
train_loss: 0.051141735166311264
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 33
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06554281711578369
train_loss: 0.05543528124690056
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 34
valid_recall: 0.9710144996643066
train_recall: 0.9583333134651184
valid_loss: 0.07372785359621048
train_loss: 0.08098366111516953
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 35
valid_recall: 0.9730848670005798
train_recall: 0.9583333134651184
valid_loss: 0.06817658990621567
train_loss: 0.10386791080236435
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 36
valid_recall: 0.9751552939414978
train_recall: 0.9583333134651184
valid_loss: 0.06490536779165268
train_loss: 0.08189582824707031
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 37
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.0643659457564354
train_loss: 0.011806427501142025
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 38
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06285632401704788
train_loss: 0.036192283034324646
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 39
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06139880791306496
train_loss: 0.025045258924365044
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 40
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06329711526632309
train_loss: 0.0199909470975399
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 41
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06511963158845901
train_loss: 0.019670134410262108
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 42
valid_recall: 0.9751552939414978
train_recall: 0.9583333134651184
valid_loss: 0.061410900205373764
train_loss: 0.07648005336523056
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 43
valid_recall: 0.9710144996643066
train_recall: 0.9583333134651184
valid_loss: 0.06465962529182434
train_loss: 0.07109063118696213
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 44
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06320013850927353
train_loss: 0.013015675358474255
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 45
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06329954415559769
train_loss: 0.0607793927192688
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 46
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06002414599061012
train_loss: 0.029872680082917213
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 47
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06448059529066086
train_loss: 0.009528719820082188
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 48
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.057952918112277985
train_loss: 0.025623993948101997
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 49
valid_recall: 0.977225661277771
train_recall: 1.0
valid_loss: 0.06045970693230629
train_loss: 0.026175307109951973
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 50
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.05891821160912514
train_loss: 0.022055188193917274
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 51
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06612921506166458
train_loss: 0.03686944767832756
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 52
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06420598924160004
train_loss: 0.017624082043766975
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 53
valid_recall: 0.9689440727233887
train_recall: 1.0
valid_loss: 0.06335237622261047
train_loss: 0.03824140504002571
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 54
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06428202986717224
train_loss: 0.012930411845445633
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 55
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.0629815086722374
train_loss: 0.0070348079316318035
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 56
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06530751287937164
train_loss: 0.013465913943946362
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 57
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06056903675198555
train_loss: 0.02149958908557892
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 58
valid_recall: 0.9668737053871155
train_recall: 1.0
valid_loss: 0.06385119259357452
train_loss: 0.02766316384077072
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 59
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.060789551585912704
train_loss: 0.013234184123575687
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 60
valid_recall: 0.9689440727233887
train_recall: 1.0
valid_loss: 0.07195231318473816
train_loss: 0.004268682096153498
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 61
valid_recall: 0.9689440727233887
train_recall: 1.0
valid_loss: 0.06679915636777878
train_loss: 0.003261161968111992
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 62
valid_recall: 0.9668737053871155
train_recall: 1.0
valid_loss: 0.07283247262239456
train_loss: 0.006758018862456083
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 63
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.061539653688669205
train_loss: 0.032257843762636185
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 64
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.05778868868947029
train_loss: 0.06136311590671539
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 65
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06858698278665543
train_loss: 0.01668465882539749
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 66
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.07245229184627533
train_loss: 0.015693241730332375
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 67
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.07142851501703262
train_loss: 0.009799388237297535
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 68
valid_recall: 0.9710144996643066
train_recall: 0.9583333134651184
valid_loss: 0.06202417612075806
train_loss: 0.058558497577905655
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 69
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06273555755615234
train_loss: 0.005547623615711927
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 70
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.06799265742301941
train_loss: 0.012293371371924877
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 71
valid_recall: 0.9689440727233887
train_recall: 1.0
valid_loss: 0.06679040938615799
train_loss: 0.0010445552179589868
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 72
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.0685654804110527
train_loss: 0.007739942986518145
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 73
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06115289032459259
train_loss: 0.0023190223146229982
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 74
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06454069912433624
train_loss: 0.006566172931343317
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 75
valid_recall: 0.977225661277771
train_recall: 1.0
valid_loss: 0.060753028839826584
train_loss: 0.009758570231497288
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 76
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.0687670186161995
train_loss: 0.014846417121589184
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 77
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06304667145013809
train_loss: 0.0021138531155884266
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 78
valid_recall: 0.977225661277771
train_recall: 1.0
valid_loss: 0.07213591784238815
train_loss: 0.03859402611851692
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 79
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.07586905360221863
train_loss: 0.020041827112436295
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 80
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.0720469281077385
train_loss: 0.03933575749397278
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 81
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.0729513168334961
train_loss: 0.0030328098218888044
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 82
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06847476214170456
train_loss: 0.004654310178011656
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 83
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.07445622235536575
train_loss: 0.0024280440993607044
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 84
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.07705342024564743
train_loss: 0.01194778922945261
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 85
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06963016837835312
train_loss: 0.008414790965616703
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 86
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.07223418354988098
train_loss: 0.0003099218592979014
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 87
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06864862889051437
train_loss: 0.014194001443684101
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 88
valid_recall: 0.977225661277771
train_recall: 1.0
valid_loss: 0.0695272758603096
train_loss: 0.016016865149140358
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 89
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06022220849990845
train_loss: 0.0008619360160082579
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 90
valid_recall: 0.977225661277771
train_recall: 1.0
valid_loss: 0.06452665477991104
train_loss: 0.012724784202873707
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 91
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06901643425226212
train_loss: 0.008833418600261211
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 92
valid_recall: 0.9710144996643066
train_recall: 1.0
valid_loss: 0.07644544541835785
train_loss: 0.03145936504006386
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 93
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06674356758594513
train_loss: 0.029054006561636925
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 94
valid_recall: 0.977225661277771
train_recall: 1.0
valid_loss: 0.05939200893044472
train_loss: 0.005556520540267229
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 95
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.07075627893209457
train_loss: 0.023409275338053703
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 96
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06687701493501663
train_loss: 0.0011143825249746442
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 97
valid_recall: 0.9730848670005798
train_recall: 1.0
valid_loss: 0.06927713006734848
train_loss: 0.01743386872112751
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 98
valid_recall: 0.9751552939414978
train_recall: 1.0
valid_loss: 0.06133008375763893
train_loss: 0.024144647642970085
--------------------------------------------------------------------------------


Validation: 0it [00:00, ?it/s]

epoch: 99
valid_recall: 0.979296088218689
train_recall: 1.0
valid_loss: 0.064743772149086
train_loss: 0.005552377551794052
--------------------------------------------------------------------------------


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


In [66]:
preds = torch.cat(trainer.predict(pl_model, pl_data.test_dataloader())).argmax(1)
print(recall_score(preds, dataset_orig['test']['y'], average='macro'))

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 61it [00:00, ?it/s]

0.978189789239679


### CV

In [14]:
from src.torch_utils.lightning import KFoldLoop#, BaseKFoldDataModule

In [15]:
from src import StackKFoldDataModule, EnsembleVotingModel

In [16]:
from src import StackDataModule

In [17]:
pl.seed_everything(2)
pl_model = StackTransformer(
    seq_lens=[70, 139, 139, 18, 17, 55, 55],
    d_model=64, 
    nhead=4, 
    dim_feedforward=64, 
    d_head=64, 
    num_layers=2, 
    num_head_layers=1, 
    dropout=0.2, 
    fc_dropout=0.5,
    activation='relu', 
    reduction='avg', 
    lr=0.0001, wd=0, 
    # T_0=5, T_mult=1,
    gamma=0.99)
pl_data = StackKFoldDataModule(train_dataframes, pred_dataframes, data_id['crop'], batch_size=64)

In [20]:
pl.seed_everything(2)
best_checkpointer = ModelCheckpoint(save_top_k=1, save_last=True, monitor='valid_recall', mode='max', filename='best')
neptune_logger = NeptuneLogger(
    api_key=API_KEY,
    project='fant0md/aihacks-2022-fields',
    log_model_checkpoints=False)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
printer = PrintMetricsCallback(metrics=['valid_recall', 'train_recall', 'valid_loss', 'train_loss'])

trainer = pl.Trainer(
    log_every_n_steps=1, 
    logger=neptune_logger, 
    callbacks=[best_checkpointer, lr_monitor, printer], 
    max_epochs=10,
    # limit_train_batches=2,
    # limit_val_batches=2,
    # limit_test_batches=2,
    # num_sanity_val_steps=0,
    accelerator='auto',
    devices=1)

internal_fit_loop = trainer.fit_loop
trainer.fit_loop = KFoldLoop(ensemble_model=EnsembleVotingModel, num_folds=4, export_path="lightning_logs/")
trainer.fit_loop.connect(internal_fit_loop)
trainer.fit(pl_model, pl_data)

Global seed set to 2
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name         | Type             | Params
--------------------------------------------------
0 | models       | ModuleList       | 2.4 M 
1 | pool         | AvgReduce        | 0     
2 | act          | ReLU             | 0     
3 | head         | MLP              | 4.6 K 
4 | criterion    | CrossEntropyLoss | 0     
5 | train_recall | Recall           | 0     
6 | valid_recall | Recall           | 0     
7 | test_recall  | Recall           | 0     
--------------------------------------------------
2.4 M     Trainable params
0         Non-trainable params
2.4 M     Total params
9.625     Total estimated model params size (MB)


https://app.neptune.ai/fant0md/aihacks-2022-fields/e/AIH-156


Info (NVML): NVML Shared Library Not Found. GPU usage metrics may not be reported. For more information, see https://docs.neptune.ai/you-should-know/what-can-you-log-and-display#hardware-consumption


Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


Sanity Checking: 0it [00:00, ?it/s]

epoch: -1
valid_recall: 0.484375
valid_loss: 1.7639729976654053
--------------------------------------------------------------------------------
<class 'src.dataset.StackKFoldDataModule'>
STARTING FOLD 0


Training: 0it [00:00, ?it/s]

## submission

In [None]:
# TODO
# initizlize paramters
# run maaaany epochs with dropout
# flatten instead of pooling
# cosine lr scheduler

# try early stopping
# try to include geographical information
# im not gonna lose!

In [None]:
pl_model = StackTransformer.load_from_checkpoint('.neptune/None/version_None/checkpoints/best-v7.ckpt')

In [67]:
preds = trainer.predict(pl_model, pl_data)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 61it [00:00, ?it/s]

In [68]:
submission = pd.read_csv('sample_solution.csv')
submission['crop'] = torch.cat(preds).argmax(1)

In [69]:
submission.to_csv('submission_v864.csv', index=False)

In [None]:
pl_model

In [None]:
# valid: 0.982401
# lb: 0.980064

In [None]:
# valid: 0.983436
# lb: 0.978759