<a href="https://colab.research.google.com/github/zoubidaameur/Deep-Multi-Task-Learning-for-Image-Video-Distortions-Identification/blob/main/Deep_MTL_distortion_identification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **Deep Multi-Task Learning for Image/Video Distortions Identification**




<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/generative/pix2pix"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/generative/pix2pix.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This notebook demonstrates image distortions identification using deep multi-task learning. Using this technique you can identify and classify several distortion types using a single model simultaneously and accurately.

### **Import required libraries**

In [None]:
import os, sys
import pickle
import csv
import pandas as pd
import numpy as np

from tensorflow.keras.preprocessing import image
from skimage.util import view_as_windows
from tensorflow.keras.layers import MaxPooling2D ,Dense ,Dropout, Flatten
from tensorflow.keras.models import Model 
from tensorflow.keras.applications.densenet import DenseNet169
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback,TensorBoard

from sklearn.metrics import accuracy_score, precision_score



### **Load dataset**

In [None]:
############## Uncomment the section of the desired dataset #####################


####### TID-2013 #########
# !wget "http://www.ponomarenko.info/tid2013/tid2013.rar"
# !pip install unrar
# !unrar x "/content/tid2013.rar"
# mkdir /content/drive/MyDrive/tid
# !mv /content/* /content/drive/MyDrive/tid


####### KADID-10K #########
# !wget "https://datasets.vqa.mmsp-kn.de/archives/kadid10k.zip"
# !unzip /content/kadid10k.zip
# mkdir /content/drive/MyDrive/kadid
# mv /content/* /content/kadid



####### CSIQ #########*
# !wget "http://vision.eng.shizuoka.ac.jp/csiq/dst_imgs.zip"
# !unzip /content/dst_imgs.zip
# mkdir /content/drive/MyDrive/csiq
# mv /content/* /content/csiq

####### LIVEMD #########*
# !wget https://public.boxcloud.com/d/1/b1!IU4S1kNcRl9668x9nt0yijL48I6EGcI3qmccUX2YNXfVw4O2LNS4fEyI3x5aNXOL2OZWHNt-Z7vTEijwPWtsasa8_P2sdaE44u-7QR1N6cNOC3afB8Szq4biRIvtRNmLTnom6NZfdFNQSMbjG6g2yTbPpRoE1YuEGIT648tUedT_eDMHGEPDyINX4hOPrRV1CvIDYMqR4K7Oa0TrM689E8nF-RDRTH2ijx0PSDc84TxdORQ79XRpIq59K3-1OEkLvnDrpcPLxsZXiZAHNjrjggCjYNscJ83COC3_JUWgR6RQ_GpvoyB_60ba1b6o76mQ1UbFRnJ3snPCEuTxb_396uRdq4tEWrnf4G-dn5NKLdvofFiuXfFFEssLoRk3beeY10EuU7z-z6w2sB_3bgJnMysFwUleBBmEgk7zbizL6rtqZ6jxcRhzmGFD7JubS8sP_nQOrIo9JMbf95oMIfLQsom7A1LlgoSyeHJ23QTQuS1Syzjo7_iHE98jBJV3LgSRGsRLPCLfgbCEFmAQEWZ5qIHETx9FEsHtPMrCB8elqLLpfzYbRf5yq8_75sM7pn_Z2ardTDAOa_Uot-nP_rqVMTCHcJSjDPygX7wNwiGITIIQZrL5zX5LUXzCmRdGjgOCuATUIOKniWKrPRmd4lowJ5kJMHM4Gd87OifblNSMxxH1jyaViMT9z5cH8kPl2Eybl3SHTsmbiernufEAnKBaHk4YjEFFTgDaNMklhclyobRNiXJaem5IiO7qmgNhgOhbT0yAYPswQA19ufH0YHzEdEuXuQ1GW2ZqUsY-kojA_0Lz2kt-kWYgIUG8UpopB_YuB64n96icArMIbbDKdbXnsZVu0myUuEMOTkgMTWKsiCYsinvBwOD4b0l1kDtkBAsQToEU34us_xyloT_A1PoV_E6bKJPTdeaTpryWWe-RSM6T5eVnKWlRZqa73-T5WqMIkGLFTdRWJTGs1VlT_GZdFCqPR-VLg6jYvZQ7ju7Bps9defkzJ-r5VmZFEV6LFultyy2tUIEaNvRi6V1YNBHRGfTh-KBaN00B6wNaQcM1FprBRiTN6jixzInXuKTFTOnS7KxiarHX_7x1P19qlf9XU61puHYJIsTQU5CyjA2n5c__JpyQIx94cOykBtle4tA88m8V4CGBTFsNRqZoNpCC1rqh-r-gTiQKz7Tbc8ea4MQiIVd0TUAKFiXge2Thczjgl-_n0nD3M7dO3FiOZvAQNUxEonDGfjsoYGtMNHgoKQO6fH2y68joH26PDr0bdeSdW820M0N50DXIu4p8OJZRNB8mWcGqrCDWXEjq/download
# !unrar x /content/drive/MyDrive/download -plivemultidistortiondatabase2013
# mkdir /content/drive/MyDrive/livemd
# !mv /content/drive/MyDrive/livemd/To_Release/Part1/blurjpeg/* /content/drive/MyDrive/livemd/
# !mv /content/drive/MyDrive/livemd/To_Release/Part2/blurnoise/* /content/drive/MyDrive/livemd/

