In [1]:
import shelve
import os
import torch
import warnings
import random
from torch.utils.data import Dataset

In [2]:
os.chdir('/Users/rraj/PythonFunctions/DCNet/')
os.getcwd()

'/Users/rraj/PythonFunctions/DCNet'

In [82]:
class ModelData:
    def read_data(self, data_identifier: str, data_type: str):
        file_path = self.check_identifier(data_identifier)
        
        f = shelve.open(file_path[:-3], 'r')
        if data_type == 'train':
            data = f['train_dict']
        elif data_type == 'test':
            data = f['train_dict']
        else:
            raise Exception("invalid data type requested")
        f.close()
        return data
    
    def check_identifier(self, data_identifier: str):
        file_path = data_identifier
        if 'Data' not in data_identifier:
            file_path = os.path.join(os.getcwd(), 'Data', 'Data-'+data_identifier+'.db')
        
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"{file_path} not found")
        else:
            return file_path
        
        
    def labels_to_index_dict(self, data: dict):
        return {label : indx for indx, label in enumerate(data.keys())}
    
    def index_to_labels_dict(self, labels_to_index_dict: dict):
        return {indx : label for label, indx in labels_to_index_dict.items()}
    
    def get_nclasses(self, data: dict):
        return len(data.keys())
    
    def get_sample_sizes(self, data: dict, index_to_label_dict: dict):
        sample_sizes = []
        for indx in range(len(data)):
            label = index_to_label_dict.get(indx, None)
            sample_sizes.append(data[label].shape[1])
        return sample_sizes

    def get_max_batch_size(self, sample_sizes: list):
        return min(sample_sizes)*len(sample_sizes)

    

In [83]:
class TrainData(ModelData, Dataset):
    def __init__(self, data_identifier: str):
        self.data = self.read_data(data_identifier, 'train')
        self.nclasses = self.get_nclasses(self.data)
        self.labels_to_index = self.labels_to_index_dict(self.data)
        self.index_to_labels = self.index_to_labels_dict(self.labels_to_index)
        self.sample_sizes = self.get_sample_sizes(self.data, self.index_to_labels)
        self.max_batch_size = self.get_max_batch_size(self.sample_sizes)

    def __len__(self):
        total_samples = 0
        for value in self.data.values():
            total_samples += value.shape[1]
        return total_samples
    
    def __getitem__(self, index):
        if isinstance(index, tuple) and len(index) == 2:
            k = self.index_to_labels.get(index[0], None)
            return torch.from_numpy(self.data[k][:, index[1]:index[1]+1])
        else:
            raise IndexError(f"{index} not supported")


In [84]:
class TestData(ModelData, Dataset):
    def __init__(self, data_identifier: str):
        self.data = self.read_data(data_identifier, 'test')
        self.labels_to_index = self.labels_to_index_dict(self.data)
        self.nclasses = self.get_nclasses(self.data)
        self.index_to_labels = self.index_to_labels_dict(self.labels_to_index)
        self.sample_sizes = self.get_sample_sizes(self.data, self.index_to_labels)

    def __len__(self):
        total_samples = 0
        for value in self.data.values():
            total_samples += value.shape[1]
        return total_samples
    
    def __getitem__(self, index):
        if isinstance(index, tuple) and len(index) == 2:
            k = self.index_to_labels.get(index[0], None)
            return torch.from_numpy(self.data[k][:, index[1]:index[1]+1])
        else:
            raise IndexError(f"{index} not supported")
        

