In [1]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import xarray as xr
from dataset import ClimateHackDataset
from basic_model import BasicModel
from loss import MS_SSIMLoss
import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Setup dataset and dataloader

In [36]:
DATASET_PATH = "eumetsat_seviri_hrv_uk.zarr"
#DATASET_PATH = "gs://public-datasets-eumetsat-solar-forecasting/satellite/EUMETSAT/SEVIRI_RSS/v3/eumetsat_seviri_hrv_uk.zarr"

dataset = xr.open_dataset(
    DATASET_PATH,
    engine='zarr',
    chunks='auto'
)

Training Setup

In [4]:
BATCH_LEN = 16
EPOCHS = 10

In [32]:
cl_dataset = ClimateHackDataset(dataset, crops_per_slice=1, day_limit=7)
cl_loader = DataLoader(cl_dataset, batch_size=BATCH_LEN)
len(list(cl_dataset))

0

Create the model

In [6]:
model = BasicModel()
model.to(device)

BasicModel(
  (layer0): Linear(in_features=196608, out_features=1024, bias=True)
  (layer1): Linear(in_features=1024, out_features=1024, bias=True)
  (layer2): Linear(in_features=1024, out_features=98304, bias=True)
  (relu): ReLU()
  (softmax): Softmax(dim=-1)
)

Training utilities
# wow

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = MS_SSIMLoss(channels=24)

Training loop definition

In [8]:
dir_to_save = 'basic_models'
if not os.path.isdir(dir_to_save):
    os.makedirs(dir_to_save)

In [16]:
def train_epoch(dataloader, model, optimizer, criterion, losses, epoch):
    its_til_checkpoint = 10
    
    running_loss = 0
    count = 0
    iterable = iter(dataloader)
    for i, (coords, features, targets) in enumerate(iterable):
        features = features.to(device)
        targets = targets.to(device)

        predictions = model(features)

        loss = criterion(predictions, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        curr_len = features.shape[0]
        running_loss += curr_len * loss.item()
        count += curr_len

        if i % its_til_checkpoint == its_til_checkpoint-1:
            curr_loss = running_loss / count
            print(f'Current loss after {its_til_checkpoint} iterations: {curr_loss}')
            losses.append(curr_loss)

            plt.plot(losses, 'b-')
            plt.title("Loss over iterations")
            plt.show()

            file_name = f'epoch-{epoch}_iteration-{i}'
            torch.save(model.state_dict(), open(os.path.join(dir_to_save, file_name), 'w+'))

            running_loss = 0
            count = 0
    
    losses.append(running_loss / count)

In [21]:
losses = []
train_epoch(cl_loader, model, optimizer, criterion, losses, 0)

ZeroDivisionError: division by zero

In [None]:
def train(epochs, dataloader, model, optimizer, criterion):
    losses = []
    for epoch in range(epochs):
        print(f'Entering epoch {epoch}')
        train_epoch(dataloader, model, optimizer, criterion)
    
    return losses

In [None]:
train(EPOCHS, cl_loader, model, optimizer, criterion)