### **Data generator**

In [None]:
class generator_overlapping(tensorflow.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self):
        self.on_epoch_end()
    'Denotes the number of batches per epoch'
    def __len__(self):
      
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def load_pkl(self,list_IDs_path,labels_path1,labels_path2,labels_path3, part):
        pickle_in = open(list_IDs_path,'rb')
        list_IDs = pickle.load(pickle_in)[part]
        pickle_in.close()

        pickle_in2 = open(labels_path1,'rb')
        labels1 = pickle.load(pickle_in2)
        pickle_in2.close()
        
        pickle_in2 = open(labels_path2,'rb')
        labels2 = pickle.load(pickle_in2)
        pickle_in2.close()
        
        pickle_in2 = open(labels_path3,'rb')
        labels3 = pickle.load(pickle_in2)
        pickle_in2.close()

        return  list_IDs, labels1, labels2, labels3
    
    def loading_img(self):
        return image.load_img(self.db_path+self.ID)

    def init_y(self):
        return np.empty((self.patches*self.batch_size,1), dtype=np.float32)
    
    def update_y1(self,ID):
        return self.labels1[ID]
    def update_y2(self,ID):
        return self.labels2[ID]
    def update_y3(self,ID):
        return self.labels3[ID]
    
    
    def update_x(self,x):
        return x 

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        # Generate data
        X, y1, y2, y3 = self.__data_generation(list_IDs_temp)
        return X, [y1, y2, y3]

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def ajust(self,img):
        return img
    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' 
        # X : (n_samples, *dim, n_channels)
        # Initialization
        
        X = np.empty((self.patches*self.batch_size, *self.dim, self.n_channels))
        y1 =self.init_y() 
        y2 =self.init_y() 
        y3 = self.init_y()
       
        for i, ID in enumerate(list_IDs_temp):
            self.ID=ID    
            img = image.load_img(self.db_path+ID)
            img = image.img_to_array(img)
            img = applications.densenet.preprocess_input(img)
            img=self.ajust(img)
            x=view_as_windows(np.ascontiguousarray(img),(*self.dim,3),self.overlap_stride).reshape((-1,*self.dim,3))      

            X[(i)*self.patches :(i+1)*self.patches,:,:,:]=self.update_x(x)
            y1[(i)*self.patches :(i+1)*self.patches]=self.update_y1(ID)
            y2[(i)*self.patches :(i+1)*self.patches]=self.update_y2(ID)
            y3[(i)*self.patches :(i+1)*self.patches]=self.update_y3(ID)
        return X, y1, y2, y3

   
    
class LIVEMD_GENERATOR(generator_overlapping):
    'Generates data for Keras'
    def __init__(self,batch_size=1, dim=(224,224), n_channels=3,
                 n_output=1, shuffle=True,part='complete',base='vgg19'):
        self.base=base
        self.dim = dim
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_output = n_output
        self.shuffle = shuffle
        self.input_dim=(self.dim[0],self.dim[1],self.n_channels)
        self.db_path='/content/drive/MyDrive/tid/distorted_images/'
        list_IDs_path='/content/drive/MyDrive/tid/files/partition_tid1.pickle'
        labels_path1='/content/drive/MyDrive/tid/files/blur_tid.pickle'
        labels_path2='/content/drive/MyDrive/tid/files/jpeg_tid.pickle'
        labels_path3 = '/content/drive/MyDrive/tid/files/noise_tid.pickle'
        self.patches=8  
        self.overlap_stride = 350
        self.list_IDs,self.labels1, self.labels2, self.labels3 =super().load_pkl(list_IDs_path,labels_path1, labels_path2, labels_path3, part)
        super().__init__()

