In [101]:
import torch
import numpy as np
import torch.nn as nn
from torch import optim
from torch.optim import AdamW
import matplotlib.pyplot as plt
from torchvision.models import ResNet
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import StepLR
from torchvision.models.resnet import ResNet

from torch.utils.data import Dataset, DataLoader

from utils.log import *
from utils.data import *

1 - Wandb logging and setup from config  
1.1 - Sort of builder from cfg
2 - Dataloader with slicing  
3 - Define some models  
4 - Wrap it into slurm tasks and add saving of experiments

In [102]:
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock1x1(nn.Module):
    """Basic Resnet block but instead of 3x3 convs we use 1x1 convs"""
    expansion = 1

    def __init__(self,
                 inplanes,
                 planes,
                 stride = 1,
                 downsample = None,
                 groups = 1,
                 base_width = 64,
                 dilation = 1,
                 norm_layer = None):
        
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv1x1(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


In [103]:
X_train_s1, y_train_s1, full_distrib_train_s1, X_val_s1, y_val_s1, full_distrib_val_s1 = load_pickled_ds('data/pickled_data/raw_train_test_splitted_s1_60.pkl')
X_train_s2, y_train_s2, full_distrib_train_s2, X_val_s2, y_val_s2, full_distrib_val_s2 = load_pickled_ds('data/pickled_data/raw_train_test_splitted_s2_60.pkl')

In [4]:
channel_order_s2 = ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B08A", "B11", "B12", "B01", "B09"]
channel_order_s1 = ["VV", "VH", "VV-VH"]

In [208]:
g = torch.Generator()
g.manual_seed(42)

bs = 512

ri = ["NDVI", "EVI", "NDWI", "GNDVI", "SAVI", "ARVI", "MSAVI"]

train_dataset_s1 = NumpyDataset(X_train_s1, y_train_s1, sentinel_number=1, band_order=channel_order_s1, requested_indices=["VV", "VH", "VV-VH"])
train_dataset_s2 = NumpyDataset(X_train_s2, y_train_s2, sentinel_number=2, band_order=channel_order_s2,
                                requested_indices=ri)


val_dataset_s1 = NumpyDataset(X_val_s1, y_val_s1, sentinel_number=1, band_order=channel_order_s1, requested_indices=["VV", "VH", "VV-VH"])
val_dataset_s2 = NumpyDataset(X_val_s2, y_val_s2, sentinel_number=2, band_order=channel_order_s2, 
                              requested_indices=ri)


train_dataloader_s1 = DataLoader(train_dataset_s1, batch_size=bs, shuffle=True, generator=g)
train_dataloader_s2 = DataLoader(train_dataset_s2, batch_size=bs, shuffle=True, generator=g)

val_dataloader_s1 = DataLoader(val_dataset_s1, batch_size=bs, shuffle=True, generator=g)
val_dataloader_s2 = DataLoader(val_dataset_s2, batch_size=bs, shuffle=True, generator=g)

In [209]:
import torch
import torch.nn as nn
import torch.optim as optim

from models import get_model

ce = CrossEntropyLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sample = next(iter(train_dataloader_s1))[0]
n_bands = sample.shape[1]
input_size = sample.shape[2]
rs = get_model('fc.MLPSmall')
rs = rs(n_classes=15, n_bands=n_bands, input_size=input_size)
rs.to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(rs.parameters(), lr=0.001, weight_decay=0.0001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

num_epochs = 30  

for epoch in range(num_epochs):
    rs.train() 

    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, _, labels in train_dataloader_s1:
        inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device, dtype=torch.long)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = rs.forward(inputs)
        
        # Compute loss
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_dataloader_s2):.4f}, Accuracy: {train_accuracy:.2f}%")

    # Validation loop
    rs.eval()  # Set model to evaluation mode
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():  # No gradient calculation during validation
        for inputs, _, labels in train_dataloader_s1:
            inputs, labels = inputs.to(device, dtype=torch.float), labels.to(device)
            
            # Forward pass
            outputs = rs(inputs)
            
            # Compute loss
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Statistics
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_accuracy = 100 * correct / total
    print(f"Validation Loss: {val_loss/len(val_dataloader_s2):.4f}, Accuracy: {val_accuracy:.2f}%")

print("Training completed.")


<module 'models.fc' from '/home/al/projects/global_rs_project/models/fc.py'> MLPSmall
64 32
Epoch [1/30], Loss: 2.6436, Accuracy: 12.81%
Validation Loss: 23.0210, Accuracy: 17.61%
Epoch [2/30], Loss: 2.5820, Accuracy: 16.64%
Validation Loss: 22.6941, Accuracy: 17.68%
Epoch [3/30], Loss: 2.5542, Accuracy: 17.36%
Validation Loss: 22.4890, Accuracy: 18.25%
Epoch [4/30], Loss: 2.5372, Accuracy: 18.43%
Validation Loss: 22.3052, Accuracy: 19.32%
Epoch [5/30], Loss: 2.5201, Accuracy: 18.96%
Validation Loss: 22.2002, Accuracy: 20.10%
Epoch [6/30], Loss: 2.5031, Accuracy: 19.64%
Validation Loss: 21.4435, Accuracy: 24.71%
Epoch [7/30], Loss: 2.3870, Accuracy: 24.48%
Validation Loss: 20.6916, Accuracy: 26.14%
Epoch [8/30], Loss: 2.3583, Accuracy: 25.12%
Validation Loss: 20.5481, Accuracy: 26.40%
Epoch [9/30], Loss: 2.3431, Accuracy: 25.19%
Validation Loss: 20.4618, Accuracy: 26.52%
Epoch [10/30], Loss: 2.3308, Accuracy: 25.60%
Validation Loss: 20.4207, Accuracy: 26.54%
Epoch [11/30], Loss: 2.3277

