In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
from torchvision.io import read_image
from model import RedCNN
from customDataset import CatdogDataset
import cv2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil

from PIL import Image
from IPython.display import display
import warnings
from sklearn.preprocessing import normalize
warnings.filterwarnings('ignore')

In [7]:
gtpath="../catData/gt/"
gtfiles=[]
for folder, subfolder, filenames in os.walk(gtpath):
    for files in filenames:
        gtfiles.append(files)
        

noisepath="../catData/noise/"
noisefiles=[]
for folder, subfolder, filenames in os.walk(noisepath):
    for files in filenames:
        noisefiles.append(files)
        
#check whether both noise and gt have same files
print(gtfiles==noisefiles)

True


In [20]:
catdataset= CatdogDataset(data_path=noisepath,target_path=gtpath,filenames=gtfiles)
train_size=int(0.7*len(catdataset))
test_size=len(catdataset)-train_size
trainset,testset=random_split(catdataset,[train_size,test_size],generator=torch.manual_seed(42))


In [21]:
model=RedCNN()
print(model)
criterion=nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

RedCNN(
  (conv1): Conv2d(3, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv1): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv2): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv3): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv4): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_output): ConvTranspose2d(96, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (batchnorm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [22]:
# ##Training
optimizer.zero_grad()

train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
test_loader = DataLoader(testset, batch_size=32, shuffle=True)

for epoch in range(1):
    losses_batch=[]
    trained_psnr=[]
    for i, data in enumerate(train_loader):
        
        inputt, target=data
        
        target_pred=model(inputt)
        
        img=inputt.detach().numpy()
        recon_image=target_pred.detach().numpy()
        for j in range(len(target)):
            trained_psnr.append(cv2.PSNR(img[j],recon_image[j]))
        
        loss=criterion(target_pred,target)
        losses_batch.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i%10==0):
            print(f"Epoch: {epoch} Loss:{loss.item()} Batch:{i}")




Epoch: 0 Loss:849.6525268554688 Batch:0
Epoch: 0 Loss:445.1080322265625 Batch:10
Epoch: 0 Loss:329.1041564941406 Batch:20
Epoch: 0 Loss:282.9747009277344 Batch:30
Epoch: 0 Loss:278.25848388671875 Batch:40
Epoch: 0 Loss:268.8425598144531 Batch:50
Epoch: 0 Loss:254.30479431152344 Batch:60
Epoch: 0 Loss:261.08843994140625 Batch:70
Epoch: 0 Loss:264.8851623535156 Batch:80
Epoch: 0 Loss:259.65899658203125 Batch:90
Epoch: 0 Loss:261.24163818359375 Batch:100
Epoch: 0 Loss:250.15689086914062 Batch:110
Epoch: 0 Loss:246.88502502441406 Batch:120
Epoch: 0 Loss:257.6952209472656 Batch:130
Epoch: 0 Loss:244.6479034423828 Batch:140


KeyboardInterrupt: 

In [23]:
fisher_dict={}
param_dict={}
            
for name, param in model.named_parameters():
    param_dict[name] = param.data.clone()
    fisher_dict[name] = param.grad.data.clone().pow(2)

modelfile="trained_on_cat_130batch.pth.tar"

torch.save(model.state_dict(),modelfile)

In [24]:
gt_psnr=[]

for i in range(len(trainset)):
    noise,gt=trainset[i]
    gt_psnr.append(cv2.PSNR(noise.numpy(),gt.numpy()))
    
    
print(f"Ground Truth PSNR: {sum(gt_psnr)/len(gt_psnr)}")
print(f"PSNR after training: {sum(trained_psnr)/len(trained_psnr)}")

Ground Truth PSNR: 18.774365291733467
PSNR after training: 21.489992950461797