In [85]:
class TrainLoader:
    def __init__(self, data_source, batch_size: int = 0, n_iter: int = 1, shuffle: bool = False):
        self.data_source = data_source
        self.n_iter = n_iter
        self.shuffle = shuffle
        self.returned_index = 0
        self.adjusted_batch_size = self.adjust_batch_size(batch_size)
        self.nsamples = self.get_nsamples()
        self.class_indices_r = self.get_randomized_class_indices()
        self.sample_indices_r = self.get_randomized_sample_indices()
        
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.n_iter > 0 and self.returned_index < self.adjusted_batch_size:
            c_indx, s_indx = divmod(self.returned_index, self.nsamples)
            self.returned_index += 1
            self.update_batch()
            if self.shuffle:
                indx = self.class_indices_r[c_indx], random.randrange(self.data_source.sample_sizes[c_indx])
                return self.data_source[indx]
            else:
                indx = self.class_indices_r[c_indx], (self.sample_indices_r[c_indx] + s_indx)
                return self.data_source[indx]
        else:
            raise StopIteration

    def update_batch(self):
        if self.returned_index == self.adjusted_batch_size:
            self.n_iter -= 1
            random.shuffle(self.class_indices_r)
            self.sample_indices_r = self.get_randomized_sample_indices()
            self.returned_index = 0
        
    def adjust_batch_size(self, batch_size: int):
        adjusted_batch_size = min(max(batch_size//self.data_source.nclasses, 1) * self.data_source.nclasses, self.data_source.max_batch_size)
        if adjusted_batch_size != batch_size:
            warnings.warn(f"batch size adjusted to {adjusted_batch_size}")
        return adjusted_batch_size

    def get_nsamples(self):
        return (self.adjusted_batch_size//self.data_source.nclasses)
    
    def get_randomized_class_indices(self):
        class_indices_r = [*range(self.data_source.nclasses)]
        random.shuffle(class_indices_r)
        return class_indices_r
    
    def get_randomized_sample_indices(self):
        sample_indices_r = []
        for indx in self.class_indices_r:
            random_start_limit = self.data_source.sample_sizes[indx] - self.nsamples
            sample_indices_r.append(random.randint(0, random_start_limit))
        return sample_indices_r

In [86]:
class TestLoader:
    def __init__(self, data_source):
        self.data_source = data_source
        self.returned_class_indx = 0
        self.returned_sample_indx = 0

    def __iter__(self):
        return self
    
    def __next__(self):
        if self.returned_class_indx < self.data_source.nclasses and self.returned_sample_indx < self.data_source.sample_sizes[self.returned_class_indx]:
            indx = self.returned_class_indx, self.returned_sample_indx
            self.returned_sample_indx += 1
            self.update_class()
            return self.data_source[indx]
        else:
            raise StopIteration
    
    def update_class(self):
        if self.returned_sample_indx == self.data_source.sample_sizes[self.returned_class_indx]:
            self.returned_class_indx += 1
            self.returned_sample_indx = 0

    def get_index_to_class_dict(self):
        return self.data_source.index_to_labels
    
    def get_class_to_index_dict(self):
        return self.data_source.labels_to_index

In [10]:
train_inputs = TrainData('May04-2357')
train_loader = TrainLoader(train_inputs, batch_size=128, n_iter=2)
test_inputs = TestData('May04-2357')
test_loader = TestLoader(test_inputs)

In [89]:
import numpy as np

def initialize_network_weights(data_identifier: str):
    train_inputs = TrainData(data_identifier)
    batch_size = int(0.25*len(train_inputs))
    train_loader = TrainLoader(train_inputs, batch_size=batch_size)
    
    init_batch = torch.tensor([])
    for input in train_loader:
        try:
            init_batch = torch.hstack((init_batch, input))
        except:
            init_batch = input
    return init_batch



In [90]:
w = initialize_network_weights('May04-2357')
w.shape



torch.Size([102, 64])

In [92]:
from sklearn.decomposition import PCA
from scipy.stats import ortho_group

pca = PCA()
pca.fit(w.numpy().T)
total_var = 0
number_of_comps = 0
for comp in range(pca.n_components_):
    total_var += pca.explained_variance_ratio_[comp]
    if total_var > 0.95:
        number_of_comps = comp + 1
        break
print(number_of_comps)
Q_t = pca.components_[0:number_of_comps, :]
Q = np.matmul(Q_t.T, np.diag(pca.singular_values_[:number_of_comps]))

N = ortho_group.rvs(dim=500)
phi = np.matmul(Q, N[:number_of_comps, :])
print(phi.shape)

32
(102, 500)


In [62]:
train_inputs = TrainData('May04-2357')
train_loader = TrainLoader(train_inputs, batch_size=train_inputs.max_batch_size, n_iter=5)

batch = np.array([])
for input in train_loader:
    try:
        batch = np.vstack((batch, input.numpy()))
    except:
        batch = input.numpy()

batch.shape

(960, 102)

In [3]:
import logging
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
logging.debug('This message should appear on the console')
logging.info('So should this')
logging.warning('And this, too')

DEBUG:This message should appear on the console
INFO:So should this


In [2]:
import torch
from torch.nn.functional import normalize
from torch.linalg import multi_dot

In [37]:
a = torch.tensor([[4, 3, 2, 1],[5, 6, 7, 8], [0, 0, 0, 0]], dtype=torch.float)
print(a)
a_centered = a - torch.mean(a, dim=1, keepdim=True)
print(a_centered)
U, S, V = torch.pca_lowrank(a, center=False)
x = torch.matmul(U, torch.diag(S))
y = torch.matmul(x, V.T)
print(y)

tensor([[4., 3., 2., 1.],
        [5., 6., 7., 8.],
        [0., 0., 0., 0.]])
tensor([[ 1.5000,  0.5000, -0.5000, -1.5000],
        [-1.5000, -0.5000,  0.5000,  1.5000],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
tensor([[4.0000, 3.0000, 2.0000, 1.0000],
        [5.0000, 6.0000, 7.0000, 8.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]])


In [40]:
def get_init_batch(train_inputs: TrainData):
    # implement with row wise operations
    batch_size = int(0.25*len(train_inputs))
    train_loader = TrainLoader(train_inputs, batch_size=batch_size)
    init_batch = torch.tensor([])
    for input in train_loader:
        try:
            init_batch = torch.hstack((init_batch, input))
        except:
            init_batch = input
    return init_batch

def get_ncomps(S: torch.Tensor):
    S_n = S**2/torch.sum(S**2)
    n_comp = 0
    sum_comp = 0
    for comp in S_n:
        sum_comp += comp
        n_comp += 1
        if sum_comp > 0.95:
            break
    return n_comp

def randomizer_matrix(m: int, n: int):
    assert m > n
    temp = torch.rand(m, m)
    U, _, _ = torch.pca_lowrank(temp)
    return U[:, :n]

def input_svd_matrices(init_batch: torch.Tensor):
    init_batch_centered =  init_batch - torch.mean(init_batch, dim=1, keepdim=True)
    U, S, _ = torch.pca_lowrank(init_batch_centered)
    n_comps = get_ncomps(S)
    return U[:, :n_comps], S[:n_comps], n_comps

def initialize_network_connections(layer_dims: list, data_identifier: str):
    train_inputs = TrainData(data_identifier)
    init_batch = get_init_batch(train_inputs)
    right_matrix, sigma, N = input_svd_matrices(init_batch)
    sigma_n = sigma**(-1/len(layer_dims))
    #left_matrix = get_randomizer_matrix(layer_dims[0], n)
    for dim in layer_dims:
        left_matrix = randomizer_matrix(layer_dims[dim], N)
        w = multi_dot([left_matrix, torch.diag(sigma_n), right_matrix.T])
        w_n = normalize(w, p = 2.0, dim = 0)
        right_matrix = left_matrix
    return None


torch.Size([3, 3])
tensor([13.9901,  2.8770,  0.0000])


In [46]:
S_n = S**2/torch.sum(S**2)
[torch.sum(S_n[:k]) for k in range(1, len(S_n)+1)]
print(S**-1)

tensor([0.0715, 0.3476,    inf])


In [55]:
print(U)
print(torch.transpose(U, 1, 0))
torch.linalg.multi_dot([U, torch.diag(S), V.T])

tensor([[-0.3404, -0.9403,  0.0000],
        [-0.9403,  0.3404,  0.0000],
        [ 0.0000,  0.0000,  1.0000]])
tensor([[-0.3404, -0.9403,  0.0000],
        [-0.9403,  0.3404,  0.0000],
        [ 0.0000,  0.0000,  1.0000]])


tensor([[4.0000, 3.0000, 2.0000, 1.0000],
        [5.0000, 6.0000, 7.0000, 8.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]])

In [4]:
from torch import nn
class HashLayer(nn.Module):
    def __init__(self, input_size: int, hash_length: int, n_hash: int):
        super().__init__()
        self.hl = nn.Linear(input_size, hash_length*n_hash)
        # initialize custom wights (maybe)

    def forward(self, x):
        out = self.hl(x)
        out = self.hash_function(out)
        table = self.map_hash_table(out)
        return out
    
    def hash_function(self, out):
        pass

    def map_hash_table(self, out):
        # identify a logic to integrate it as network layer 
        pass

# implement to get input indices or make sure that at max bach size inputs are generated orederly

[1, 2]