# Model

In [2]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
import os
import rasterio
import pickle

device = (torch.device('cuda') if torch.cuda.is_available()
          else torch.device('cpu'))

print(f"Training on {device}.")

### Import data

In [3]:
def import_training_imgs(train_folder):
    train_data = dict()

    for file in os.listdir(train_folder):
        if file.endswith('.tif'):
            with rasterio.open(train_folder + file) as src:
                train_data.update({ file : src.read() })
    return train_data

if os.path.exists('loaded/train_images.pkl'):
    train_imgs = pickle.load(open('loaded/train_images.pkl', 'rb'))
else:
    train_imgs = import_training_imgs('data/train_images/')
    pickle.dump(train_imgs, open('loaded/train_images.pkl', 'wb'))

### Training function

In [4]:
def train(n_epochs, optimizer, model, loss_fn, train_loader, val_loader):

    n_train_batch = len(train_loader)
    losses_train = []

    n_val_batch = len(val_loader)
    losses_val = []

    model.train()
    optimizer.zero_grad()
    model = model.to(device)
    
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        loss_val = 0.0

        for imgs, labels in train_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            model.train()

            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()
        losses_train.append(loss_train / n_train_batch)

        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs = imgs.to(device)
                labels = labels.to(device)
                model.eval()

                outputs = model(imgs)
                loss = loss_fn(outputs, labels)
                loss_val += loss.item()

            losses_val.append(loss_val / n_val_batch)

        if epoch == 1 or epoch % 5 == 0:
            print(f'--------- Epoch: {epoch} ---------')
            print('Training loss {:.5f}'.format(loss_train / n_train_batch))
            print('Validation loss {:.5f}'.format(loss_val / n_val_batch))
            print()

    return losses_train, losses_val

###  A simple Convolutional Network

In [4]:
class SimpleConvNet(nn.Module):

    #Architecture
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=12, out_channels=6, kernel_size=(5,5), padding=2)  
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=1, kernel_size=(3,3), padding=1)  
        self.pool2 = nn.MaxPool2d(2)
        self.flat = nn.Flatten()
        self.fc = nn.Linear(in_features=256 * 256, out_features=1024 * 1024)
        self.a = nn.Sigmoid()

    #Forward Pass:  
    def forward(self, x): 
        out = self.conv1(x)   
        out = torch.relu(self.pool(out))
        out = self.conv2(out)
        out = self.flat(out) 
        out = torch.relu(self.fc(out))  
        out = self.a(out)
        return out

In [8]:
model = SimpleConvNet()

In [None]:
print("hei")
image = train_imgs['image_1.tif']
print(image.shape)
model(image)

### Hyperparameters and training

In [8]:
# Define hyperparameters
n_epochs = 10
lr = 1e-2
model = SimpleConvNet()
optimizer = optim.SGD(model.parameters(), lr=lr)
loss_fn = None
train_loader = {k: train_imgs[k] for k in list(train_imgs)[:10]}
val_loader = None

In [None]:
losses_train, losses_val = train(
        n_epochs = n_epochs,
        optimizer = optimizer,
        model = model,
        loss_fn = loss_fn,
        train_loader = train_loader,
        val_loader=val_loader
    )

# Plot the loss
xvalues = range(1, n_epochs + 1)
train_yvalues = losses_train
val_yvalues = losses_val

fig, ax = plt.subplots()

ax.plot(xvalues, train_yvalues, label='training')
ax.plot(xvalues, val_yvalues, label='validation')
ax.set_title(f'Training and Validation loss for {model.__class__.__name__}')
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.legend()
plt.show()