class TID_GENERATOR(generator_overlapping):
    'Generates data for Keras'
    def __init__(self,batch_size=1, dim=(224,224), n_channels=3,
                 n_output=1, shuffle=True,part='complete',base='vgg19'):
        self.base=base
        self.dim = dim
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_output = n_output
        self.shuffle = shuffle
        self.input_dim=(self.dim[0],self.dim[1],self.n_channels)
        self.db_path='/content/drive/MyDrive/tid/distorted_images/'
        list_IDs_path='/content/drive/MyDrive/tid/files/partition_tid1.pickle'
        labels_path1='/content/drive/MyDrive/tid/files/blur_tid.pickle'
        labels_path2='/content/drive/MyDrive/tid/files/jpeg_tid.pickle'
        labels_path3 = '/content/drive/MyDrive/tid/files/noise_tid.pickle'
        self.patches=4  
        self.overlap_stride = 150
        self.list_IDs,self.labels1, self.labels2, self.labels3 =super().load_pkl(list_IDs_path,labels_path1, labels_path2, labels_path3, part)
        super().__init__()

class CSIQ_GENERATOR(generator_overlapping):
    'Generates data for Keras'
    def __init__(self,batch_size=1, dim=(224,224), n_channels=3,
                 n_output=1, shuffle=True,part='complete',base='vgg19'):
        self.base=base
        self.dim = dim
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_output = n_output
        self.shuffle = shuffle
        self.input_dim=(self.dim[0],self.dim[1],self.n_channels)
        self.db_path='/content/drive/MyDrive/tid/distorted_images/'
        list_IDs_path='/content/drive/MyDrive/tid/files/partition_tid1.pickle'
        labels_path1='/content/drive/MyDrive/tid/files/blur_tid.pickle'
        labels_path2='/content/drive/MyDrive/tid/files/jpeg_tid.pickle'
        labels_path3 = '/content/drive/MyDrive/tid/files/noise_tid.pickle'
        self.patches=4  
        self.overlap_stride = 160
        self.list_IDs,self.labels1, self.labels2, self.labels3 =super().load_pkl(list_IDs_path,labels_path1, labels_path2, labels_path3, part)
        super().__init__()

class KADID_GENERATOR(generator_overlapping):
    'Generates data for Keras'
    def __init__(self,batch_size=1, dim=(224,224), n_channels=3,
                 n_output=1, shuffle=True,part='complete',base='vgg19'):
        self.base=base
        self.dim = dim
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_output = n_output
        self.shuffle = shuffle
        self.input_dim=(self.dim[0],self.dim[1],self.n_channels)
        self.db_path='/content/drive/MyDrive/tid/distorted_images/'
        list_IDs_path='/content/drive/MyDrive/tid/files/partition_tid1.pickle'
        labels_path1='/content/drive/MyDrive/tid/files/blur_tid.pickle'
        labels_path2='/content/drive/MyDrive/tid/files/jpeg_tid.pickle'
        labels_path3 = '/content/drive/MyDrive/tid/files/noise_tid.pickle'
        self.patches=6  
        self.overlap_stride = 100
        self.list_IDs,self.labels1, self.labels2, self.labels3 =super().load_pkl(list_IDs_path,labels_path1, labels_path2, labels_path3, part)
        super().__init__()


  

### **Build the model**

In [None]:
def build_model(max_pool= False ,weights='imagenet', dropOutRate=0.25,hiddenLayerDim=512,num_denseLayer=2,input_shape = (224,224,3), include_top = False,fine_tune_all = False, num_towers =2): 
    
    base_model = DenseNet169(weights=weights ,include_top = include_top, input_shape = input_shape)

    if (fine_tune_all ==False):
        for layer in base_model.layers:
            layer.trainable = False
            
    x =base_model.layers[-1].output
        
    if (max_pool):
        x= MaxPooling2D(pool_size=(2,2))(x)
    x =Flatten()(x)
    features = x
    for i in range(num_denseLayer):
        features = Dense(hiddenLayerDim, activation='relu',name="DenseTower1"+str(i))(features)
        features = Dropout(dropOutRate, name="DropoutTower1"+ str(i))(features)
    output1 = Dense(1, name="Tower1", activation="sigmoid")(features)
        
       
    features = x
    for i in range(num_denseLayer):
        features = Dense(hiddenLayerDim, activation='relu',name="DenseTower2"+str(i))(features)
        features = Dropout(dropOutRate, name="DropoutTower2"+ str(i))(features)
    output2 = Dense(1, name= "Tower2", activation="sigmoid")(features)
             
    features = x
    for i in range(num_denseLayer):
        features = Dense(hiddenLayerDim, activation='relu',name="DenseTower3"+str(i))(features)
        features = Dropout(dropOutRate, name="DropoutTower3"+ str(i))(features)
    output3 = Dense(1, name= "Tower3", activation="sigmoid")(features)
        


    model = Model(inputs=base_model.layers[0].output, outputs= [output1, output2, output3])

    return model


