In [1]:
from pathlib import Path
import pandas as pd
from functools import partial
import json
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt
import rasterio
from rasterio.plot import show
from rasterio.mask import mask
from tqdm.notebook import tqdm, trange
import fiona
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, models, transforms
import torchvision.transforms as T
from sklearn.metrics import confusion_matrix
import seaborn as sns
from time import time

## Preprocessing steps

The preprocessing steps's goal is to find the best stats to equilibrate the dataset and the model. In order to standardize the data, we need the average on the whole dataset.

In [2]:
# Main setting
train_path = Path('train')
test_path = Path('test')

print(train_path)
print(test_path)

train
test


In [3]:
# Look for information in the datafile
csv_res = train_path / 'Building_Solutions.csv'

def generate_csv(csv_res):
    csv_path = train_path / 'info'
    dfs = [pd.read_csv(csv_file) for csv_file in csv_path.iterdir()]
    result = pd.concat(dfs)
    result.to_csv(csv_res, index=None)
    return result

# Use function to load csv
# df = generate_csv(csv_res)
df = pd.read_csv(csv_res)
df

Unnamed: 0,ImageId,BuildingId,PolygonWKT_Pix,PolygonWKT_Geo
0,AOI_3_Paris_img485,1,"POLYGON ((31.68 11.69 0,31.06 11.39 0,44.13 -0...",POLYGON ((2.242656935000071 49.023104327000055...
1,AOI_3_Paris_img485,2,"POLYGON ((-0.0 199.48 0,-0.0 216.43 0,10.24 20...",POLYGON ((2.242571399947189 49.022597304268373...
2,AOI_3_Paris_img485,3,"POLYGON ((-0.0 153.61 0,-0.0 177.39 0,25.13 18...",POLYGON ((2.242571399947189 49.02272116311444 ...
3,AOI_3_Paris_img485,4,"POLYGON ((69.9 124.55 0,29.41 110.46 0,2.29 14...",POLYGON ((2.242760130000022 49.022799603000067...
4,AOI_3_Paris_img485,5,"POLYGON ((84.69 190.02 0,157.39 210.2 0,171.31...",POLYGON ((2.242800060000036 49.022622857000044...
...,...,...,...,...
223398,AOI_2_Vegas_img4107,4,"POLYGON ((650.0 509.9 0,650.0 443.3 0,321.98 5...",POLYGON ((-115.202217599790103 36.211160967734...
223399,AOI_2_Vegas_img4107,5,"POLYGON ((650.0 359.73 0,650.0 294.18 0,254.17...",POLYGON ((-115.202217599790103 36.211566416926...
223400,AOI_2_Vegas_img4107,6,"POLYGON ((639.48 150.23 0,125.53 392.7 0,136.6...",POLYGON ((-115.202245993999952 36.212132067000...
223401,AOI_2_Vegas_img4107,7,"POLYGON ((326.22 149.0 0,65.42 270.22 0,77.4 2...",POLYGON ((-115.203091813999947 36.212135405000...


In [4]:
cities = ["Paris", "Shanghai", "Khartoum", "Vegas"]

for city in cities:
    print(city, ':', 
        df[df['ImageId'].str.contains(city)]['ImageId'].unique().size, 
        '/', df['ImageId'].unique().size)

Paris : 1148 / 10593
Shanghai : 4582 / 10593
Khartoum : 1012 / 10593
Vegas : 3851 / 10593


In [5]:
def count_files_per_city(fp):
    count_dict = {city: 0 for city in cities}

    for filename in tqdm(fp.iterdir(), desc=f"Folder peeling"):
        for city in cities:
            if city in filename.stem:
                count_dict[city] += 1
    return count_dict

# GeoJSONs
print("GeoJSON")
print(count_files_per_city((train_path / "buildings")))

# Img
print("Images:")
# print("MUL", count_files_per_city((train_path / "data" / "MUL" / "MUL")))
print("MUL-PanSharpen", count_files_per_city((train_path / "data" / "MUL-PanSharpen" / "MUL-PanSharpen")))
# print("PAN",count_files_per_city((train_path / "data" / "PAN" / "PAN")))
# print("RGB-Pan",count_files_per_city((train_path / "data" / "RGB-PanSharpen" / "RGB-PanSharpen")))

GeoJSON


HBox(children=(HTML(value='Folder peeling'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='2…


{'Paris': 1148, 'Shanghai': 4582, 'Khartoum': 1012, 'Vegas': 3851}
Images:


HBox(children=(HTML(value='Folder peeling'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='2…


MUL-PanSharpen {'Paris': 1148, 'Shanghai': 4582, 'Khartoum': 1012, 'Vegas': 3851}


In [6]:
def compute_mean_std(filepath, n_ch):
    sum_channels = np.zeros(n_ch) #8, 3 or 1
    std_channels = np.zeros(n_ch)
    total_pixel = 0

    for img in tqdm((filepath).iterdir(), desc="10593"):
        with rasterio.open(img, 'r') as ds:
            try:
                arr = ds.read() 
            except:
                print(f"Uh oh, {img.stem} seems to be corrupted...")
            else:
                arr = arr.reshape(arr.shape[0], -1)
                sum_channels += arr.sum(axis=-1)
                total_pixel += arr[0].size
            
    mean_channels = sum_channels / total_pixel


    for img in tqdm((filepath).iterdir()):
        with rasterio.open(img, 'r') as ds:
            try:
                arr = ds.read() 
            except:
                print(f"Uh oh, {img.stem} seems to be corrupted...")
            else:
                arr = arr.reshape(arr.shape[0], -1)
                std_channels += np.sum((arr - mean_channels.reshape(n_ch, 1)) ** 2, axis=-1) 
            
    std_channels = np.sqrt(std_channels / total_pixel)

    stats = {'mean': mean_channels.tolist(), 'std': std_channels.tolist()}
    return stats

# folder_path = train_path / "data" / "MUL-PanSharpen" / "MUL-PanSharpen"
# stats = compute_mean_std(folder_path,8)
# with open(train_path / 'stats_mul_pan.json', 'w') as file:
#     json.dump(stats, file)

In [7]:
def load_stats(filepath):
    with open(filepath, 'r') as file:
        n_params = json.load(file)
    mean_channels = np.array(n_params['mean'])
    std_channels = np.array(n_params['std'])
    return mean_channels, std_channels

mean_channels, std_channels = load_stats(train_path / 'stats_mul_pan.json')
print(mean_channels)
print(std_channels)

[296.60002065 357.02413693 463.62619325 416.1747543  331.35564105
 408.23318352 478.25544466 363.4427353 ]
[105.56921283 148.62490128 224.65445921 226.26444245 195.07064322
 210.31440212 239.627737   197.79631543]


In [8]:
import pathlib

def norm_img(img, mean_arr, std_arr):
    res = (np.transpose(img, (1, 2, 0)) - mean_arr) / std_arr
    return np.transpose(res, (2,0,1))

def load_tif(fn, df, mean_vec, std_vec):
    img_id = "_".join(pathlib.Path(fn).stem.split("_")[1:]) # get img id
    train_path = pathlib.Path(fn).parents[3] # Get train path from img

    no_building = df[df['BuildingId']==-1]['ImageId'].unique().tolist()
    geojson_path = train_path / "buildings" / f"buildings_{img_id}.geojson"

    # Extract the file as a (8 x 650 x 650) cube
    with rasterio.open(fn) as tif:
        arr = tif.read()
        info = tif.meta
    
    info['count'] = 1
    # Extract geofeatures if the image has buildings
    if img_id in no_building:
        X = np.zeros((info['height'],info['width']), dtype='uint16')
        features = []
    else:
        with fiona.open(geojson_path, "r") as geojson:
            features = [feature["geometry"] for feature in geojson]
        X = np.ones((info['height'],info['width']), dtype='uint16')

    # Write polygons as a tif whose dimensions are the same than the opened tif
    with rasterio.open('temp.tif','w', **info) as new_ds:
        new_ds.write(X, 1)
    
    # Extract mask if necessary
    with rasterio.open('temp.tif') as tif:
        if features:
            mask_img, _ = rasterio.mask.mask(tif, features)
        else:
            mask_img = tif.read()
    
    
    # arr = norm_img(arr, mean_vec, std_vec)
    arr, mask_img = arr.astype('float32'), mask_img.squeeze().astype('int64')
    pathlib.Path('temp.tif').unlink()

    return arr, mask_img

# Need to standardize by avg / std and show as tensor
load_img = partial(
    load_tif,
    df = df, #df directly
    mean_vec = mean_channels,
    std_vec = std_channels
)

In [9]:
img_path = train_path / 'data' / 'MUL-PanSharpen'

# Define dataset here
ds = datasets.DatasetFolder(root=img_path, 
                       loader=load_img, 
                       extensions=('.tif',))

print("N° of images:", len(ds))
print("Type of img:", ds.classes[0])

N° of images: 10593
Type of img: MUL-PanSharpen


In [10]:
def split_dataset(ds, train_size=0.8, random_seed=0):
    if type(train_size) is float:
        train_size = int(len(ds)*train_size)
    train_ds, val_ds = random_split(ds, (train_size, len(ds)-train_size), generator=torch.Generator().manual_seed(random_seed))
    return train_ds, val_ds

train_ds, val_ds = split_dataset(ds, train_size=0.8, random_seed=123)
print(len(train_ds))
print(len(val_ds))

8474
2119


In [11]:
batch_size = 16 #16

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=True, num_workers=0)
print("N° of iterations per batch (train):", len(train_dl))
print("N° of iterations per batch (val):", len(val_dl))

N° of iterations per batch (train): 530
N° of iterations per batch (val): 133


## Definition of the model

The model used is an AE (autoencoder) trained specially for the pan-sharpened part of the multichannel dataset, thus for `8x650x650` images. 

Some modifications might be done in order to exploit this model for other datasets, like changing the number of channels.

In [12]:
# Train on whole dataset

class AutoencoderBuildingMulPs(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_init = nn.Conv2d(8, 16, 3, padding=1)
        self.conv1 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv3 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv4 = nn.Conv2d(16, 8, 3, padding=1)
        
        self.conv5 = nn.Conv2d(8, 16, 3, padding=1)
        self.conv6 = nn.Conv2d(16, 16, 3, padding=1)
        self.conv7 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv8 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv_last = nn.Conv2d(16, 2, 3, padding=1)
        
    def forward(self, x):
        # Encoder
        x = F.relu(self.conv_init(x))
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x,2)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x,2)
        x = F.relu(self.conv4(x))
        x = F.max_pool2d(x,2)
        
        # Decoder
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        x = F.relu(self.conv5(x))
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        n_size = tuple([(dim+1)*2 for dim in x.shape[2:]])
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = F.interpolate(x, mode='bilinear', 
                          align_corners=False, size=n_size)
        x = F.relu(self.conv8(x))
        x = self.conv_last(x)
        return x

In [13]:
ae_model = AutoencoderBuildingMulPs()
print(ae_model)

#Send to GPU
device = torch.device(("cuda" if torch.cuda.is_available() else "cpu"))
print(device)
ae_model = ae_model.to(device)

AutoencoderBuildingMulPs(
  (conv_init): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_last): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
cuda


In [14]:
epochs = 100
criterion = nn.CrossEntropyLoss(weight=torch.tensor([.11, .89]))
criterion = criterion.to(device)
optimizer = optim.Adam(ae_model.parameters(), lr=1e-3)

In [15]:
#Defining a class where we register every parameter necessary to train the model
class ModelParameters:
    def __init__(self, model, device, epochs, criterion, optimizer, train_dl, val_dl, sim_bs=None):
        self.model = model.to(device)
        self.device = device
        self.epochs = epochs
        self.criterion = criterion.to(device)
        self.optimizer = optimizer
        self.train_dl = train_dl
        self.val_dl = val_dl
        if sim_bs:
            self.sim_bs = sim_bs // self.train_dl.batch_size 
    
    def __str__(self):
        return f"Model: {self.model}\nOn: {self.device}\nN°epochs: {epochs}"

mp = ModelParameters(ae_model, device, epochs, criterion, optimizer, 
                     train_dl, val_dl)
print(mp)

Model: AutoencoderBuildingMulPs(
  (conv_init): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_last): Conv2d(16, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
On: cuda
N°epochs: 100


In [16]:
# Eval cell

def n_correct_pred(output, label):
    bs, ch, h, w = output.shape
    pred = output.argmax(dim=1)
    return (pred==label).sum() / h*w

def batch_cm(output, label):
    to_numpy = lambda tens: tens.cpu().detach().reshape(-1).numpy()
    cm = np.zeros((2,2))
    pred = output.argmax(dim=1)
    bs, ch, h, w = output.shape
    # use zip instead
    for i in range(bs):
        cm += confusion_matrix(to_numpy(label[i]), 
                          to_numpy(pred[i]), labels=[0, 1])
    return cm

@torch.no_grad()
def evaluate_model(mp):
    mp.model.eval()
    output_loss, n_corr_preds, total_cm = 0, 0, []
    
    for batch_n, ((img, mask), _) in tqdm(enumerate(mp.val_dl), desc="Model evaluation", unit="batch", total=len(mp.val_dl)):
        img, mask = img.to(mp.device), mask.to(mp.device)
        pred = mp.model(img)
        
        loss = mp.criterion(pred, mask)
        output_loss += loss.detach().item()
        corr_pred = n_correct_pred(pred, mask)
        n_corr_preds += corr_pred.detach().item()
        cm = batch_cm(pred,mask)
        total_cm.append(cm)
    
    return output_loss, n_corr_preds / len(mp.val_dl), total_cm

@torch.no_grad()
def eval_model_limit(mp, lim):
    mp.model.eval()
    total_size, output_loss, n_corr_preds = 0, 0, 0
    tot_bs = mp.train_dl.batch_size
    
    for batch_n, ((img, mask), _) in tqdm(enumerate(mp.train_dl), desc="Val training", total=lim, unit="batch"):
        img, mask = img.to(mp.device), mask.to(mp.device)
        pred = mp.model(img)
        total_size += len(pred.view(-1))
        loss = mp.criterion(pred, mask)
        output_loss += loss.detach().item()

        corr_pred = n_correct_pred(pred, mask)
        n_corr_preds += corr_pred.detach().item()

        if batch_n==lim:
            cm = batch_cm(pred,mask)
            break
    
    return output_loss, n_corr_preds / lim, 100* cm / (cm.sum() * tot_bs)


def train_model(mp):
    (train_path /'save_states').mkdir(exist_ok=True)
    
    total_results = {'train': {'loss': []}, 
                   'val': {'loss': [], 'correct_pred': [], 'cm': []}}
  

    mp.model.zero_grad()
    for epoch in trange(mp.epochs, desc="Train", unit="epoch"):
        train_epoch = 0
        for bn, ((img, mask), _) in tqdm(enumerate(mp.train_dl), 
                                     desc=f"Batch training", 
                                     total=len(mp.train_dl), unit="batch"):
            img, mask = img.to(mp.device), mask.to(mp.device)
            mp.model.train()
            pred = mp.model(img)
            loss = mp.criterion(pred, mask) #Avg loss on whole batch
            loss.backward()

            train_epoch += loss.detach().item()
#             test_batch, pred_batch, cm_batch = eval_model_limit(mp, 5)

#             print(f"Epoch {epoch}, batch {bn}, average loss: {np.array(test_batch).mean()}")

#             iou = cm_batch[-1, -1] / (cm_batch.sum() - cm_batch[0, 0])
#             print("IoU:", iou, '%')

            mp.model.train()
            mp.optimizer.step()                            
            mp.optimizer.zero_grad()

            
#         torch.save({'epoch': epoch,
#                'model_state_dict': mp.model.state_dict(),
#                'optimizer_state_dict': mp.optimizer.state_dict()}, 
#                train_path /'save_states'/ f'model_chkpt_ep_{epoch:03d}')

    
        total_results['train']['loss'].append(train_epoch)
        print(f"Train loss: {train_epoch}")

        # Evaluation step 
        output_loss, n_corr_preds, total_cm = evaluate_model(mp)
        
        last_cm = total_cm[-1]
        print(f"Val loss: {output_loss}")
        print(f"N of correct preds: {n_corr_preds}")
        print(f"Last CM:\n{total_cm[-1]}")
        
        iou = last_cm[-1, -1] / (last_cm.sum() - last_cm[0, 0])
        print(f"Last IoU: {iou}")
        
        total_results['val']['loss'].append(output_loss)
        total_results['val']['correct_pred'].append(n_corr_preds)
        total_results['val']['cm'].append(total_cm)

        torch.save({'epoch': epoch,
               'model_state_dict': mp.model.state_dict(),
               'optimizer_state_dict': mp.optimizer.state_dict(),
               'total_results': total_results}, train_path /'save_states'/ f'model_bs_{mp.train_dl.batch_size}_chkpt_ep_{epoch:03d}')
  
    torch.save(mp.model.state_dict(), train_path / f'ae_building_mask')
    
    return total_results  

In [None]:
import json
total_results = train_model(mp)


with open('total_result_dict.json', 'w') as fp:
    json.dump(total_results, fp)

HBox(children=(HTML(value='Train'), FloatProgress(value=0.0), HTML(value='')))

HBox(children=(HTML(value='Batch training'), FloatProgress(value=0.0, max=530.0), HTML(value='')))


Train loss: 242.66615882515907


HBox(children=(HTML(value='Model evaluation'), FloatProgress(value=0.0, max=133.0), HTML(value='')))

  



Val loss: 49.972219586372375
N of correct preds: 4964734.2105263155
Last CM:
[[1806218.  666180.]
 [  17306.  467796.]]
Last IoU: 0.40632616509247954


HBox(children=(HTML(value='Batch training'), FloatProgress(value=0.0, max=530.0), HTML(value='')))


Train loss: 205.59057426452637


HBox(children=(HTML(value='Model evaluation'), FloatProgress(value=0.0, max=133.0), HTML(value='')))


Val loss: 47.885592103004456
N of correct preds: 4781057.894736842
Last CM:
[[1743354.  581176.]
 [  45694.  587276.]]
Last IoU: 0.48369471216805887


HBox(children=(HTML(value='Batch training'), FloatProgress(value=0.0, max=530.0), HTML(value='')))


Train loss: 187.66497646272182


HBox(children=(HTML(value='Model evaluation'), FloatProgress(value=0.0, max=133.0), HTML(value='')))


Val loss: 45.61543545126915
N of correct preds: 5097387.969924812
Last CM:
[[1673832.  870334.]
 [  31068.  382266.]]
Last IoU: 0.29779195243629974


HBox(children=(HTML(value='Batch training'), FloatProgress(value=0.0, max=530.0), HTML(value='')))

In [None]:
serial_results = total_results.copy()

with open('total_result_dict.json', 'w') as fp:
    serial_results['val']['cm'] = [[cm.tolist()  for cm in ep] for ep in total_results['val']['cm']]
    json.dump(serial_results, fp)

In [None]:
import matplotlib.pyplot as plt

plt.plot(total_results["train"]["loss"])
plt.plot(total_results["val"]["loss"])
plt.title("Loss evolution")
plt.xlabel("Epoch")
plt.ylabel("Loss value")
plt.savefig("loss_value.png")

In [None]:
plt.plot(total_results["val"]["correct_pred"])
plt.title("N of correct predictions")
plt.xlabel("Epoch")
plt.ylabel("N correct prediction")
plt.savefig("corr_preds.png")

In [None]:
total_cm = total_results['val']['cm']
cm_prds = [sum([np.array(cm) for cm in ep]) for ep in total_cm]

fig, axs = plt.subplots(2, 3, figsize=(18,12))

for i, axrow in enumerate(axs):
    for j, ax in enumerate(axrow):
        idx = i*3 + j
        sns.heatmap(cm_prds[idx*5] / (cm_prds[idx*5].sum()), ax=ax, annot=True)
#         sns.heatmap(cm_prds[idx*5] / (cm_prds[idx*5].sum() * mp.val_dl.batch_size*len(mp.val_dl)), ax=ax, annot=True)

In [None]:
iou = [cm[-1, -1] / (cm.sum() - cm[0, 0])  for cm in cm_prds]
iou

In [None]:
plt.plot(iou)
plt.title("IoU")
plt.ylabel("IoU")
plt.xlabel("Epoch")
plt.savefig("iou.png")

In [None]:
assert False

## Resume training

In [None]:
chpt_list = [fn for fn in (train_path /'save_states').iterdir() if fn.name.startswith("model_bs_16")]

last_chpt = sorted(chpt_list, key=lambda x: int(x.name.split("_")[-1]))[-1]
last_chpt

In [None]:
def resume_training(mp, last_chpt):
    checkpoint = torch.load(last_chpt)
    mp.model.load_state_dict(checkpoint['model_state_dict'])
    mp.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    total_results = checkpoint["total_results"]
    
    for epoch in trange(start_epoch + 1, mp.epochs + start_epoch, desc="Train", unit="epoch"):
        train_epoch = 0
        for bn, ((img, mask), _) in tqdm(enumerate(mp.train_dl), 
                                     desc=f"Batch training", 
                                     total=len(mp.train_dl), unit="batch"):
            img, mask = img.to(mp.device), mask.to(mp.device)
            mp.model.train()
            pred = mp.model(img)
            loss = mp.criterion(pred, mask) #Avg loss on whole batch
            loss.backward()

            train_epoch += loss.detach().item()
            

            mp.model.train()
            mp.optimizer.step()                            
            mp.optimizer.zero_grad()
            

    
        total_results['train']['loss'].append(train_epoch)
        print(f"Train loss: {train_epoch}")

        # Evaluation step 
        output_loss, n_corr_preds, total_cm = evaluate_model(mp)
        
        last_cm = total_cm[-1]
        print(f"Val loss: {output_loss}")
        print(f"N of correct preds: {n_corr_preds}")
        print(f"Last CM:\n{total_cm[-1]}")
        
        iou = last_cm[-1, -1] / (last_cm.sum() - last_cm[0, 0])
        print(f"Last IoU: {iou}")
        
        total_results['val']['loss'].append(output_loss)
        total_results['val']['correct_pred'].append(n_corr_preds)
        total_results['val']['cm'].append(total_cm)

        torch.save({'epoch': epoch,
               'model_state_dict': mp.model.state_dict(),
               'optimizer_state_dict': mp.optimizer.state_dict(),
               'total_results': total_results}, train_path /'save_states'/ f'model_bs_{mp.train_dl.batch_size}_chkpt_ep_{epoch:03d}')
  
    torch.save(mp.model.state_dict(), train_path / f'ae_building_mask')
    
    return total_results  

In [None]:
import json
total_results = train_model(mp)


with open('total_result_dict.json', 'w') as fp:
    json.dump(total_results, fp)