In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
from torchvision.io import read_image
from model import RedCNN
from customDataset import CatdogDataset
import cv2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil

from PIL import Image
from IPython.display import display
import warnings
from sklearn.preprocessing import normalize
warnings.filterwarnings('ignore')

In [10]:
gtpath="../dogData/gt/"
gtfiles=[]
for folder, subfolder, filenames in os.walk(gtpath):
    for files in filenames:
        gtfiles.append(files)
        

noisepath="../dogData/noise/"
noisefiles=[]
for folder, subfolder, filenames in os.walk(noisepath):
    for files in filenames:
        noisefiles.append(files)
        
#check whether both noise and gt have same files
print(gtfiles==noisefiles)

True


In [11]:
dogdataset= CatdogDataset(data_path=noisepath,target_path=gtpath,filenames=gtfiles)
train_size=int(0.7*len(dogdataset))
test_size=len(dogdataset)-train_size
trainset,testset=random_split(dogdataset,[train_size,test_size],generator=torch.manual_seed(42))


In [19]:
model_dog=RedCNN()
model_dog.load_state_dict(torch.load("trained_on_cat_130batch.pth.tar"))
print(model_dog)
criterion=nn.MSELoss()
optimizer = torch.optim.Adam(model_dog.parameters(), lr=0.0001)

RedCNN(
  (conv1): Conv2d(3, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv3): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv4): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv1): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv2): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv3): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_conv4): ConvTranspose2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (d_output): ConvTranspose2d(96, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (batchnorm): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [20]:
##Training
optimizer.zero_grad()
ewc_lambda=0.1

param_dict=torch.load("parameter_dict.pth.tar")
fisher_dict=torch.load("fisher_dict.pth.tar")

train_loader = DataLoader(trainset, batch_size=32, shuffle=True)
test_loader = DataLoader(testset, batch_size=32, shuffle=True)


losses_batch=[]
trained_psnr=[]
for i, data in enumerate(train_loader):

    inputt, target=data

    target_pred=model_dog(inputt)

    img=inputt.detach().numpy()
    recon_image=target_pred.detach().numpy()
    for j in range(len(target)):
        trained_psnr.append(cv2.PSNR(img[j],recon_image[j]))

    loss=criterion(target_pred,target)

    for name, param in model_dog.named_parameters():
        fisher=fisher_dict[name]
        theta_star=param_dict[name]
        loss+=(fisher*(param-theta_star).pow(2)).sum()*ewc_lambda




    losses_batch.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i%10==0):
        print(f"Epoch: 1 Loss:{loss.item()} Batch:{i}")




Epoch: 1 Loss:252.43943786621094 Batch:0
Epoch: 1 Loss:273.6606140136719 Batch:10
Epoch: 1 Loss:265.46087646484375 Batch:20
Epoch: 1 Loss:274.1610107421875 Batch:30
Epoch: 1 Loss:263.66766357421875 Batch:40
Epoch: 1 Loss:251.27737426757812 Batch:50
Epoch: 1 Loss:264.0395202636719 Batch:60
Epoch: 1 Loss:281.7471618652344 Batch:70
Epoch: 1 Loss:257.4654846191406 Batch:80
Epoch: 1 Loss:245.90756225585938 Batch:90
Epoch: 1 Loss:270.5928039550781 Batch:100
Epoch: 1 Loss:258.30633544921875 Batch:110
Epoch: 1 Loss:269.86566162109375 Batch:120
Epoch: 1 Loss:265.7371520996094 Batch:130
Epoch: 1 Loss:269.8819580078125 Batch:140
Epoch: 1 Loss:259.597900390625 Batch:150
Epoch: 1 Loss:250.15341186523438 Batch:160
Epoch: 1 Loss:248.4051971435547 Batch:170
Epoch: 1 Loss:252.6630096435547 Batch:180


In [21]:
# fisher_dict={}
# param_dict={}
            
# for name, param in model.named_parameters():
#     param_dict[name] = param.data.clone()
#     fisher_dict[name] = param.grad.data.clone().pow(2)

# param_dictfile="parameter_dict.pth.tar"
# fisher_dictfile="fisher_dict.pth.tar"

# torch.save(fisher_dict,fisher_dictfile)
# torch.save(param_dict,param_dictfile)
modelfile="trained_on_dogwithewc.pth.tar"

torch.save(model_dog.state_dict(),modelfile)

In [None]:
# model=RedCNN()

# model.load_state_dict(torch.load(modelfile))

In [None]:
# gt_psnr=[]
# train_psnr=[]
# for i in range(len(trainset)):
#     noise,gt=trainset[i]
#     gt_psnr.append(cv2.PSNR(noise.numpy(),gt.numpy()))
#     train_psnr.append(cv2.PSNR(model(noise.view(-1,3,224,224)).detach().numpy().reshape(3,224,224),gt.numpy()))
    
# print(f"Ground Truth PSNR(trainset): {sum(gt_psnr)/len(gt_psnr)}")
# print(f"PSNR after training: {sum(trained_psnr)/len(trained_psnr)}")