In [None]:
import torch

if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
    raise RuntimeError("Requires >= 2 GPUs with CUDA enabled.")

RUN_TRAIN = True
RUN_VALID = True
RUN_TEST  = True

# HGNet-V2 Starter Notebook

This notebook builds on Egor Trushin's great starter notebook [here](https://www.kaggle.com/code/egortrushin/gwi-unet-with-float16-dataset), thanks for sharing.

The main purpose of this notebook is to show how to use 2 GPUs during model training, maximizing our weekly GPU quota in the Kaggle environment. 

In addition, I provide 2x pretrained model checkpoints that were trained for 50 epochs using this setup. Each model achieved a validation MAE of ~65-70.

Other additions:
- Flip augmentation
- Dataset preprocessing
- EMA (Exponential moving average)
- Unet w/ a pretrained encoder

In [None]:
%%writefile config.yaml

local_rank: 0
data_path: "/kaggle/input/openfwi-preprocessed-72x72/openfwi_72x72/"
model_path: "/kaggle/input/openfwi-preprocessed-72x72/models/"
backbone: "hgnetv2_b4.ssld_stage2_ft_in1k"
batch_size: 256
print_freq: 100
max_epochs: 1
es_epochs: 3
seed: 99
optimizer:
    lr: 0.001
    weight_decay: 0.001
scheduler:
    params:
        factor: 0.8
        patience: 0

### Preprocess

Here we reduce the size of each input to (72,72), and then save as fp16. This reduces the size of each input to roughly 4% of the original size.

