In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
from torchvision.io import read_image
from model import RedCNN
from customDataset import CatdogDataset
import cv2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil

from PIL import Image
from IPython.display import display
import warnings
from sklearn.preprocessing import normalize
warnings.filterwarnings('ignore')

In [2]:
gtpath="../dogData_60/gt/"
gtfiles=[]
for folder, subfolder, filenames in os.walk(gtpath):
    for files in filenames:
        gtfiles.append(files)
        

noisepath="../dogData_60/noise/"
noisefiles=[]
for folder, subfolder, filenames in os.walk(noisepath):
    for files in filenames:
        noisefiles.append(files)
        
#check whether both noise and gt have same files
print(gtfiles==noisefiles)
gtfiles_=gtfiles[:200]

True


In [6]:
m2dataset= CatdogDataset(data_path=noisepath,target_path=gtpath,filenames=gtfiles_)
train_size=int(0.7*len(m2dataset))
test_size=len(m2dataset)-train_size
trainset,testset=random_split(m2dataset,[train_size,test_size],generator=torch.manual_seed(42))


In [7]:
model_dog=RedCNN()
model_dog.load_state_dict(torch.load("./saved_models/task1_model.pth.tar"))
print(model_dog)
criterion=nn.MSELoss()
optimizer = torch.optim.Adam(model_dog.parameters(), lr=0.0001)

RedCNN(
  (conv1): Conv2d(3, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv1): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv2): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv3): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv4): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_output): ConvTranspose2d(96, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (batchnorm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [8]:
##Training
optimizer.zero_grad()
ewc_lambda=0.1


train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
test_loader = DataLoader(testset, batch_size=32, shuffle=True)


losses_batch=[]
trained_psnr=[]

for epoch in range(2):
    
    for i, data in enumerate(train_loader):

        inputt, target=data

        target_pred=model_dog(inputt)



        loss=criterion(target_pred,target)


        losses_batch.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch:{epoch} Loss:{loss.item()} Batch:{i}")


Epoch: 1 Loss:259.7182312011719 Batch:0
Epoch: 1 Loss:279.15966796875 Batch:10
Epoch: 1 Loss:248.8930206298828 Batch:20
Epoch: 1 Loss:260.74249267578125 Batch:30
Epoch: 1 Loss:274.3846130371094 Batch:40
Epoch: 1 Loss:276.06561279296875 Batch:50
Epoch: 1 Loss:271.1468811035156 Batch:60
Epoch: 1 Loss:271.7431945800781 Batch:70
Epoch: 1 Loss:249.69406127929688 Batch:80
Epoch: 1 Loss:251.94764709472656 Batch:90
Epoch: 1 Loss:248.34332275390625 Batch:100
Epoch: 1 Loss:252.2019805908203 Batch:110
Epoch: 1 Loss:265.9474792480469 Batch:120
Epoch: 1 Loss:262.1739196777344 Batch:130
Epoch: 1 Loss:251.37933349609375 Batch:140
Epoch: 1 Loss:256.38763427734375 Batch:150
Epoch: 1 Loss:266.0694274902344 Batch:160
Epoch: 1 Loss:251.97119140625 Batch:170
Epoch: 1 Loss:267.4698791503906 Batch:180


In [10]:
modelfile="trained_on_dogwithout_ewc.pth.tar"

torch.save(model_dog.state_dict(),modelfile)

In [11]:
gt_psnr=[]
train_psnr=[]
for i in range(len(testset)):
    noise,gt=testset[i]
    gt_psnr.append(cv2.PSNR(noise.numpy(),gt.numpy()))
    train_psnr.append(cv2.PSNR(model_dog(noise.view(-1,3,224,224)).detach().numpy().reshape(3,224,224),gt.numpy()))
    
print(f"Ground Truth PSNR(trainset): {sum(gt_psnr)/len(gt_psnr)}")
print(f"PSNR after applying model: {sum(train_psnr)/len(train_psnr)}")

Ground Truth PSNR(trainset): 18.688907881062395
PSNR after applying model: 23.892375907162517
