# Model

In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import nn, optim
import matplotlib.pyplot as plt
import os
import rasterio
import pickle

### Import data

In [2]:
def import_training_imgs(train_folder):
    x_train= []

    for file in os.listdir(train_folder):
        if file.endswith('.tif'):
            file_path = os.path.join(train_folder, file)
            with rasterio.open(file_path) as src:
                image_np = src.read()  
                image_tensor = torch.from_numpy(image_np).float()
                x_train.append(image_tensor)
    return x_train

if os.path.exists('./data/loaded/train_images.pkl'):
    print('Loading training images from pickle file...')
    x_train = pickle.load(open('./data/loaded/train_images.pkl', 'rb'))
else:
    print('Importing training images...')
    x_train = import_training_imgs('data/train_images/')
    pickle.dump(x_train, open('data/loaded/train_images.pkl', 'wb'))

Loading training images from pickle file...


In [3]:
def import_training_labels(train_folder):
    y_train = []

    for file in os.listdir(train_folder):
        if file.endswith('.tif'):
            file_path = os.path.join(train_folder, file)
            with rasterio.open(file_path) as src:
                image_np = src.read()
                image_tensor = torch.from_numpy(image_np)
                y_train.append(image_tensor)
    return y_train

In [4]:
y_train = import_training_labels('data/masked_annotations/')

In [5]:
x_train[1].shape

torch.Size([12, 1024, 1024])

In [12]:

x_train_tensor = torch.stack(x_train)  # Shape: [num_samples, 12, 1024, 1024]
y_train_tensor = torch.stack(y_train).squeeze(1).long()   # Shape: [num_samples, 1, 1024, 1024]

train_dataset = TensorDataset(x_train_tensor, y_train_tensor)

In [13]:
batch_size = 10

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

### Training function

In [14]:
def train(n_epochs, optimizer, model, loss_fn, train_loader, val_loader=None):

    n_train_batch = len(train_loader)
    losses_train = []
    if val_loader is not None:
        n_val_batch = len(val_loader)
        losses_val = []

    model.train()
    optimizer.zero_grad()
    
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        loss_val = 0.0

        for imgs, labels in train_loader:

            outputs = model(imgs)
            loss = loss_fn(outputs, labels.squeeze(1))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_train += loss.item()
        losses_train.append(loss_train / n_train_batch)

        if val_loader is not None:
            with torch.no_grad():
                for imgs, labels in val_loader:
                    model.eval()

                    outputs = model(imgs)
                    loss = loss_fn(outputs, labels)
                    loss_val += loss.item()

                losses_val.append(loss_val / n_val_batch)

        if epoch == 1 or epoch % 5 == 0:
            print(f'--------- Epoch: {epoch} ---------')
            print('Training loss {:.5f}'.format(loss_train / n_train_batch))
            if val_loader is not None:
                print('Validation loss {:.5f}'.format(loss_val / n_val_batch))
            print()

    return losses_train#, losses_val

###  A simple Convolutional Network

In [15]:
class SimpleConvNet(nn.Module):
    def __init__(self):
        super(SimpleConvNet, self).__init__()
        
        # Encoder: downsample from 1024x1024 to 256x256
        self.down1 = nn.Sequential(
            nn.Conv2d(in_channels=12, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)  # 1024 -> 512
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)  # 512 -> 256
        )
        
        # Decoder: upsample from 256x256 back to 1024x1024
        self.up1 = nn.ConvTranspose2d(in_channels=64, out_channels=32, 
                                      kernel_size=2, stride=2)  # 256 -> 512
        self.up2 = nn.ConvTranspose2d(in_channels=32, out_channels=5,  # output 5 channels now
                                      kernel_size=2, stride=2)  # 512 -> 1024
        
    
    def forward(self, x):
        # [batch, 12, 1024, 1024]
        x = self.down1(x)    # [batch, 32, 512, 512]
        x = self.down2(x)    # [batch, 64, 256, 256]
        x = F.relu(self.up1(x))  # [batch, 32, 512, 512]
        x = self.up2(x)      # [batch, 5, 1024, 1024]
        return x


In [16]:
model = SimpleConvNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()


In [17]:
n_epochs = 1 

losses_train = train(n_epochs, optimizer, model, loss_fn, train_loader)


--------- Epoch: 1 ---------
Training loss nan

