In [4]:
import gc
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn import model_selection

import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl

from numpy import genfromtxt
import datatable as dt
import pandas as pd

In [8]:
inputDate='2019-01-03 01:00:00'

In [14]:
def data_load(inputDate,variable=['temp','prec'],path=r'/hdd/temp/',nrow=1159,ncol=1505):
    df=np.empty((nrow,ncol,len(variable)),dtype=float)
    for k,var in enumerate(variable):
        data=dt.fread(f'{path}/{var}{str(inputDate)}.csv')
        df[:,:,k]=data.to_numpy().reshape(nrow,ncol)[::-1,::]
    return df

In [18]:
def data_gen(inputDate,shape,DateRange=pd.to_timedelta(1,unit='day')):
    generateDate=[str(i) for i in \
                  pd.date_range(start=pd.to_datetime(inputDate)-\
                                DateRange,end=inputDate , freq='H')]
    df=np.empty(shape,dtype=float)
    for l, td in enumerate(generateDate):
        df[l,:,:,:]=data_load(inputDate)
    return df

In [22]:
DateRange=pd.to_timedelta(1,unit='day')
generateDate=[str(i) for i in pd.date_range(start=pd.to_datetime(inputDate)-DateRange,end=inputDate , freq='H')]

In [26]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dates, test=False):
        self.test = test
#         self.dates=dates
        self.dataDate=pd.DataFrame(dates)
    def __len__(self):
        return len(self.dataDate)

    def __getitem__(self, idx,nrow=1159,ncol=1505,variable=['temp','prec']):
        date = self.dataDate.T.to_dict('index')[0][idx]
        data = data_gen(date,shape=(25,1159,1505,len(variable)))
        x = data[:-1,:,:,:]
        x = x.astype(np.float32)
        x = torch.tensor(x, dtype=torch.float)
        if self.test:
            return x
        else:
            y = df_[-1:,:,:,:]
            y = y.astype(np.float32)
            y = torch.tensor(y, dtype=torch.float)
            y = y.unsqueeze(-1)
            return x, y
        
    def data_load(inputDate,path=r'/hdd/temp/',variable=['temp','prec'],nrow=1159,ncol=1505):
        df=np.empty((nrow,ncol,len(variable)),dtype=float)
        for k,var in enumerate(variable):
            print(inputDate)
            data=dt.fread(f'{path}/{var}{str(inputDate)}.csv')
            df[:,:,k]=data.to_numpy().reshape(nrow,ncol)[::-1,::]
        return df
    
    def data_gen(inputDate,shape,DateRange=pd.to_timedelta(1,unit='day')):
        generateDate=[str(i) for i in pd.date_range(start=pd.to_datetime(inputDate)-DateRange,end=inputDate , freq='H')]
        df=np.empty(shape,dtype=float)
        for l, td in enumerate(generateDate):
            df[l,:,:,:]=data_load(inputDate)
        return df

In [30]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, test=False, num_workers=4):
        super().__init__()
        self.test = test
        self.batch_size = batch_size
        self.num_workers = 4

    def setup(self,generateDate, stage="train"):
        if stage == "train":
            train_paths, val_paths = model_selection.train_test_split(
                generateDate, test_size=0.1, shuffle=True
            )
            self.train_dataset = Dataset(train_paths)
            self.val_dataset = Dataset(val_paths)
        else:
            self.test_dataset = Dataset(generateDate, test=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=2 * self.batch_size,
            pin_memory=True,
            num_workers=self.num_workers,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=2 * self.batch_size,
            pin_memory=True,
            num_workers=self.num_workers,
        )

In [34]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

In [38]:
class Encoder(nn.Module):
    def __init__(self, chs=[4, 64, 128]):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.blocks = nn.ModuleList(
            [Block(chs[i], chs[i + 1]) for i in range(len(chs) - 1)]
        )
        self.conv = nn.Conv2d(128, 512, kernel_size=3, padding=1)

    def forward(self, x):
        ftrs = []
        for block in self.blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        x = self.conv(x)
        ftrs.append(x)
        return ftrs

