<a href="https://colab.research.google.com/github/vinayak19th/ARCNN-keras/blob/main/ARCNN_att.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ARCNN Attention Based
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/Tensorflow?&logo=Python&style=for-the-badge)
![Tensorflow Version](https://img.shields.io/static/v1?label=Tensorflow&message=2.1%2B&color=ffcc00&logo=Tensorflow&logoColor=ffcc00&style=for-the-badge)
![Docker Image Size](https://img.shields.io/static/v1?label=DockerImage&message=3.35GB&color=0066ff&logo=Docker&style=for-the-badge)

Part of the [ARCNN-Keras](https://github.com/vinayak19th/ARCNN-keras) repo

In [7]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, Lambda, Conv2DTranspose, SeparableConv2D
import glob
import os

## Creating the Dataloader

In [2]:
@tf.function
def create_pairs(flist,jpq=(10,20)):
    images = tf.TensorArray(tf.float16,dynamic_size=True,size=0,infer_shape=True)
    images_comp = tf.TensorArray(tf.float16,dynamic_size=True,size=0,infer_shape=True)
    c =0
    for file in flist:
        y = tf.image.decode_jpeg(tf.io.read_file(file))
        x = tf.image.random_jpeg_quality(y,jpq[0],jpq[1])
        y = tf.expand_dims(tf.image.rgb_to_yuv(tf.cast(y,tf.float16))[:,:,0],-1)/255
        x = tf.expand_dims(tf.image.rgb_to_yuv(tf.cast(x,tf.float16))[:,:,0],-1)/255
        images = images.write(c,y)
        images_comp =images_comp.write(c,x)
        c+=1
    y = images.stack()
    x = images_comp.stack()
    return (x,y)

@tf.function
def create_patches(x,y,p,s):
    print("Shapes")
    batch_size = tf.shape(y)[0]
    print(batch_size)
    #Extracting patches and converting into batches
    y_patches = tf.image.extract_patches(images=y,sizes=(1,p,p,1),strides=(1,s,s,1),rates=(1,1,1,1),padding='VALID')
    #Calculating patch sizes and batches
    shapes= tf.shape(y_patches)
    patch_batch = int(shapes[1]*shapes[2]*batch_size)
    
    y_patches = tf.reshape(y_patches,(patch_batch,p,p,1))
    print("y_patches :",y_patches.shape)
    
    x_patches = tf.image.extract_patches(images=x,sizes=(1,p,p,1),strides=(1,s,s,1),rates=(1,1,1,1),padding='VALID')
    x_patches = tf.reshape(x_patches,(patch_batch,p,p,1))
    print("x_patches :",x_patches.shape)
    return (x_patches,y_patches)

def create_artifact_dataset(fpath = "HarmonicI_720p_1000k_1440p_bicubic/480/",batch_size=32,p=100,s=42,jpq=(10,20),fformat="*.jpg"):
    """
    Wrapper function to return tf.dataset object with all the data
        fpath : Path to folder containing jpeg files
            ex:HarmonicI_720p_1000k_1440p_bicubic/480/
            HR should be a similar directory with the parent changed from 480 to 960
            ex:HarmonicI_720p_1000k_1440p_bicubic/960/
        batch_size : size of batches per batch of patches
        p : Patch size
        s : stride size
        jpq : Tuple(min,max)
            ex: jpq = (10,20) ; where min quality is 10 and max is 20
    """
    flist = glob.glob(os.path.join(fpath,fformat))
    print("flist:",len(flist))
    artifact_dataset = tf.data.Dataset.from_tensor_slices(flist).batch(32)
    
    func = lambda x: create_pairs(x,jpq)
    artifact_dataset = artifact_dataset.map(func,num_parallel_calls=tf.data.experimental.AUTOTUNE)
    print("JPEG Pairs created with quality of range:",jpq,"\n--------------------")
    
    func = lambda x,y: create_patches(x,y,p,s)
    artifact_dataset = artifact_dataset.map(func,num_parallel_calls=tf.data.experimental.AUTOTUNE)
    print("Created Patches\n--------------------")
    
    artifact_dataset = artifact_dataset.unbatch().batch(batch_size)
    print("Dataset batches of batch size",batch_size,"\n--------------------")
    print("Dataset Spec:\n",artifact_dataset.element_spec)
    
    artifact_dataset = artifact_dataset.cache()
    return artifact_dataset

In [8]:
data = create_artifact_dataset(fpath="/content/drive/MyDrive/Colab_Notebooks/arcnn/480")
data = data.prefetch(tf.data.experimental.AUTOTUNE)

flist: 720
JPEG Pairs created with quality of range: (10, 20) 
--------------------
Shapes
Tensor("strided_slice:0", shape=(), dtype=int32)
y_patches : (None, 100, 100, 1)
x_patches : (None, 100, 100, 1)
Created Patches
--------------------
Dataset batches of batch size 32 
--------------------
Dataset Spec:
 (TensorSpec(shape=(None, 100, 100, 1), dtype=tf.float16, name=None), TensorSpec(shape=(None, 100, 100, 1), dtype=tf.float16, name=None))


## Creating the model


### Model Architecture

In [9]:
class PixelAttention(tf.keras.layers.Layer):
    def __init__(self, nf,name ='PixAttention'):
        super(PixelAttention, self).__init__(name=name)
        self._name = name
        self.conv1 = Conv2D(filters=nf,kernel_size=1)
    
    def call(self,x):
        y = self.conv1(x)
        self.sig = tf.keras.activations.sigmoid(y)
        out = tf.math.multiply(x,y)
        return out

In [10]:
def get_ARCNN_att(input_shape=(32,32,1)):
    inp = Input(shape=input_shape)
    conv1 = Conv2D(32,5,dilation_rate=4,activation='relu', padding='same', use_bias=True,name="Feature_extract")(inp)
    conv2 = Conv2D(32,5,dilation_rate=2,activation='relu', padding='same', use_bias=True,name="Feature_Enhance")(conv1)
    pa2 = PixelAttention(32,name="PA2")(conv2)
    conv3 = Conv2D(32,1,activation='relu', padding='valid', use_bias=True,name="Mapping")(pa2)
    pa3 = PixelAttention(32,name="PA3")(conv3)
    conv4 = Conv2D(1,3,dilation_rate=4,name="Image",padding='same')(pa3)
    ARCNN = Model(inputs=inp,outputs=conv4)
    return ARCNN

In [19]:
ARCNN = get_ARCNN_att([None,None,1])
ARCNN.summary()

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, None, None, 1)]   0         
_________________________________________________________________
Feature_extract (Conv2D)     (None, None, None, 32)    832       
_________________________________________________________________
Feature_Enhance (Conv2D)     (None, None, None, 32)    25632     
_________________________________________________________________
PA2 (PixelAttention)         (None, None, None, 32)    1056      
_________________________________________________________________
Mapping (Conv2D)             (None, None, None, 32)    1056      
_________________________________________________________________
PA3 (PixelAttention)         (None, None, None, 32)    1056      
_________________________________________________________________
Image (Conv2D)               (None, None, None, 1)     289 

### Custom losses and metrics

In [15]:
def ssim(y_true,y_pred):
    return tf.image.ssim(y_true,y_pred,max_val=1.0)

def psnr(y_true,y_pred):
    return tf.image.psnr(y_true,y_pred,max_val=1.0)

@tf.function
def custom_loss(y_true, y_pred):
    alpha = tf.constant(0.30)
    mssim = alpha*(1-tf.image.ssim_multiscale(y_true,y_pred,max_val=1.0,filter_size=3))
    mse = tf.metrics.mae(y_true, y_pred)
    loss = tf.reduce_mean(mssim) + (1-alpha)*tf.reduce_mean(mse)
    return loss


### Checkpoints & Training

In [21]:
def makedirs(path):
    try:
        os.mkdir(path)
    except:
        pass

In [24]:
filepath="/content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-{epoch:02d}-{ssim:.2f}.hdf5"
cp = tf.keras.callbacks.ModelCheckpoint(filepath,monitor="ssim",verbose=1,save_weights_only=True)
makedirs("/content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/")

In [26]:
optim = tf.keras.optimizers.Adam(learning_rate=5e-4)
ARCNN.compile(optimizer=optim,loss=custom_loss,metrics=[ssim,psnr])
ARCNN.fit(data,epochs=10,callbacks=[cp])

Epoch 1/10

Epoch 00001: saving model to /content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-01-0.89.hdf5
Epoch 2/10

Epoch 00002: saving model to /content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-02-0.90.hdf5
Epoch 3/10

Epoch 00003: saving model to /content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-03-0.90.hdf5
Epoch 4/10

Epoch 00004: saving model to /content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-04-0.90.hdf5
Epoch 5/10

Epoch 00005: saving model to /content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-05-0.90.hdf5
Epoch 6/10

Epoch 00006: saving model to /content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-06-0.90.hdf5
Epoch 7/10

Epoch 00007: saving model to /content/drive/MyDrive/Colab_Notebooks/arcnn/checkpoints/weights-improvement-07-0.90.hdf5
Epoch 8/10

Epoch 00008: saving model to /content/drive/MyDrive/Colab_Notebooks/arc

<tensorflow.python.keras.callbacks.History at 0x7f2494ab3990>

In [28]:
print("Saving Model")
ARCNN.save("/content/drive/MyDrive/Colab_Notebooks/arcnn/att",save_format="tf")

Saving Model




INFO:tensorflow:Assets written to: /content/drive/MyDrive/Colab_Notebooks/arcnn/att/assets


INFO:tensorflow:Assets written to: /content/drive/MyDrive/Colab_Notebooks/arcnn/att/assets


# Testing

In [43]:
import PIL

In [33]:
model = tf.keras.models.load_model("/content/drive/MyDrive/Colab_Notebooks/arcnn/att",compile=False)









In [39]:
flist = np.asarray(glob.glob(os.path.join("/content/drive/MyDrive/Colab_Notebooks/arcnn/tests/","*")))
count = 0
total = len(flist) 
print("Processing",total,"files")


Processing 3 files


In [41]:
def process_image_SR(impath):
  im = PIL.Image.open(impath)
  im = im.convert('YCbCr') # For single channel inference
  im = np.asanyarray(im)
  y = np.expand_dims(im[:,:,0],-1)/255 # Normalizing input
  uv = np.asanyarray(im)[:,:,1:]
  #print("uv:",uv.shape,"| y:",y.shape)
  return (y,uv)

In [44]:
prog = tf.keras.utils.Progbar(total,unit_name='frames')
for i in flist:
  im_y,im_uv = process_image_SR(i)
  #print(im_y.shape)
  im_y = np.expand_dims(im_y,0)
  outs = ARCNN.predict(im_y)
  count += 1
  out = outs.reshape(im_y.shape[1], im_y.shape[2]) #Removing batch dimensions
  y_pred = np.stack([out*255,im_uv[:,:,0],im_uv[:,:,1]],axis=-1)
  y_pred= np.clip(y_pred,0,255).astype('uint8')
  y_pred = PIL.Image.fromarray(y_pred,mode='YCbCr').convert('RGB')
  fname = "out"+ i.split("/")[-1]
  converter = PIL.ImageEnhance.Color(y_pred)
  y_pred = converter.enhance(1.4)
  y_pred.save("/content/drive/MyDrive/Colab_Notebooks/arcnn/outputs/"+fname)
  prog.update(count)
print("\nDone")


Done
