In [7]:
import os
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class Config:
    def __init__(self):
        self.is_train = True
        self.image_size = 33
        self.label_size = 33
        self.scale = 2
        self.stride = 14
        self.batch_size = 128
        self.learning_rate = 1e-4
        self.epoch = 10
        self.checkpoint_dir = './checkpoint'
        self.sample_dir = './sample'

config = Config()

def imread(path, is_grayscale=True):
    if is_grayscale:
        return np.array(Image.open(path).convert('YCbCr').split()[0], dtype=np.float32)
    else:
        return np.array(Image.open(path).convert('YCbCr'), dtype=np.float32)

def modcrop(image, scale=3):
    h, w = image.shape
    h = h - np.mod(h, scale)
    w = w - np.mod(w, scale)
    return image[0:h, 0:w]

def preprocess(image_path, scale=3):
    image = imread(image_path, is_grayscale=True)
    label_ = modcrop(image, scale)
    input_ = label_ / 255.0
    label_ = label_ / 255.0
    return input_, label_

def input_setup(config, image_path):
    input_, label_ = preprocess(image_path, config.scale)
    sub_input_sequence = []
    sub_label_sequence = []

    padding = abs(config.image_size - config.label_size) // 2

    h, w = input_.shape

    for x in range(0, h - config.image_size + 1, config.stride):
        for y in range(0, w - config.image_size + 1, config.stride):
            sub_input = input_[x:x + config.image_size, y:y + config.image_size]
            sub_label = label_[x + padding:x + padding + config.label_size, y + padding:y + padding + config.label_size]

            sub_input = sub_input.reshape([config.image_size, config.image_size, 1])
            sub_label = sub_label.reshape([config.label_size, config.label_size, 1])

            sub_input_sequence.append(sub_input)
            sub_label_sequence.append(sub_label)

    sub_input_sequence = np.array(sub_input_sequence)
    sub_label_sequence = np.array(sub_label_sequence)

    return sub_input_sequence, sub_label_sequence

def imsave(image, path):
    Image.fromarray(image.astype(np.uint8)).save(path)

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 1))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j * h:j * h + h, i * w:i * w + w, :] = image

    return img

# PyTorch Dataset class for handling image data
class ImageDataset(Dataset):
    def __init__(self, image_paths, config):
        self.image_paths = image_paths
        self.config = config

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        sub_input_sequence, sub_label_sequence = input_setup(self.config, image_path)
        return torch.tensor(sub_input_sequence, dtype=torch.float32).permute(0, 3, 1, 2), \
               torch.tensor(sub_label_sequence, dtype=torch.float32).permute(0, 3, 1, 2)

# Load images from a directory
def prepare_data(dataset):
    filenames = [os.path.join(dataset, file) for file in os.listdir(dataset) if file.endswith('.png')]
    return filenames

# Main code to run the setup
dataset_path = "./images"  # Directory containing images
image_paths = prepare_data(dataset_path)
dataset = ImageDataset(image_paths, config)
data_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

# To visualize the first sub-image and its label
sample_input, sample_label = dataset[0]
plt.subplot(1, 2, 1)
plt.imshow(sample_input[0][0].numpy(), cmap='gray')
plt.title('Sub-input')

plt.subplot(1, 2, 2)
plt.imshow(sample_label[0][0].numpy(), cmap='gray')
plt.title('Sub-label')

plt.show()


ImportError: /home/shivam/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so: undefined symbol: cudaGraphRetainUserObject, version libcudart.so.11.0

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
import os
import time
from math import ceil
from skimage.io import imsave
from skimage.util import view_as_windows

# Helper functions
def input_setup(config, image_path):
    # Assuming images are already loaded as numpy arrays
    image = np.load(image_path)  # Load your image as a numpy array
    sub_input_sequence = []
    sub_label_sequence = []

    padding = abs(config.image_size - config.label_size) // 2
    padded_image = np.pad(image, ((padding, padding), (padding, padding), (0, 0)), 'constant')

    for x in range(0, padded_image.shape[0] - config.image_size + 1, config.stride):
        for y in range(0, padded_image.shape[1] - config.image_size + 1, config.stride):
            sub_input = padded_image[x:x + config.image_size, y:y + config.image_size]
            sub_label = padded_image[x + padding:x + padding + config.label_size,
                                     y + padding:y + padding + config.label_size]
            sub_input_sequence.append(sub_input)
            sub_label_sequence.append(sub_label)

    return np.asarray(sub_input_sequence), np.asarray(sub_label_sequence)

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], images.shape[3]))

    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image

    return img

class RECONNET(nn.Module):
    def __init__(self, image_size=33, label_size=33, measurement_rate=1e-1, c_dim=1):
        super(RECONNET, self).__init__()
        self.fc_size = int(ceil(measurement_rate * 1089))
        
        self.fc1 = nn.Linear(1089, self.fc_size)
        self.fc2 = nn.Linear(self.fc_size, 1089)
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=11, padding=5)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.conv3 = nn.Conv2d(32, 1, kernel_size=7, padding=3)
        self.conv4 = nn.Conv2d(1, 64, kernel_size=11, padding=5)
        self.conv5 = nn.Conv2d(64, 32, kernel_size=1)
        self.conv6 = nn.Conv2d(32, 1, kernel_size=7, padding=3)
        
    def forward(self, x):
        x = x.view(-1, 1089)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = x.view(-1, 1, 33, 33)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        x = self.conv6(x)
        return x

def train(config, model, device, train_loader, optimizer, criterion):
    model.train()
    start_time = time.time()
    for epoch in range(config.epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0:
                print(f"Epoch: [{epoch+1}], step: [{batch_idx}], time: [{time.time() - start_time:.4f}], loss: [{loss.item():.8f}]")
            if batch_idx % 500 == 0:
                torch.save(model.state_dict(), os.path.join(config.checkpoint_dir, f'model_epoch_{epoch+1}_step_{batch_idx}.pth'))

def test(config, model, device, test_loader):
    model.eval()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            result = output.cpu().numpy()
            nx = ny = int(np.sqrt(len(result)))
            result = merge(result, [nx, ny])
            result = result.squeeze()
            image_path = os.path.join(config.sample_dir, "test.png")
            imsave(image_path, result)

class Config:
    def __init__(self):
        self.is_train = True
        self.image_size = 33
        self.label_size = 33
        self.scale = 2
        self.stride = 14
        self.epoch = 10
        self.batch_size = 128
        self.learning_rate = 1e-4
        self.checkpoint_dir = './checkpoint'
        self.sample_dir = './sample'

# Main code to run the setup
config = Config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your data as numpy arrays
image_path = "./images/test1.npy"  # Path to numpy array of the image
train_data, train_label = input_setup(config, image_path)

train_dataset = data.TensorDataset(torch.tensor(train_data, dtype=torch.float32).permute(0, 3, 1, 2), 
                                   torch.tensor(train_label, dtype=torch.float32).permute(0, 3, 1, 2))
train_loader = data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

model = RECONNET(image_size=config.image_size, label_size=config.label_size).to(device)
optimizer = optim.SGD(model.parameters(), lr=config.learning_rate, momentum=0.9)
criterion = nn.MSELoss()

if config.is_train:
    train(config, model, device, train_loader, optimizer, criterion)
else:
    test(config, model, device, train_loader)


ImportError: /home/shivam/anaconda3/lib/python3.8/site-packages/torch/lib/libtorch_cuda_cpp.so: undefined symbol: cudaGraphRetainUserObject, version libcudart.so.11.0