In [42]:
class Decoder(nn.Module):
    def __init__(self, chs=[512, 128, 64]):
        super().__init__()
        self.tr_convs = nn.ModuleList(
            [
                nn.ConvTranspose2d(chs[i], chs[i + 1], kernel_size=2, stride=2)
                for i in range(len(chs) - 1)
            ]
        )
        self.blocks = nn.ModuleList(
            [Block(2 * chs[i + 1], chs[i + 1]) for i in range(len(chs) - 1)]
        )

    def forward(self, x, ftrs):
        for i, ftr in enumerate(ftrs):
            x = self.tr_convs[i](x)
            x = torch.cat([ftr, x], dim=1)
            x = self.blocks[i](x)
        return x

In [46]:
class Baseline(pl.LightningModule):
    def __init__(self, lr=1e-3, enc_chs=[4, 64, 128], dec_chs=[512, 128, 64]):
        super().__init__()
        self.lr = lr
        self.criterion = nn.L1Loss()
        self.encoder = Encoder(enc_chs)
        self.decoder = Decoder(dec_chs)
        self.out = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        ftrs = self.encoder(x)
        ftrs = ftrs[::-1]
        x = self.decoder(ftrs[0], ftrs[1:])
        out = self.out(x)
        return out

    def shared_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log("val_loss", loss)
        return {"loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        print(f"Epoch {self.current_epoch} | MAE: {avg_loss}")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [50]:
datamodule = DataModule(batch_size=32)
datamodule.setup(generateDate)

In [54]:
model = Baseline()

In [58]:
trainer = pl.Trainer(
    gpus=1, max_epochs=10, precision=16, progress_bar_refresh_rate=50, benchmark=True
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [65]:
trainer.fit(model, datamodule)


  | Name      | Type       | Params
-----------------------------------------
0 | criterion | L1Loss     | 0     
1 | encoder   | Encoder    | 666 K 
2 | decoder   | Decoder    | 664 K 
3 | out       | Sequential | 577   


Validation sanity check: 0it [00:00, ?it/s]

ParserError: Caught ParserError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/pandas/core/arrays/datetimes.py", line 1979, in objects_to_datetime64ns
    values, tz_parsed = conversion.datetime_to_datetime64(data)
  File "pandas/_libs/tslibs/conversion.pyx", line 200, in pandas._libs.tslibs.conversion.datetime_to_datetime64
TypeError: Unrecognized value type: <class 'str'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/xpython_4025623/3210728914.py", line 11, in __getitem__
    data = data_gen(date,shape=(25,1159,1505,len(variable)))
  File "/tmp/xpython_4025623/464845066.py", line 3, in data_gen
    pd.date_range(start=pd.to_datetime(inputDate)-\
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/pandas/util/_decorators.py", line 208, in wrapper
    return func(*args, **kwargs)
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/pandas/core/tools/datetimes.py", line 796, in to_datetime
    result = convert_listlike(np.array([arg]), box, format)[0]
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/pandas/core/tools/datetimes.py", line 463, in _convert_listlike_datetimes
    allow_object=True,
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/pandas/core/arrays/datetimes.py", line 1984, in objects_to_datetime64ns
    raise e
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/pandas/core/arrays/datetimes.py", line 1975, in objects_to_datetime64ns
    require_iso8601=require_iso8601,
  File "pandas/_libs/tslib.pyx", line 465, in pandas._libs.tslib.array_to_datetime
  File "pandas/_libs/tslib.pyx", line 688, in pandas._libs.tslib.array_to_datetime
  File "pandas/_libs/tslib.pyx", line 822, in pandas._libs.tslib.array_to_datetime_object
  File "pandas/_libs/tslib.pyx", line 813, in pandas._libs.tslib.array_to_datetime_object
  File "pandas/_libs/tslibs/parsing.pyx", line 225, in pandas._libs.tslibs.parsing.parse_datetime_string
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/dateutil/parser/_parser.py", line 1374, in parse
    return DEFAULTPARSER.parse(timestr, **kwargs)
  File "/root/anaconda3/envs/jupyter/lib/python3.6/site-packages/dateutil/parser/_parser.py", line 649, in parse
    raise ParserError("Unknown string format: %s", timestr)
dateutil.parser._parser.ParserError: Unknown string format: i