This has already been done for every datapoint in the OpenFWI dataset and can be found [here](https://www.kaggle.com/datasets/brendanartley/openfwi-preprocessed-72x72).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

def _preprocess(x):
    x = F.interpolate(x, size=(70, 70), mode='area')
    x = F.pad(x, (1,1,1,1), mode='replicate')
    return x

def _helper(x, ):
    before_shape = x.shape
    before_mem = x.nbytes / 1e6
    x = torch.from_numpy(x).float()

    # Interpolate and pad
    x = _preprocess(x)
    x = x.cpu().numpy().astype(np.float16)

    after_mem = x.nbytes / 1e6
    percent = 100 - 100 * (before_mem - after_mem) / before_mem if before_mem else 0

    # Log
    print("Shape Change")
    print("  {} -> {}".format(before_shape, x.shape))
    print()
    print("Memory Usage")
    print("  {:.1f} MB -> {:.1f} MB".format(before_mem, after_mem))
    print("  ({:.1f}% of original size)".format(percent))
    return x


In [None]:
# Preprocess
x= np.load("/kaggle/input/waveform-inversion/train_samples/CurveFault_A/seis2_1_0.npy")
x = _helper(x)

# Sanity check: Confirm preprocessing matches w/ Dataset
z= np.load("/kaggle/input/openfwi-preprocessed-72x72/openfwi_72x72/CurveFault_A/seis2_1_0.npy")
assert np.all(z == x)

del x, z

### Dataset

Here, we introduce a flip augmentation. 

Unlike a normal horizontal flip, we have to reverse the source and receiver dimensions. To match this, we reverse the width dimension of the label as well.

We use this flip as TTA (test-time augmentation) during inference.

In [None]:
%%writefile data.py

import glob
import numpy as np
from torch.utils.data import Dataset

def inputs_files_to_output_files(input_files):
    return [
        f.replace('/seis', '/vel').replace('/data', '/model')
        for f in input_files
    ]


def get_data_files(data_path):

    # All filenames
    all_inputs = [
        f for f in glob.glob(data_path + "/*/*.npy")
        if ('/seis' in f) or ('/data' in f)
    ]
    all_outputs = inputs_files_to_output_files(all_inputs)
    assert all([x != y for x,y in zip(all_inputs, all_outputs)])

    # Validation filenames
    val_fpaths= [
        'CurveFault_A/seis2_1_0.npy', 'CurveFault_A/seis2_1_1.npy', 
        'CurveFault_B/seis6_1_0.npy', 'CurveFault_B/seis6_1_1.npy', 
        'CurveVel_A/data1.npy', 'CurveVel_A/data10.npy', 
        'CurveVel_B/data1.npy', 'CurveVel_B/data10.npy', 
        'FlatFault_A/seis2_1_0.npy', 'FlatFault_A/seis2_1_1.npy', 
        'FlatFault_B/seis6_1_0.npy', 'FlatFault_B/seis6_1_1.npy', 
        'FlatVel_A/data1.npy', 'FlatVel_A/data10.npy', 
        'FlatVel_B/data1.npy', 'FlatVel_B/data10.npy', 
        'Style_A/data1.npy', 'Style_A/data10.npy', 
        'Style_B/data1.npy', 'Style_B/data10.npy',
        ]

    train_inputs, train_outputs= [], []
    valid_inputs, valid_outputs= [], []

    # Iterate and split files
    for a,b in zip(all_inputs, all_outputs):
        to_val= False
        
        for c in val_fpaths:
            if c in a:
                to_val= True

        if to_val:
            valid_inputs.append(a)
            valid_outputs.append(b)
        else:
            train_inputs.append(a)
            train_outputs.append(b)

    return train_inputs, train_outputs, valid_inputs, valid_outputs



class SeismicDataset(Dataset):
    def __init__(self, inputs_files, output_files, mode, n_examples_per_file=500):
        assert len(inputs_files) == len(output_files)
        self.inputs_files = inputs_files
        self.output_files = output_files
        self.n_examples_per_file = n_examples_per_file
        self.mode= mode

    def __len__(self):
        return len(self.inputs_files) * self.n_examples_per_file

    def __getitem__(self, idx):
        # Calculate file offset and sample offset within file
        file_idx = idx // self.n_examples_per_file
        sample_idx = idx % self.n_examples_per_file

        x = np.load(self.inputs_files[file_idx], mmap_mode='r')[sample_idx]
        y = np.load(self.output_files[file_idx], mmap_mode='r')[sample_idx]

        # Random Flips
        if self.mode == "train":

            # Receiver Flip
            if np.random.random() < 0.5:
                x= x[::-1, :, ::-1]
                y= y[:, ::-1]
            
        try:
            return x.copy(), y.copy()
        finally:
            del x, y


class TestDataset(Dataset):
    def __init__(self, test_files):
        self.test_files = test_files

    def __len__(self):
        return len(self.test_files)

    def __getitem__(self, i):
        test_file = self.test_files[i]
        test_stem = test_file.split("/")[-1].split(".")[0]
        return np.load(test_file), test_stem

### Utils

Same as Egor's.

In [None]:
%%writefile utils.py

import datetime
import random
import torch
import numpy as np

def format_time(elapsed):
    """Take a time in seconds and return a string hh:mm:ss."""
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

def seed_everything(
    seed_value: int
) -> None:
    """
    Controlling a unified seed value for Python, NumPy, and PyTorch (CPU, GPU).

    Parameters:
    ----------
    seed_value : int
        The unified random seed value.
    """
    random.seed(seed_value) # Python
    np.random.seed(seed_value) # cpu vars
    torch.manual_seed(seed_value) # cpu  vars    
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # gpu vars
    if torch.backends.cudnn.is_available:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Model

This model includes several modifications beyond a standard U-Net architecture.


### Encoder

The model uses the `HgnetV2` backbone from timm as the encoder. We have to make a few modifications for this to work with the Unet. See more info on this backbone [here](https://huggingface.co/timm/hgnetv2_b4.ssld_stage1_in22k_in1k).


First, we reduce the stride of the stem convolution from (2,2) to (1,1). This increases the size of the feature maps in the backbone. Second, we reduce the stride of the downsample convolution in the deepest block from (2,2) to (1,1). We do this so that upsampling in the decoder can be done without padding.

```python
# Original feature map
[torch.Size([18, 18]), torch.Size([9, 9]), torch.Size([5, 5]), torch.Size([3, 3])]

# Updated stem conv
[torch.Size([36, 36]), torch.Size([18, 18]), torch.Size([9, 9]), torch.Size([5, 5])]

# Updated downsample conv
[torch.Size([36, 36]), torch.Size([18, 18]), torch.Size([9, 9]), torch.Size([9, 9])]
```

### Decoder

The decoder has a few modifications as well. 

We remove all BatchNorm2d layers and add intermediate convolutions to the skip connections. I found that removing the normalization layers increased the convergence speed, and the intermediate convolutions improved the model's predictiveness.

In [None]:
%%writefile model.py

import torch
import torch.nn as nn
import timm

class EnsembleModel(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = nn.ModuleList(models).eval()

    def forward(self, x):
        output = None
        
        for m in self.models:
            logits= m(x)
            
            if output is None:
                output = logits
            else:
                output += logits
                
        output /= len(self.models)
        return output


class ConvBnAct2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding: int = 0,
        stride: int = 1,
    ):
        super().__init__()

        self.conv= nn.Conv2d(
            in_channels, 
            out_channels,
            kernel_size,
            stride=stride, 
            padding=padding, 
            bias=False,
        )
        self.act= nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.act(x)
        return x

class DecoderBlock2d(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        scale_factor: int = 2,
    ):
        super().__init__()

        self.upsample = nn.ConvTranspose2d(
            in_channels= in_channels,
            out_channels= in_channels,
            kernel_size=scale_factor, 
            stride=scale_factor,
        )

        k= 3
        c= skip_channels if skip_channels != 0 else in_channels
        self.intermediate_conv = nn.Sequential(
            ConvBnAct2d(c, c, k, k//2),
            ConvBnAct2d(c, c, k, k//2),
            )

        self.conv1 = ConvBnAct2d(
            in_channels + skip_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
        )

        self.conv2 = ConvBnAct2d(
            out_channels,
            out_channels,
            kernel_size= 3,
            padding= 1,
        )

    def forward(self, x, skip=None):
        x = self.upsample(x)

        if skip is not None:
            skip = self.intermediate_conv(skip)
            x = torch.cat([x, skip], dim=1)
        else:
            x = self.intermediate_conv(x)

        x = self.conv1(x)
        x = self.conv2(x)
        return x


class UnetDecoder2d(nn.Module):
    """
    Unet decoder.
    Source: https://arxiv.org/abs/1505.04597
    """
    def __init__(
        self,
        encoder_channels: tuple[int],
        skip_channels: tuple[int] = None,
        decoder_channels: tuple = (256, 128, 64, 32, 16),
        scale_factors: tuple = (1,2,2,2),
    ):
        super().__init__()
        
        if len(encoder_channels) == 4:
            decoder_channels= decoder_channels[1:]
        self.decoder_channels= decoder_channels
        
        if skip_channels is None:
            skip_channels= list(encoder_channels[1:]) + [0]

        # Build decoder blocks
        in_channels= [encoder_channels[0]] + list(decoder_channels[:-1])
        self.blocks = nn.ModuleList()

        for i, (ic, sc, dc, sf) in enumerate(zip(
            in_channels, skip_channels, decoder_channels, scale_factors,
        )):
            self.blocks.append(
                DecoderBlock2d(
                    ic, sc, dc, 
                    scale_factor= sf,
                    )
            )

    def forward(self, feats: list[torch.Tensor]):
        res= [feats[0]]
        feats= feats[1:]

        # Decoder blocks
        for i, b in enumerate(self.blocks):
            skip= feats[i] if i < len(feats) else None
            res.append(
                b(res[-1], skip=skip),
                )
            
        return res

class SegmentationHead2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        scale_factor: tuple[int] = (2,2),
        kernel_size: int = 3,
    ):
        super().__init__()
        self.conv= nn.Conv2d(
            in_channels, out_channels, kernel_size= kernel_size,
            padding= kernel_size//2
        )
        self.upsample = nn.Upsample(
            scale_factor= scale_factor,
            mode='bilinear',
            align_corners= False
        )

    def forward(self, x):
        x = self.conv(x)
        x = self.upsample(x)
        return x

class HGUNet(nn.Module):
    def __init__(
        self,
        backbone: str,
    ):
        super().__init__()

        # Encoder
        self.backbone= timm.create_model(
            backbone,
            in_chans= 5,
            pretrained= True,
            features_only= True,
            drop_path_rate= 0.2,
            )
        ecs= [_["num_chs"] for _ in self.backbone.feature_info][::-1]

        # Decoder
        self.decoder= UnetDecoder2d(
            encoder_channels= ecs,
            decoder_channels= (256, 128, 64, 32),
            scale_factors= (1,2,2,2),
        )

        self.seg_head= SegmentationHead2d(
            in_channels= self.decoder.decoder_channels[-1],
            out_channels= 1,
            scale_factor= 2,
        )
        self._update_stem()

    def _update_stem(self, ):
        self.backbone.stem.stem1.conv.stride=(1,1)
        self.backbone.stages_3.downsample.conv.stride=(1,1)
        pass

        
    def proc_flip(self, x_in):
        # Flip TTA during inference
        x_in= torch.flip(x_in, dims=[-3, -1])
        
        x= self.backbone(x_in)
        x= x[::-1]
        x= self.decoder(x)
        x_seg= self.seg_head(x[-1])
        x_seg= x_seg[..., 1:-1, 1:-1]
        x_seg= x_seg * 1500 + 3000
        
        x_seg= torch.flip(x_seg, dims=[-1])
        return x_seg

    def forward(self, x):        
        x_in = x
        
        x= self.backbone(x)
        x= x[::-1]
        x= self.decoder(x)
        x_seg= self.seg_head(x[-1])
        x_seg= x_seg[..., 1:-1, 1:-1]
        x_seg= x_seg * 1500 + 3000
    
        if self.training:
            return x_seg
        else:
            p1 = self.proc_flip(x_in)
            x_seg = torch.mean(torch.stack([x_seg, p1]), dim=0)
            return x_seg

### EMA

This is a common strategy used to increase the stability of validation performance between steps/epochs. This implementation is from Tereka [here](https://www.kaggle.com/competitions/blood-vessel-segmentation/discussion/475080#2641635).

Note: We have to be patient with EMA. We should see good results after 5 epochs or so.

In [None]:
%%writefile ema.py

from copy import deepcopy
import torch
import torch.nn as nn

class ModelEMA(nn.Module):
    def __init__(self, model, decay=0.99, device=None):
        super().__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

# Train

Here is the main training script. 

By using 2x GPUs, we can use larger batch sizes and speed up model training. No more wasted Quota!

I won't go into the details of the script as there are already many good resources explaining DDP. Here are a couple of good starting points. 

- [Run DDP scripts with 2 T 4](https://www.kaggle.com/code/cpmpml/run-ddp-scripts-with-2-t-4) by @CPMP
- [Getting Started with Distributed Data Parallel
](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Distributed Data Parallel Docs](https://docs.pytorch.org/docs/stable/notes/ddp.html)

In [None]:
%%writefile train.py

import os
import yaml
import time
import gc
import numpy as np

import torch
import torch.nn as nn

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler

from data import inputs_files_to_output_files, get_data_files, SeismicDataset
from utils import format_time, seed_everything
from model import HGUNet
from ema import ModelEMA

def printy(*args):
    try: 
        global cfg
        assert cfg["local_rank"] == 0
    except:
        print(*args)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def setup(rank, world_size):
    torch.cuda.set_device(rank)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    return

def cleanup():
    dist.barrier()
    dist.destroy_process_group()
    return

def train(cfg):

    # ========= Datasets / Dataloaders ==========
    printy("="*25)
    train_inputs, train_outputs, valid_inputs, valid_outputs = get_data_files(cfg["data_path"])
    printy("TRAIN_FILES:", len(train_inputs))
    printy("VALID_FILES:", len(valid_inputs))
    printy("="*25)

    dstrain = SeismicDataset(train_inputs, train_outputs, mode="train")
    samplertrain= DistributedSampler(
        dstrain, 
        num_replicas=cfg["world_size"], 
        rank=cfg["local_rank"],
        )
    dltrain = DataLoader(
        dstrain,
        sampler= samplertrain,
        batch_size=cfg["batch_size"],
        pin_memory=False,
        drop_last=True,
        num_workers=4,
        persistent_workers=False,
    )

    dsvalid = SeismicDataset(valid_inputs, valid_outputs, mode="valid")
    samplervalid= DistributedSampler(
        dsvalid, 
        num_replicas=cfg["world_size"], 
        rank=cfg["local_rank"],
        )
    dlvalid = DataLoader(
        dsvalid,
        sampler= samplervalid,
        batch_size=32,
        pin_memory=False,
        drop_last=False,
        num_workers=4,
        persistent_workers=False,
    )
    

    # ========== Model / EMA ==========
    printy("="*25)
    model = HGUNet(backbone=cfg["backbone"])
    model = model.to(cfg["local_rank"])
    model = DistributedDataParallel(
        model, 
        device_ids=[cfg["local_rank"]], 
        )
    n_params= count_parameters(model)
    printy("backbone: {}".format(cfg["backbone"]))
    printy("n_params: {:_}".format(count_parameters(model)))
    
    ema_model= ModelEMA(model.module, decay=0.99)
    printy("="*25)

    
    # ========== Training ==========
    criterion = nn.L1Loss()
    optimizer = torch.optim.AdamW(model.parameters(), **cfg["optimizer"])  # hparams
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', **cfg["scheduler"]["params"])

    best_val_loss = 10000.0
    epochs_wo_improvement = 0
    t0 = time.time()
    

    for epoch in range(1, cfg["max_epochs"] + 1):

        # Shuffle sampler
        dltrain.sampler.set_epoch(epoch)
    
        # Train
        model.train()
        train_losses = []
        for step, (inputs, targets) in enumerate(dltrain):
    
            inputs = inputs.to(cfg["local_rank"])
            targets = targets.to(cfg["local_rank"])
            optimizer.zero_grad()
    
            with torch.autocast(device_type="cuda"):
                outputs = model(inputs)
                loss = criterion(outputs, targets)
    
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

            if ema_model is not None:
                ema_model.update(model.module)

            if cfg["local_rank"] == 0:
                if step % cfg["print_freq"] == cfg["print_freq"] - 1 or step == len(dltrain) - 1:
                    trn_loss = np.mean(train_losses)
                    t1 = format_time(time.time() - t0)
                    free, total = torch.cuda.mem_get_info(device=0)
                    mem_used = (total - free) / 1024**3
                    lr = optimizer.param_groups[-1]['lr']
                    print(
                        f"Epoch: {epoch:02d}  Step {step+1}/{len(dltrain)}  Trn Loss: {trn_loss:.2f}  LR: {lr:.2e}  GPU Usage: {mem_used:.2f}GB  Elapsed Time: {t1}",
                        flush=True,
                    )


        # Valid
        model.eval()
        valid_losses = []
        for inputs, targets in dlvalid:
            inputs = inputs.to(cfg["local_rank"])
            targets = targets.to(cfg["local_rank"])
    
            with torch.inference_mode():
                with torch.autocast(device_type="cuda"):
                    
                    if ema_model is not None:
                        outputs = ema_model.module(inputs)
                    else:
                        outputs = model(inputs)
    
            loss = criterion(outputs, targets)
            valid_losses.append(loss.item())

        # Gather loss on same device
        v = torch.tensor([sum(valid_losses), len(valid_losses)], device=cfg["local_rank"])
        torch.distributed.all_reduce(v, op=dist.ReduceOp.SUM)
        val_loss = (v[0] / v[1]).item()

        scheduler.step(val_loss)

        # Log
        if cfg["local_rank"] == 0:
            free, total = torch.cuda.mem_get_info(device=0)
            mem_used = (total - free) / 1024**3
            print(
                f"\nEpoch: {epoch:02d}  Trn Loss: {trn_loss:.2f}  Val Loss: {val_loss:.2f}  GPU Usage: {mem_used:.2f}GB  Elapsed Time: {t1}",
                flush=True,
            )
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_wo_improvement = 0
                torch.save(ema_model.module.state_dict(), "best_model.pth")
                print(f"\nNew best val_loss: {val_loss:.2f}\n", flush=True)
            else:
                epochs_wo_improvement += 1
                print(f"\nEpochs without improvement: {epochs_wo_improvement}\n", flush=True)
        
            if epochs_wo_improvement == cfg["es_epochs"]:
                break

    # Cleanup
    del model, ema_model, optimizer, scheduler
    del dltrain, dlvalid, dstrain, dsvalid
    gc.collect()
    torch.cuda.empty_cache()
                            
    return

if __name__ == "__main__":

    # GPU Specs
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    _, total = torch.cuda.mem_get_info(device=rank)

    # Init
    setup(rank, world_size)
    time.sleep(rank)
    print(f"Rank: {rank}, World size: {world_size}, GPU memory: {total / 1024**3:.2f}GB", flush=True)
    time.sleep(world_size - rank)

    # Load cfg
    with open("config.yaml", "r") as file_obj:
        cfg = yaml.safe_load(file_obj)
    cfg["local_rank"]= rank
    cfg["world_size"]= world_size
    seed_everything(cfg["seed"]+rank)

    # Run
    train(cfg)
    cleanup()

In [None]:
if RUN_TRAIN:
    print("Starting training..")
    !torchrun --nproc_per_node=2 train.py

# Valid

First, we load in a 3x pretrained models for inference. These models were trained with with an effective batch_size of 512 (256 per GPU).

The output prediction is an average of each model.

In [None]:
import yaml
import glob

import torch

from model import HGUNet, EnsembleModel

# Load cfg
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("config.yaml", "r") as file_obj:
    cfg = yaml.safe_load(file_obj)

# Load models
fpaths= glob.glob(cfg["model_path"] + "*.pt")
models = []
for f in fpaths:
    m = HGUNet(backbone=cfg["backbone"])
    m.load_state_dict(torch.load(f, map_location=device, weights_only=True))
    models.append(m)

# Combine models
model = EnsembleModel(models)
model = model.to(device)
model = model.eval()
print("n_models: {:_}".format(len(models)))

Next, we score the ensemble on the validation set.

In [None]:
import gc
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, SequentialSampler

from data import inputs_files_to_output_files, get_data_files, SeismicDataset


if RUN_VALID:

    # ========= Datasets / Dataloaders ==========
    print("="*25)
    _, _, valid_inputs, valid_outputs = get_data_files(cfg["data_path"])
    print("VALID_FILES:", len(valid_inputs))
    print("="*25)
    
    dsvalid = SeismicDataset(valid_inputs, valid_outputs, mode="valid")
    dlvalid = DataLoader(
        dsvalid,
        sampler= SequentialSampler(dsvalid),
        batch_size=32,
        num_workers=4,
    )
    
    # ========== Validate ==========
    criterion = nn.L1Loss()
    val_logits= []
    val_targets= []
    
    with torch.inference_mode():
        with torch.autocast(device.type):
    
            # Iterate all samples
            for inputs, targets in tqdm(dlvalid, total=len(dlvalid)):
                inputs = inputs.to(device)
                targets = targets.to(device)
            
                outputs = model(inputs)
                
                val_logits.append(outputs.cpu())
                val_targets.append(targets.cpu())
    
            # Compute Loss
            val_logits= torch.cat(val_logits, dim=0).float()
            val_targets= torch.cat(val_targets, dim=0)
    
            loss = criterion(
                input=val_logits,
                target=val_targets,
            ).item()
    
    print("="*25)
    print("val_loss: {:.2f}".format(loss))
    print("="*25)


    # By class scores
    print("="*25)
    ds_idxs= np.array([_.split("/")[-2] for _ in valid_inputs])
    ds_idxs= np.repeat(ds_idxs, repeats=500)
    df= []
    
    with torch.no_grad():    
        for idx in sorted(np.unique(ds_idxs)):
    
            # Mask
            mask = ds_idxs == idx
            logits_ds = val_logits[mask]
            targets_ds = val_targets[mask]
    
            # Score predictions
            mae = F.l1_loss(val_logits[mask], val_targets[mask], reduction='mean').item()
            print("{:15} {:.2f}".format(idx, mae))
    print("="*25)
    
    # Cleanup
    del dsvalid, dlvalid
    del inputs, targets, outputs
    del ds_idxs, val_logits, val_targets, df
    gc.collect()
    torch.cuda.empty_cache()

# Test

Finally, we make predictions on the test data.

In [None]:
import csv
import time

from data import TestDataset
from utils import format_time

if RUN_TEST:
    t0 = time.time()
    
    test_files = glob.glob("/kaggle/input/open-wfi-test/test/*.npy")
    x_cols = [f"x_{i}" for i in range(1, 70, 2)]
    fieldnames = ["oid_ypos"] + x_cols
    
    ds = TestDataset(test_files)
    dl = DataLoader(ds, batch_size=32, num_workers=4, pin_memory=False)
    
    with open("submission.csv", "wt", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        with torch.inference_mode():
            with torch.autocast(device.type):
                for inputs, oids_test in tqdm(dl, total=len(dl)):
                    inputs = inputs.to(device)
            
                    inputs = _preprocess(inputs)
                    outputs = model(inputs)
                            
                    y_preds = outputs[:, 0].cpu().numpy()
                    
                    for y_pred, oid_test in zip(y_preds, oids_test):
                        for y_pos in range(70):
                            row = dict(zip(x_cols, [y_pred[y_pos, x_pos] for x_pos in range(1, 70, 2)]))
                            row["oid_ypos"] = f"{oid_test}_y_{y_pos}"
            
                            writer.writerow(row)
    
    t1 = format_time(time.time() - t0)
    print(f"Inference Time: {t1}")

We can also view a few samples to make sure things look reasonable.

In [None]:
import matplotlib.pyplot as plt 

if RUN_TEST:
    # Plot a few samples
    fig, axes = plt.subplots(3, 5, figsize=(10, 6))
    axes= axes.flatten()

    n = min(len(outputs), len(axes))
    
    for i in range(n):
        img= outputs[0, 0, ...].cpu().numpy()
        img = outputs[i, 0].cpu().numpy()
        idx= oids_test[i]
    
        # Plot
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(idx)
        axes[i].axis('off')

    for i in range(n, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()