# Run in Google Colab

In [None]:
!pip install tsai
!pip install geopandas
!pip install geojson
!pip install pytorch_lightning
!pip install neptune-client

In [None]:
!git clone https://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@github.com/yuasosnin/aihacks-2022-fields

# Imoprts, data and setup

In [None]:
%cd aihacks-2022-fields

In [None]:
import numpy as np
import pandas as pd
import geopandas as gpd

from src import process_data, read_data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import NeptuneLogger

from src import EnsembleVotingModel, StackKFoldDataModule, StackTransformer
from src.torch_utils.lightning import KFoldLoop, PrintMetricsCallback

In [None]:
pl.seed_everything(5)
with open('api_key') as f:
    API_KEY = f.read()

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
# delete bad perm krai outlier
filt = np.ones(4830)
filt[2932] = 0
filt = filt.astype(bool)

In [None]:
data = read_data('data/train_dataset_train_2.csv')[filt]
data_test = read_data('data/test_dataset_test_2.csv')
data_ts, data_id = process_data(data)
data_id['alt'] = pd.read_csv('data/altitude_train.csv')['alt'].tolist()
data_ts_test, data_id_test = process_data(data_test)
data_id_test['alt'] = pd.read_csv('data/altitude_test.csv')['alt'].tolist()

In [None]:
data_ts_modis = pd.read_csv('data/train_dataset_modis.csv').fillna(0.0)[filt]
data_ts_modis_test = pd.read_csv('data/test_dataset_modis.csv').fillna(0.0)
data_ts_modis_evi = pd.read_csv('data/train_dataset_modis_evi.csv').fillna(0.0)[filt]
data_ts_modis_test_evi = pd.read_csv('data/test_dataset_modis_evi.csv').fillna(0.0)

In [None]:
data_ts_landsat = pd.read_csv('data/train_dataset_landsat.csv').fillna(0.0)[filt]
data_ts_landsat_test = pd.read_csv('data/test_dataset_landsat.csv').fillna(0.0)
data_ts_landsat_evi = pd.read_csv('data/train_dataset_landsat_evi.csv').fillna(0.0)[filt]
data_ts_landsat_test_evi = pd.read_csv('data/test_dataset_landsat_evi.csv').fillna(0.0)

In [None]:
data_ts_sentinel = pd.read_csv('data/train_dataset_sentinel.csv').fillna(0)[filt]
data_ts_sentinel_test = pd.read_csv('data/test_dataset_sentinel.csv').fillna(0)
data_ts_sentinel_evi = pd.read_csv('data/train_dataset_sentinel_evi.csv').fillna(0)[filt]
data_ts_sentinel_test_evi = pd.read_csv('data/test_dataset_sentinel_evi.csv').fillna(0)

In [None]:
def tensor_stack(*dfs):
    return torch.tensor(np.concatenate([df[:,None,:] for df in dfs], axis=1), dtype=torch.float32)

In [None]:
ds = tensor_stack(data_ts.values)
ds_test = tensor_stack(data_ts_test.values)

ds_modis = tensor_stack(data_ts_modis.values)#, data_ts_modis_evi.values)
ds_modis_test = tensor_stack(data_ts_modis_test.values)#, data_ts_modis_test_evi.values)

ds_landsat = tensor_stack(data_ts_landsat.values)#, data_ts_landsat_evi.values)
ds_landsat_test = tensor_stack(data_ts_landsat_test.values)#, data_ts_landsat_test_evi.values)

ds_sentinel = tensor_stack(data_ts_sentinel.values)#, data_ts_sentinel_evi.values)
ds_sentinel_test = tensor_stack(data_ts_sentinel_test.values)#, data_ts_sentinel_test_evi.values)

ds_const = torch.tensor(data_id[['area', 'lat', 'lon', 'alt']].values, dtype=torch.float32)
ds_const_test = torch.tensor(data_id_test[['area', 'lat', 'lon', 'alt']].values, dtype=torch.float32)

ds_y = torch.tensor(data_id['crop'].values, dtype=torch.long)

In [None]:
train_dataset = TensorDataset(ds, ds_modis, ds_landsat, ds_sentinel, ds_const, ds_y)
pred_dataset = TensorDataset(ds_test, ds_modis_test, ds_landsat_test, ds_sentinel_test, ds_const_test)

In [None]:
c_ins = [t.shape[1] for t in pred_dataset.tensors[:-1]]
c_const_in = pred_dataset.tensors[-1].shape[1]
seq_lens = [t.shape[2] for t in pred_dataset.tensors[:-1]]

# CV training

In [None]:
pl.seed_everything(5)
pl_model = StackTransformer(
    c_ins=c_ins,
    seq_lens=seq_lens,
    d_model=64,
    nhead=16,
    dim_feedforward=64,
    d_head=64,
    num_layers=4,
    num_head_layers=2,
    dropout=0.2,
    fc_dropout=0.5,
    activation=nn.GELU,
    reduction='avg',
    const=True,
    c_in_const=c_const_in,
    num_const_leayers=2,
    lr=0.0001,
    wd=0.00001,
    gamma=0.99)

In [None]:
pl_data = StackKFoldDataModule(
    train_dataset=train_dataset, 
    pred_dataset=pred_dataset,
    const=True,
    batch_size=64,
    seed=5)

In [None]:
best_checkpointer = ModelCheckpoint(
    save_top_k=1, save_last=True, monitor='valid_recall', mode='max', filename='best')
neptune_logger = NeptuneLogger(
    api_key=API_KEY, project='fant0md/aihacks-2022-fields', log_model_checkpoints=False)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
printer = PrintMetricsCallback(
    metrics=['valid_recall', 'train_recall', 'valid_loss', 'train_loss'])

trainer = pl.Trainer(
    log_every_n_steps=1,
    logger=neptune_logger,
    callbacks=[best_checkpointer, lr_monitor, printer],
    max_epochs=100,
    accelerator='auto',
    devices=1)

internal_fit_loop = trainer.fit_loop
trainer.fit_loop = KFoldLoop(
    ensemble_model=EnsembleVotingModel, num_folds=8, checkpoint_type='last')
trainer.fit_loop.connect(internal_fit_loop)

In [None]:
trainer.fit(pl_model, pl_data)

# submission

In [None]:
# TODO
# run maaaany epochs with dropout
# im not gonna lose!

In [None]:
ckpt_paths = trainer.fit_loop.checkpoint_paths
infer_model = EnsembleVotingModel(StackTransformer, ckpt_paths)
infer_model.freeze()

In [None]:
# ckpt_paths = [x.replace('last', 'best') for x in ckpt_paths]
# infer_model = EnsembleVotingModel(StackTransformer, ckpt_paths)
# infer_model.freeze()
# trainer.test(infer_model, pl_data.test_dataloader())

In [None]:
preds = trainer.predict(infer_model, pl_data.predict_dataloader())

In [None]:
submission = pd.read_csv('sample_solution.csv')
submission['crop'] = torch.cat(preds).argmax(1)

In [None]:
version = 'act'

In [None]:
submission.to_csv(f'submission_{version}.csv', index=False)

In [None]:
import shutil
shutil.make_archive(
    f'checkpoints_{version}', 'zip',
    '/content/aihacks-2022-fields/.neptune/None/version_None/checkpoints')