### **Train model**

In [None]:
def Training(batch_size=8,db='CSIQ',dropOutRate=0.25,hiddenLayerDim=512,epochs=60, num_denseLayer=2,fine_tune_all=False):


    #Define parameters 
    params = {'dim':(224,224),
        'batch_size': batch_size,
        'n_output': 1,
        'n_channels': 3,
        'shuffle': True,
    }

    if (db=='TID'):
        training_generator = TID_GENERATOR(part='train', **params)
        validation_generator =TID_GENERATOR(part='test', **params)

    if (db=='CSIQ'):
        training_generator = CSIQ_GENERATOR(part='train', **params)
        validation_generator = CSIQ_GENERATOR(part='test', **params)

    if (db=='KADID'):
        training_generator = KADID_GENERATOR(part='train', **params)
        validation_generator = KADID_GENERATOR(part='test', **params)

    if (db=='LIVEMD'):
        training_generator = LIVEMD_GENERATOR(part='train', **params)
        validation_generator = LIVEMD_GENERATOR(part='test', **params)        


    model = build_model (weights='imagenet',dropOutRate=0.25,hiddenLayerDim=512,num_denseLayer=2, input_shape=(224, 224, 3),fine_tune_all= False, max_pool = True)
    model.summary()
    adam=Adam(lr=0.0001)
    losses = {
            'Tower1': 'binary_crossentropy',
            'Tower2': 'binary_crossentropy',
            'Tower3': 'binary_crossentropy'
    }
    lossWeights = {
            'Tower1': 1.0, 
            'Tower2': 1.0,
            'Tower3':1.0
    }
    metrics = {
            'Tower1': 'accuracy', 
            'Tower2': 'accuracy',
            'Tower3':'accuracy'
    }

    model.compile(optimizer=adam,
              loss= losses,
              loss_weights=lossWeights,
              metrics=metrics)



    tensorboard = TensorBoard(log_dir='./logs', histogram_freq=0,write_graph=True, write_images=False)
    callbacks = tensorflow.keras.callbacks.ModelCheckpoint('weights.h5', monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=True, mode='min')

    print('Start training of '+ training_type)
    history =model.fit_generator(generator=training_generator,
                                validation_data=validation_generator,
                                use_multiprocessing=True,
                                workers=4,
                                epochs=epochs,
                                callbacks= callbacks
                               )

    print('Training is finished')

    return True


Training(db='TID',batch_size=6, hiddenLayerDim=512,num_denseLayer=2,dropOutRate=0.25 ,fine_tune_all=False ,epochs=60, save_json=True)


### **Test model**

In [None]:
def predictions(patches= 4 ,partition_path='',labels_path1='',labels_path2='',labels_path3='',y_pred1='',y_pred2='',y_pred3= '',part='test',save_csv=True):

            
            pickle_in = open(partition_path,'rb')
            partition = pickle.load(pickle_in)
            pickle_in.close()
            
            pickle_in2 = open(labels_path1,'rb')
            labels1 = pickle.load(pickle_in2)
            pickle_in2.close()
            
            pickle_in2 = open(labels_path2,'rb')
            labels2 = pickle.load(pickle_in2)
            pickle_in2.close()
            
            pickle_in2 = open(labels_path3,'rb')
            labels3 = pickle.load(pickle_in2)
            pickle_in2.close()


                     
            print('Start evaluation...')
            truE1=[]
            truE2=[]
            truE3=[]
            truEname=[]
            for im in partition[part]:
                truE1.append(labels1[im])
                truE2.append(labels2[im])
                truE3.append(labels3[im])
                truEname.append(im)

            y_true1=np.array(truE1)
            y_true2=np.array(truE2)
            y_true3=np.array(truE3)           
            y_truename=np.array(truEname)
            

            y_pred1=y_pred1.reshape(-1,)
            y_pred2=y_pred2.reshape(-1,)
            y_pred3=y_pred3.reshape(-1,)
  


            if (save_csv):
                with open('test.csv', 'w') as f:
                    fnames = ['name','pred blur', 'true blur', 'pred JPEG','true JPEG','pred noise','true noise']       
                    writer = csv.DictWriter(f, fieldnames=fnames)
                    writer.writeheader()

                    for i in range(y_true1.size-1):
                        pred1 = 0
                        pred2 = 0
                        pred3 = 0 
                        for k in range(patches):
                          pred1=y_pred1[(i*patches)+k]+pred1
                        blur=pred1/patches
                        
                        for k in range(patches):
                          pred2=y_pred2[(i*patches)+k]+pred2
                        JPEG=pred2/patches
                        
                        for k in range(patches):
                          pred3=y_pred3[(i*patches)+k]+pred3
                        noise=pred3/patches                  
                        
                        writer.writerow({'name': y_truename[i],'pred blur' : blur, 'true blur': y_true1[i],'pred JPEG' : JPEG , 'true JPEG': y_true2[i], 'pred noise' : noise , 'true noise': y_true3[i]})

            return True