In [7]:
def train_and_evaluate(train_loader, val_loader, lr, weight_decay, loss_fn, metric_fn,
                       device, num_epochs, step2decay, decay_lr):
    
    
    indiced_sample, raw_sample = next(iter(train_loader))[:2]
    n_bands_raw, n_bands_indices = indiced_sample.shape[1], raw_sample.shape[1]

    rs_indiced = Resnet1x1(n_classes=15, n_bands=n_bands_indices)
    rs_raw = Resnet1x1(n_classes=15, n_bands=n_bands_raw)

    rs_indiced.model.to(device)
    rs_raw.model.to(device)

    loss_fn_raw = loss_fn().to(device)
    loss_fn_indiced = loss_fn().to(device)


    optimizer_raw = optim.AdamW(rs_raw.model.parameters(), lr=lr, weight_decay=weight_decay)
    optimizer_indiced = optim.AdamW(rs_indiced.model.parameters(), lr=lr, weight_decay=weight_decay)

    scheduler_raw = StepLR(optimizer_raw, step_size=step2decay, gamma=decay_lr)
    scheduler_indiced = StepLR(optimizer_indiced, step_size=step2decay, gamma=decay_lr)


    iter_log = IterLog()
    train_log = TrainLog()

    for epoch in range(num_epochs):
        
        iter_log.on_iter_start()

        rs_indiced.model.train() 
        rs_raw.model.train() 


        for inputs_raw, inputs_indiced, labels in train_loader:
            inputs_raw, inputs_indiced, labels = inputs_raw.to(device, dtype=torch.float), inputs_indiced.to(device, dtype=torch.float), labels.to(device, dtype=torch.long)

            optimizer_raw.zero_grad()
            optimizer_indiced.zero_grad()

            outputs_raw = rs_raw.forward(inputs_raw)
            outputs_indiced = rs_indiced.forward(inputs_indiced)

            loss_raw = loss_fn_raw(outputs_raw, labels)
            loss_raw.backward()

            loss_indiced = loss_fn_indiced(outputs_indiced, labels)
            loss_indiced.backward()

            optimizer_raw.step()
            optimizer_indiced.step()

            scheduler_raw.step()
            scheduler_indiced.step()

            _, predicted_raw = torch.max(outputs_raw, 1)
            _, predicted_indiced = torch.max(outputs_indiced, 1)

            metric_raw = metric_fn(predicted_raw, labels)
            metric_indiced = metric_fn(predicted_indiced, labels)

            iter_log.add_on_train_iter_end(train_loss_indiced=loss_indiced.item(), train_loss_raw=loss_raw.item(),
                                           train_metric_raw=metric_raw.item(), train_metric_indiced=metric_indiced.item())

        rs_raw.model.eval()
        rs_indiced.model.eval()

        with torch.no_grad(): 
            for inputs_raw, inputs_indiced, labels in val_loader:
                inputs_raw, inputs_indiced, labels = inputs_raw.to(device, dtype=torch.float), inputs_indiced.to(device, dtype=torch.float), labels.to(device, dtype=torch.long)

                outputs_raw = rs_raw.forward(inputs_raw)
                outputs_indiced = rs_indiced.forward(inputs_indiced)

                loss_raw = loss_fn_raw(outputs_raw, labels)
                loss_indiced = loss_fn_indiced(outputs_indiced, labels)

                _, predicted_raw = torch.max(outputs_raw, 1)
                _, predicted_indiced = torch.max(outputs_indiced, 1)
                
                metric_raw = metric_fn(predicted_raw, labels)
                metric_indiced = metric_fn(predicted_indiced, labels)

                iter_log.add_on_val_iter_end(val_loss_indiced=loss_indiced.item(), val_loss_raw=loss_raw.item(),
                                             val_metric_raw=metric_raw.item(), val_metric_indiced=metric_indiced.item())

        iter_log.on_epoch_end()
        print(f"Epoch [{epoch+1}/{num_epochs}], {iter_log}")

        train_log.on_epoch_end(iter_log)

    train_log.terminate()

In [8]:
def accuracy(pred, gt):
    return (pred == gt).sum() / len(pred)

In [9]:
train_and_evaluate(train_dataloader_s2, val_dataloader_s2, 0.001, 0.0001, CrossEntropyLoss, accuracy, 'cuda:0', 40, 'test', step2decay=15, decay_lr=0.1)

TypeError: train_and_evaluate() got multiple values for argument 'step2decay'

In [None]:
def train_and_evaluate(train_loader, val_loader, lr, weight_decay, loss_fn, metric_fn,
                       device, num_epochs, experiment_name, scheduler):