In [4]:
import scipy.io as sio
import pandas as pd

import numpy as np

import numpy.random as rand
from scipy.optimize import minimize

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
import seaborn as sns

from sklearn.metrics import mean_squared_error
from sklearn.model_selection import KFold

from tqdm import tqdm

plt.style.use('seaborn-pastel')
plt.ion()

%matplotlib tk 

result_folder = 'Results'

In [2]:
class FileReader:

    def __init__(self, file):
        self.file = file

    def read_file(self):
        if self:
            try:
                with open(self.file, "r", encoding="ISO-8859-1") as f:
                    movie_list = [' '.join(np.array(l.split())[1:]) for l in f.readlines()]
                return movie_list
            except IOError:
                pass
        else:
            raise ValueError("File name not specified. ")

In [3]:
class DATA():
    
    @staticmethod
    def readTXT(file):
        return FileReader('movie_ids.txt').read_file()
    
    @staticmethod
    def readMAT(file):
        return sio.loadmat(file)
    
    @staticmethod
    def loadData(file, fields = ['Y', 'R']):

        matData = DATA.readMAT(file)
        return [f for f in DATA.processData(matData, fields)]
    
    @staticmethod
    def processData(matData, fields):
        
        if type(fields) is str:
            yield matData[fields]
        elif type(fields) is list:
            for f in fields:
                yield matData[f]
        else:
            raise TypeError('value of {} can be either string or list, provided is {} '.format("'fields'", type(fields)))


In [5]:
class AnimateScores():
    
    
    def __init__(self, iterations):
        
        self.itr = iterations
        
        fig = plt.figure(figsize=(20, 5))
        
        ax = plt.axes(xlim=(0, iterations * 1.5), ylim=(0, 1000))
        ax.set_ylabel('Cost/1000')
        ax.set_xlabel('Iterations')
        
        self.fig = fig
        self.plt = plt
        self.ax = ax
        
        self.costLine,  = self.plt.plot([],[], color='blue', marker='o', label='Cost')
        
        ax.legend(shadow=True, fontsize='x-large')
        
        self.costs = []
        self.counts = []
        
    def plotScores(self,counts,costs):
        self.anim = animation.FuncAnimation(self.fig, self.update_line, 
                                       frames=range(0, self.itr),
                                       fargs=(counts, costs),
                                       interval=self.itr*10, blit=True, init_func=self.init, repeat=True)
        #self.plt.pause(0.00001 * self.itr)
        #self.plt.close()
        #self.plt.show(block=False)
            
    def plotScoresRealtime(self,frames):
        self.anim = animation.FuncAnimation(self.fig, self.update_line_realtime, 
                                       frames=frames,
                                       blit=True, init_func=self.init, repeat=True)
        
        self.plt.pause(self.itr)

    def init(self):
        self.costLine.set_data([], [])
        return self.costLine,
    
    def update_line(self, num, counts ,costs):
        #print(counts[..., :num], costs[..., :num])
        #self.costLine.set_data(counts[..., :num], costs[..., :num])
        #return self.costLine,
        return self.plt.plot(counts[..., :num], costs[..., :num], color='blue', marker='o', label='Cost')

    def update_line_realtime(self, frame):
        self.costs.append(frame)
        self.counts.append(self.count)
        self.count = self.count + 1
        
        self.costLine.set_data(self.counts, self.costs)
        return self.costLine,
    
    def save(self, filename=None):
        if filename is not None:
            f = '{}/{}.gif'.format(result_folder, filename)
            print(".... Saving aimation {}".format(f))
            self.anim.save(f, writer = "pillow", fps=5) 
            
    #Static graph to plot scores
    @staticmethod
    def plot_static_graph(costs, file = None):
    
        length = len(costs)
        
        fig = plt.figure(figsize=(12, 5))
        ax = plt.axes(xlim=(0, length + 1))
    
        ax.set_ylabel('Cost/1000')
        ax.set_xlabel('Iterations')
    
        plt.plot(range(1, length +1),costs, color='blue', marker='o', label='Cost')

        ax.legend(shadow=True, fontsize='x-large')
        
        if file is not None:
            f = '{}/{}.png'.format(result_folder, file)
            fig.savefig(f)
        
        plt.show()