def test_model(save_csv=True, db='LIVEMD', max_pool=False, hiddenLayerDim=512):

    dim=(224,224)
    params = {'dim': dim,
        'batch_size': 1,
        'n_output': 1,
        'n_channels': 3,
        'shuffle': False,
    }

    if (db=='TID'):
        test_generator =TID_GENERATOR(part='test', **params)

    if (db=='CSIQ'):
        test_generator = CSIQ_GENERATOR(part='test', **params)

    if (db=='KADID'):
        test_generator = KADID_GENERATOR(part='test', **params)

    if (db=='LIVEMD'):
        test_generator = LIVEMD_GENERATOR(part='test', **params)        


    model = build_model (weights=None, dropOutRate=0.25, hiddenLayerDim=512, num_denseLayer=2, input_shape=(224, 224, 3))
    adam=Adam(lr=0.0001)
    losses = {
            'Tower1': 'mean_squared_error',
            'Tower2': 'mean_squared_error',
            'Tower3': 'mean_squared_error',
    }
    lossWeights = {
            'Tower1': 1.0, 'Tower2': 1.0
            , 'Tower3':1.0
  
    }
    metrics = {
            'Tower1': 'accuracy', 
            'Tower2': 'accuracy',
            'Tower3':'accuracy'
    }

    model.compile(optimizer=adam,
              loss= losses,
              loss_weights=lossWeights,
              metrics=metrics)
    
    
    model.load_weights('weights.h5')

    y_pred1 , y_pred2, y_pred3 = model.predict_generator(generator=test_generator)

    if (db == "TID"):
      return predictions(patches= 4 ,partition_path='',labels_path1='',labels_path2='',labels_path3='',y_pred1='',y_pred2='',y_pred3= '',part='test',save_csv=True)
    if (db == "CSIQ"):
      return predictions(patches= 4 ,partition_path='',labels_path1='',labels_path2='',labels_path3='',y_pred1='',y_pred2='',y_pred3= '',part='test',save_csv=True)
    if (db == "KADID"):
      return predictions(patches= 6 ,partition_path='',labels_path1='',labels_path2='',labels_path3='',y_pred1='',y_pred2='',y_pred3= '',part='test',save_csv=True)
    if (db == "LIVEMD"):
      return predictions(patches= 8 ,partition_path='',labels_path1='',labels_path2='',labels_path3='',y_pred1='',y_pred2='',y_pred3= '',part='test',save_csv=True)


test_model(db='LIVEMD', hiddenLayerDim=512, save_csv=True)


### **Evaluate test**

In [None]:
data = pd.read_csv("test.csv")


y_pred = []
y_true = []
for i in range(len(data)):
    pred = [round(data["pred blur"][i]), round(data["pred JPEG"][i]),round(data["pred noise"][i])]
    y_pred.append(pred)
    true = [int(data["true blur"][i]), int(data["true JPEG"][i]),int(data["true noise"][i])]
    y_true.append(true)

print('######### Accuracy #########')
print(accuracy_score(y_true, y_pred))
print('######### Precision BLUR #########')
pred = round(data["pred blur"])
true = data["true blur"]
print(precision_score(true,pred))
print('######### Precision JPEG #########')
pred = round(data["pred JPEG"])
true = data["true JPEG"]
print(precision_score(true,pred))
print('######### Precision NOISE #########')
pred = round(data["pred noise"])
true = data["true noise"]
print(precision_score(true,pred))
print('######### Recall BLUR #########')
pred = round(data["pred blur"])
true = data["true blur"]
print(recall_score(true,pred))
print('######### Recall JPEG #########')
pred = round(data["pred JPEG"])
true = data["true JPEG"]
print(recall_score(true,pred))
print('######### Recall NOISE #########')
pred = round(data["pred noise"])
true = data["true noise"]
print(recall_score(true,pred))
