In [1]:
import os
import torch
import numpy as np
torch.manual_seed(42)
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import time

from dataset import Agulhas2
from models.DetideNet import IGWResNet
from joint_transforms import Transform2

from utils.Pix2Pix import train_loop, val_loop
from utils import save_checkpoint, load_checkpoint, csv_writer, save_examples2

In [2]:
MODEL="DetideNet"
LEARNING_RATE=1.0E-3
NUM_EPOCHS=401
INPUT_SIZE=256
BATCH_SIZE=32
NUM_WORKERS=0

SCRATCH_BUCKET = os.environ['SCRATCH_BUCKET']
SNAPSHOT_DIR = os.path.join('outputs', MODEL, 'snapshots')
print(SNAPSHOT_DIR)

USE_CHECKPOINT = False
RESTORE_FROM = os.path.join("DetideNet", 'snapshots', 'epoch-100')

outputs/DetideNet/snapshots


In [3]:
loss_fn = nn.MSELoss()
loss_l1 = nn.L1Loss()

def sshtoqSS(ssh_tensor):
    #print(ssh_tensor.shape)
    laplacian_x = torch.cuda.FloatTensor([1, -2, 1]).view([1, 1, 1, 3])
    laplacian_y = torch.transpose(laplacian_x,2,3)
    laplaciansshx = F.conv2d(ssh_tensor, laplacian_x, padding=0)
    laplaciansshy = F.conv2d(ssh_tensor, laplacian_y, padding=0)
    zp1 = nn.ZeroPad2d((1,1,0,0))
    zp2 = nn.ZeroPad2d((0,0,1,1))
    return zp1(laplaciansshx) + zp2(laplaciansshy)

def lossNN(ytrue, ypred, weight):
    mseSSH = loss_fn(ytrue,ypred)
    maePV = loss_l1(sshtoqSS(ytrue),sshtoqSS(ypred))
    return mseSSH+weight*maePV

In [4]:
def val_loop(dataloader, transform_params, model, saving_path):

    model.eval()
    with torch.no_grad():
        for counter, (ssh, it, bm) in enumerate(dataloader, 1):

            # GPU deployment
            ssh = ssh.cuda()
            it = it.cuda()
            bm = bm.cuda()

            # Compute prediction and loss
            bm_fake = model(ssh)
            it_fake = ssh - bm_fake
            
            y_fake = torch.cat([it_fake, bm_fake], dim=1)
            y = torch.cat([it, bm], dim=1)

            save_examples2(ssh, y, y_fake, transform_params, counter, saving_path)
            if counter == 5:
                break

In [7]:
# Without xbatcher
def main():
    
    since = time.time()
    
    cudnn.enabled = True
    cudnn.benchmark = True
    
    print(f"{MODEL} is deployed on {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Loading model
    model = IGWResNet().cuda()

    try:
        os.makedirs(SNAPSHOT_DIR)
    except FileExistsError:
        pass

    # Dataloader
    
    joint_transforms = Transform2(crop=96)
    
    train_dataset = Agulhas2(split='train', joint_transform=joint_transforms)
    val_dataset = Agulhas2(split='val', joint_transform=joint_transforms)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                  num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)

    val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True,
                                num_workers=NUM_WORKERS, pin_memory=True, drop_last=False)

    # Initializing the loss function and optimizer
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
    
    
    if USE_CHECKPOINT:
        load_checkpoint(f'{RESTORE_FROM}/model.pth.tar', model, optimizer, LEARNING_RATE)
    
    transform_params = dict()
    transform_params['inputs_mean'] = train_dataset.inps_mean_std[0]
    transform_params['inputs_std'] = train_dataset.inps_mean_std[1]
    transform_params['targets_mean'] = train_dataset.tars_mean_std[0]
    transform_params['targets_std'] = train_dataset.tars_mean_std[1]
    transform_params['targets_bm_mean'] = train_dataset.tars_bm_mean_std[0]
    transform_params['targets_bm_std'] = train_dataset.tars_bm_mean_std[1]
    
    model.train()

    for epoch in range(NUM_EPOCHS):
        print('Epoch:', epoch,'LR:', scheduler.get_last_lr())
        
        train_loss_per_epoch = []
        running_loss = 0
        for batch_idx, (ssh, _, bm) in enumerate(train_dataloader):

            ssh = ssh.to(device)
            bm = bm.to(device)

            pred = model(ssh)
            loss = lossNN(bm, pred, 1.0E3)    

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            oss = loss.detach()
            running_loss += loss.item()

            if batch_idx % 30 == 0:
                print(f"Epoch: [{epoch}/{NUM_EPOCHS}] Batch: {batch_idx:>2}/{len(train_dataloader)} "
                      f"Loss: {loss.item():.4f},")# LR:{lr:#.4E}")

        scheduler.step()
        
        train_loss_per_epoch.append(running_loss / len(train_dataloader))
        current_dir = os.path.join(SNAPSHOT_DIR, f'epoch-{epoch:003d}')
        
        try:
            os.makedirs(current_dir)
        except FileExistsError:
            pass

        val_loop(val_dataloader, transform_params, model, current_dir)
        
        if epoch % 50 == 0:
            save_checkpoint(model, optimizer,  os.path.join(current_dir, "model.pth.tar"))
    
    
    with open(os.path.join(current_dir, "loss.npy"), mode = 'wb') as f:
        np.save(f, np.array(train_loss_per_epoch))



In [None]:
main()

DetideNet is deployed on Tesla T4
Epoch: 0 LR: [0.001]
Epoch: [0/201] Batch:  0/241 Loss: 3.3031,
Epoch: [0/201] Batch: 30/241 Loss: 3.7344,
Epoch: [0/201] Batch: 60/241 Loss: 1.6212,
Epoch: [0/201] Batch: 90/241 Loss: 2.1468,
Epoch: [0/201] Batch: 120/241 Loss: 2.3538,
Epoch: [0/201] Batch: 150/241 Loss: 1.6899,
Epoch: [0/201] Batch: 180/241 Loss: 2.5755,
Epoch: [0/201] Batch: 210/241 Loss: 2.3188,
Epoch: [0/201] Batch: 240/241 Loss: 2.1224,
Saving Checkpoint...
Epoch: 1 LR: [0.001]
Epoch: [1/201] Batch:  0/241 Loss: 1.4065,
Epoch: [1/201] Batch: 30/241 Loss: 1.7729,
Epoch: [1/201] Batch: 60/241 Loss: 1.8270,
Epoch: [1/201] Batch: 90/241 Loss: 2.3510,
Epoch: [1/201] Batch: 120/241 Loss: 1.6949,
Epoch: [1/201] Batch: 150/241 Loss: 1.8422,
Epoch: [1/201] Batch: 180/241 Loss: 2.1323,
Epoch: [1/201] Batch: 210/241 Loss: 2.0958,
Epoch: [1/201] Batch: 240/241 Loss: 1.2629,
Epoch: 2 LR: [0.001]
Epoch: [2/201] Batch:  0/241 Loss: 2.1188,
Epoch: [2/201] Batch: 30/241 Loss: 2.0681,
Epoch: [2/20