In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.functional as TF
from PIL import Image
import numpy as np
from UNet3d import UNet3D
from utils import *
from losses import *

In [2]:
gen_net = UNet3D(n_frames=10, feat_channels=[32, 128, 128, 256, 512],).cuda().train()

In [3]:
from torchsummary import summary
summary(gen_net, (3,10,128,128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [-1, 32, 10, 128, 128]           2,624
         LayerNorm-2           [-1, 163840, 32]              64
       LayerNorm3D-3     [-1, 32, 10, 128, 128]               0
              GELU-4     [-1, 32, 10, 128, 128]               0
            Conv3d-5     [-1, 32, 10, 128, 128]          27,680
         LayerNorm-6           [-1, 163840, 32]              64
       LayerNorm3D-7     [-1, 32, 10, 128, 128]               0
              GELU-8     [-1, 32, 10, 128, 128]               0
            Conv3d-9     [-1, 32, 10, 128, 128]              96
     Conv3D_Block-10     [-1, 32, 10, 128, 128]               0
        MaxPool3d-11       [-1, 32, 10, 64, 64]               0
           Conv3d-12       [-1, 32, 10, 64, 64]             896
           Conv3d-13      [-1, 128, 10, 64, 64]           4,224
         DWconv3D-14      [-1, 128, 10,

In [4]:
from dataset import DataLoaderTurb
from torch.utils.data import DataLoader

dataloader = DataLoaderTurb('./image2/img2')
train_loader = DataLoader(dataset=dataloader, batch_size=20, \
                          shuffle=True, num_workers=8, drop_last=True, pin_memory=True)
# mean_im = data.mean(0).permute(1,2,0).clamp(0,1).detach().cpu().numpy()
# pg_save = Image.fromarray(np.uint8(mean_im * 255))
# pg_save.save('mean2.jpg', 'JPEG')
mean_im = TF.to_tensor(Image.open('./image2/mean.jpg').convert("RGB")).cuda()
mean_im = torch.stack([mean_im, mean_im], dim=0)

In [5]:
lr = 5e-6
max_iters = 10000
optimizer = optim.AdamW(gen_net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)
scheduler = CosineDecayWithWarmUpScheduler(optimizer, 5, init_warmup_lr=5e-7, \
                                           warm_up_steps=1000,max_lr=lr, min_lr=1e-7,num_step_down=max_iters)
criterion_char = CharbonnierLoss()

In [6]:
# initialization
niter = 0

current_loss = 0
psnr = 0
best_psnr = 0
loss_mean = []
psnr_mean = []
while True:
    for data in train_loader:
        data = data.view((-1, 10, 3, 128, 128)).permute(0,2,1,3,4).cuda()
        output = gen_net(data)
        
        loss = criterion_char(output, mean_im)
        loss.backward()
        optimizer.step()
        scheduler.step()
        current_loss += loss.item()
        psnr += calculate_psnr(output.detach().cpu().numpy()*255, \
                               mean_im.detach().cpu().numpy()*255, border=0)
        niter += 1
        if niter % 100 == 0:
            loss_mean.append(current_loss/100)
            psnr_mean.append(psnr/100)
            out = output[0,...].data.squeeze().float().cpu().clamp_(0, 1).numpy() * 255
            out = np.transpose(out, (1, 2, 0)).round().astype(np.uint8)
            out_save = Image.fromarray(out)
            out_save.save(f'./image/init_imgs/img_{niter}.jpg', "JPEG")
            
            print('iteration: {:d} lr: {:.8f} loss-100-iter: {:.8f} psnr-100-iter: {:4f}'.format(
                niter, optimizer.param_groups[0]['lr'], current_loss/100, psnr/100))
            if psnr/100 > best_psnr:
                torch.save({'step': niter, 
                            'best_psnr': psnr/100,
                            'state_dict': gen_net.state_dict(),
                            'optimizer' : optimizer.state_dict()
                            }, f"best_res.pth")      
                best_psnr = psnr/100
            current_loss = 0
            psnr = 0
    if niter >= max_iters:
        break
        
torch.save({'step': niter, 
            'state_dict': gen_net.state_dict(),
            'optimizer' : optimizer.state_dict()
            }, f"model_step_{niter}.pth")

iteration: 100 lr: 0.00000095 loss-100-iter: 0.23476454 psnr-100-iter: 9.503778
iteration: 200 lr: 0.00000140 loss-100-iter: 0.17692056 psnr-100-iter: 11.316060
iteration: 300 lr: 0.00000185 loss-100-iter: 0.13901213 psnr-100-iter: 14.167670
iteration: 400 lr: 0.00000230 loss-100-iter: 0.09433629 psnr-100-iter: 18.617477
iteration: 500 lr: 0.00000275 loss-100-iter: 0.04128907 psnr-100-iter: 26.063072
iteration: 600 lr: 0.00000320 loss-100-iter: 0.03906770 psnr-100-iter: 26.434854
iteration: 700 lr: 0.00000365 loss-100-iter: 0.06422441 psnr-100-iter: 22.155913
iteration: 800 lr: 0.00000410 loss-100-iter: 0.06107317 psnr-100-iter: 22.338365
iteration: 900 lr: 0.00000455 loss-100-iter: 0.04134213 psnr-100-iter: 25.431169
iteration: 1000 lr: 0.00000500 loss-100-iter: 0.02447659 psnr-100-iter: 29.379364
iteration: 1100 lr: 0.00000500 loss-100-iter: 0.03323309 psnr-100-iter: 27.069005
iteration: 1200 lr: 0.00000500 loss-100-iter: 0.02195956 psnr-100-iter: 30.047891
iteration: 1300 lr: 0.0000

In [7]:
out.shape

(128, 128, 3)