In [10]:
import os
import re
import cv2
import torch
import torchvision
import pandas as pd
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from TrainDataset import TrainDataset
from TestDataset import TestDataset
#from model import ReconNet
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
import matplotlib.pyplot as pyplot
%matplotlib inline
from utils import resize_image, restore_image
#from ourModel import Model
from model import Model

In [11]:
path = './images/all_images/BSDS200'
compression_percentage = 1
test_file_name = 'test1.png'
num_epochs = 200

#name='Our_new_model_2fclayer_200_epoch_' + str(compression_percentage) + '_compression_rate'
name='Reconnet_original_200_epoch_' + str(compression_percentage) + '_compression_rate'
original_image = './images/test_sample_' + name + '.png'
sample_image = './images/sample_result_' + name + '.png'
model_name = 'model_state_'+ name +'.pth'
log_file_name = 'model_state_'+ name +'.txt'
compression_rate = compression_percentage/100
ratio_dict = {1: 10, 4: 43, 10: 109, 20: 218, 25: 272, 30: 327, 40: 436, 50: 545}

In [12]:
def normalize(v):
    return v / np.sqrt(v.dot(v))

def generate_phi(x, y):
    np.random.seed(333)
    phi = np.random.normal(size=(x, y))
    n = len(phi)
    
    # Perform Gram-Schmidt orthonormalization
    phi[0, :] = normalize(phi[0, :])
    
    for i in range(1, n):
        Ai = phi[i, :]
        for j in range(0, i):
            Aj = phi[j, :]
            t = Ai.dot(Aj)
            Ai = Ai - t * Aj
        phi[i, :] = normalize(Ai)
        
    return phi

mat = generate_phi(ratio_dict[compression_percentage], 1089)
mat = torch.from_numpy(mat)

In [13]:
import torch
import time
import numpy as np

transformations = transforms.Compose([transforms.ToTensor()])
train_data = TrainDataset(path,mat,transformations,compression_rate)

train_dl = DataLoader(train_data,batch_size=128)

train_iter = iter(train_dl)
images , labels = next(train_iter)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

model = Model(ratio_dict, compression_percentage, measurement_rate=compression_rate)
model = nn.DataParallel(model)
model = model.to(device)
print(model)

optimizer = optim.Adam(model.parameters(), lr=0.001)

from torch.optim.lr_scheduler import StepLR
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

criterion = nn.MSELoss()
criterion = criterion.to(device)


DataParallel(
  (module): Model(
    (fc1): Linear(in_features=10, out_features=1089, bias=True)
    (conv1): Conv2d(1, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv6): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)


In [14]:
import torch
import time
import numpy as np
import os
from skimage.metrics import peak_signal_noise_ratio as psnr

# Create directories to save models and logs
os.makedirs('saved_models', exist_ok=True)
os.makedirs('logs', exist_ok=True)

total_start_time = time.time()

def train(model, criterion, optimizer, train_dl, num_epochs=10, scheduler=None, early_stopping_patience=10, log_file_name='training_log.txt'):
    best_train_loss = float('inf')
    patience_counter = 0

    # Move model to CUDA device once
    model = model.cuda()
    
    print(model)

    # Open the log file
    log_file_path = os.path.join('logs', log_file_name)
    with open(log_file_path, 'w') as log_file:
        log_file.write(str(model) + '\n\n')

        for epoch in range(num_epochs):
            start_time = time.time()
            train_loss = []
            psnr_values = []
            model.train()

            for inp, lbl in train_dl:
                inp = inp.cuda().float()
                lbl = lbl.cuda().float()

                optimizer.zero_grad()
                out = model(inp)
                out = out.view(lbl.size())
                loss = criterion(out, lbl)
                loss.backward()
                optimizer.step()

                train_loss.append(loss.item())

                # Calculate PSNR
                out_cpu = out.detach().cpu().numpy()
                lbl_cpu = lbl.detach().cpu().numpy()
                psnr_values.append(psnr(lbl_cpu, out_cpu))

            epoch_time = time.time() - start_time
            total_time = time.time() - total_start_time  # Total time taken for all epochs so far

            mean_train_loss = np.mean(train_loss)
            mean_psnr = np.mean(psnr_values)

            log_message = (f'Epoch: {epoch+1}/{num_epochs}, Training Loss: {mean_train_loss:.10f}, '
                           f'PSNR: {mean_psnr:.4f}, Time taken: {epoch_time:.2f} seconds, '
                           f'Total time taken: {total_time:.2f} seconds\n')

            print(log_message)
            log_file.write(log_message)

            # Save the best model
            if mean_train_loss < best_train_loss:
                best_train_loss = mean_train_loss
                patience_counter = 0
                model_path = os.path.join('saved_models', f'best_model_epoch_{epoch+1}.pth')
                torch.save(model.state_dict(), model_path)
                save_message = f"Saved best model at epoch {epoch+1} with training loss: {mean_train_loss:.6f}\n"
                print(save_message)
                log_file.write(save_message)
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    early_stop_message = "Early stopping triggered\n"
                    print(early_stop_message)
                    log_file.write(early_stop_message)
                    break

            # Learning rate scheduler
            if scheduler:
                scheduler.step()

        total_duration = time.time() - total_start_time
        completion_message = f"Training completed in: {total_duration:.2f} seconds\n"
        print(completion_message)
        log_file.write(completion_message)


In [15]:
train(model, criterion, optimizer, train_dl, num_epochs=200, scheduler=scheduler, early_stopping_patience=10, log_file_name=log_file_name)


DataParallel(
  (module): Model(
    (fc1): Linear(in_features=10, out_features=1089, bias=True)
    (conv1): Conv2d(1, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv6): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)
Epoch: 1/200, Training Loss: 0.0253394695, PSNR: 18.3993, Time taken: 30.82 seconds, Total time taken: 35.44 seconds

Saved best model at epoch 1 with training loss: 0.025339

Epoch: 2/200, Training Loss: 0.0165371638, PSNR: 19.7381, Time taken: 31.06 seconds, Total time taken: 66.50 seconds

Saved best model at epoch 2 with training loss: 0.016537

Epoch: 3/200, Training Loss: 0.0162802115, PSNR: 19.8242, Time taken: 31.13

KeyboardInterrupt: 

In [None]:
#best_model_name = './saved_models/best_model_epoch_133.pth'

In [16]:
torch.save(model.state_dict(), model_name)

In [17]:
state_dict = torch.load(model_name,map_location ='cpu')
model.load_state_dict(state_dict)
model.eval()

DataParallel(
  (module): Model(
    (fc1): Linear(in_features=10, out_features=1089, bias=True)
    (conv1): Conv2d(1, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv3): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv5): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv6): Conv2d(8, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)