In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import seaborn as sn
sn.set_context("poster")

import torch
from torch import nn as nn
ttype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
ctype = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor
print(ttype)
import torch.nn.functional as F
from matplotlib import gridspec
from sithcon import SITHCon_Layer, _SITHCon_Core, iSITH

from tqdm.notebook import tqdm

import itertools
from csv import DictWriter
import os 
from os.path import join
import glob

import numpy as np
import pandas as pd
import pickle
from math import factorial
import random

In [None]:
MORSE_CODE_DICT = { 'A':'.-', 'B':'-...', 
                    'C':'-.-.', 'D':'-..', 'E':'.', 
                    'F':'..-.', 'G':'--.', 'H':'....', 
                    'I':'..', 'J':'.---', 'K':'-.-', 
                    'L':'.-..', 'M':'--', 'N':'-.', 
                    'O':'---', 'P':'.--.', 'Q':'--.-', 
                    'R':'.-.', 'S':'...', 'T':'-', 
                    'U':'..-', 'V':'...-', 'W':'.--', 
                    'X':'-..-', 'Y':'-.--', 'Z':'--..', 
                    '1':'.----', '2':'..---', '3':'...--', 
                    '4':'....-', '5':'.....', '6':'-....', 
                    '7':'--...', '8':'---..', '9':'----.', 
                    '0':'-----', ', ':'--..--', '.':'.-.-.-', 
                    '?':'..--..', '/':'-..-.', '-':'-....-', 
                    '(':'-.--.', ')':'-.--.-'} 

In [None]:
print(MORSE_CODE_DICT['?'], MORSE_CODE_DICT['?'].replace('.', '10').replace('-', '1110'))
print(len(MORSE_CODE_DICT))
morse_code_numpy = {key:np.array([int(x) for x in MORSE_CODE_DICT[key].replace('.', '10').replace('-', '1110')] + [0, 0])
                    for key in MORSE_CODE_DICT.keys()}
for k in morse_code_numpy.keys():
    #if len(morse_code_numpy[k]) == 12:
    #    print(morse_code_numpy[k], k)
    print(morse_code_numpy[k], k)
print(len(morse_code_numpy))
subset = list(morse_code_numpy.keys())
#subset = ['3', '7', 'Y', 'Q', 'J',
#          'M', 'R', 'U', 'H', 'D']

In [None]:
id2key = subset
key2id = {}
for idx, s in enumerate(subset):
    key2id[s] = idx

X = [ttype(morse_code_numpy[k])for k in subset]
Y = torch.LongTensor(np.arange(0,len(X)))
print(Y.max())
print(X, Y)

In [None]:
class SITHCon_Classifier(nn.Module):
    def __init__(self, out_classes, layer_params, 
                 act_func=nn.ReLU, batch_norm=False,
                 dropout=.2):
        super(SITHCon_Classifier, self).__init__()
        last_channels = layer_params[-1]['channels']
        self.transform_linears = nn.ModuleList([nn.Linear(l['channels'], l['channels'])
                                                for l in layer_params])
        self.sithcon_layers = nn.ModuleList([SITHCon_Layer(l, act_func) for l in layer_params])
        self.to_out = nn.Linear(last_channels, out_classes)
        
        
    def forward(self, inp):
        
        x = inp
        #out = []
        for i in range(len(self.sithcon_layers)):
            x = self.sithcon_layers[i](x)
            
            x = F.relu(self.transform_linears[i](x[:,0,:,:].transpose(1,2)))
            x = x.unsqueeze(1).transpose(2,3)

            #out.append(x.clone())
        x = x.transpose(2,3)[:, 0, :, :]
        #x = x.transpose(2,3)[:, 0, :, :]
        x = self.to_out(x)
        return x

In [None]:
permute = np.arange(0, 43)
print(max(permute))

# Three Layers

In [None]:
def gen_model(p):
    sp1 = dict(in_features=1, 
               tau_min=.1, tau_max=p[0], buff_max=p[5],
               dt=1, ntau=p[1], k=p[2], g=0.0, ttype=ttype, 
               channels=25, kernel_width=p[3], dilation=p[4],
               dropout=None, batch_norm=None)
    sp2 = dict(in_features=sp1['channels'], 
               tau_min=.1, tau_max=p[0], buff_max=p[5],
               dt=1, ntau=p[1], k=p[2], g=0.0, ttype=ttype, 
               channels=25, kernel_width=p[3], dilation=p[4], 
               dropout=None, batch_norm=None)
    sp3 = dict(in_features=sp2['channels'], 
               tau_min=.1, tau_max=p[0], buff_max=p[5],
               dt=1, ntau=p[1], k=p[2], g=0.0, ttype=ttype, 
               channels=25, kernel_width=p[3], dilation=p[4], 
               dropout=None, batch_norm=None)
    layer_params = [sp1, sp2, sp3]
    model = SITHCon_Classifier(len(X), layer_params, act_func=None).cuda()
    return model


In [None]:
params = [[4000, 400, 35, 23, 2, 6500],]
model = gen_model(params[0])
tot_weights = 0
for p in model.parameters():
    tot_weights += p.numel()
print("Total Weights:", tot_weights)
c = model.sithcon_layers[0].sithcon.sith.c
print(c)
ntau = 320
m = .1
maxt = m*(c+1)**(ntau-1)
print(maxt)
params.append([maxt, ntau, 35, 23, 2, maxt*3])

In [None]:
training_lens = [1,4,10,25,100]

In [None]:
runs = 5
epochs = 150000
Trainscale = 10
device='cuda'

batch_size = 10
batches = int(np.ceil(43 / batch_size))
for r in range(runs):
    model = gen_model(params[0]).cuda()
    
    loss_func = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    perf_file = join('perf','sithconmulti_morsedecoder_run_TEST{}.csv'.format(r))
    

    progress_bar = tqdm(range(int(epochs)), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
    times_100 = 0

    for epoch_idx in progress_bar:
        perfs = []
        losses = []
        model.train()
        for batch_idx in range(batches):
            optimizer.zero_grad()
            loss = 0
            permute = np.arange(0, 43)
            np.random.shuffle(permute)
            for i in range(0, int(min(len(X) - (batch_idx*batch_size), 
                                  batch_size))
                           ):
                Trainscale = training_lens[i % len(training_lens)]
                iv = X[permute[batch_idx*batch_size + i]]
                iv = iv.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device)
                iv = iv.unsqueeze(-1)
                iv = iv.repeat(1,1,1,1,Trainscale)
                iv = iv.reshape(1,1,1,-1)
                tv = Y[permute[batch_idx*batch_size + i]].to(device)
                out = model(iv)
                loss += loss_func(out[:, -1, :],
                                  torch.cuda.LongTensor([tv]))
                perfs.append((torch.argmax(out[:, -1, :], dim=-1) == 
                              tv).sum().item())

            loss = loss / min(len(X) - (batch_idx*batch_size), 
                              batch_size)
            loss.backward()
            optimizer.step()


            #perfs = perfs[int(-loss_buffer_size/batch_size):]
            losses.append(loss.detach().cpu().numpy())
            #losses = losses[int(-loss_buffer_size/batch_size):]


            s = "{}:{:2} Loss: {:.4f}, Perf: {:.4f}"
            format_list = [epoch_idx, batch_idx, np.mean(losses), 
                           np.sum(perfs)/((len(perfs)))]
            s = s.format(*format_list)
            progress_bar.set_description(s)
        if (np.sum(perfs)/((len(perfs))) == 1.0) & (np.mean(losses) < .11):
            times_100 += 1
            if times_100 >= 3:
                break
    torch.save(model.state_dict(), perf_file[:-4]+".pt")


In [None]:
torch.save(model.state_dict(), perf_file[:-4]+".pt")

In [None]:
runs = 5
device = 'cuda'
for r in range(runs):
    model.load_state_dict(torch.load(join('perf','sithconmulti_morsedecoder_run_{}.pt'.format(r))))
    model.eval()
    with torch.no_grad():
        evald = []
        evaldDict = {'test_perf':[],
                     'rate':[]}
        for nr in [1,2,3,4,5,6,7,8,9,10,12,13,15,18,20,25,50,75,90,100,125,250,500,1000]:
        #for nr in range(1,20):
            perfs = []
            for batch_idx, iv in enumerate(X):
                iv = iv.unsqueeze(0).unsqueeze(0).unsqueeze(0).to(device)
                iv = iv.unsqueeze(-1)
                iv = iv.repeat(1,1,1,1,nr)
                iv = iv.reshape(1,1,1,-1)
                tv = Y[batch_idx].to(device)
                out = model(iv)

                perfs.append((torch.argmax(out[:, -1, :], dim=-1) == 
                              tv).sum().item())
                #print(torch.argmax(out, dim=-1), 
                #              tv)
            evaldDict['test_perf'].append(sum(perfs)/len(perfs))
            evaldDict['rate'].append(nr)
            print(nr, sum(perfs)/len(perfs))
            evald.append({'scale':nr, 
                          'perf':sum(perfs)/len(perfs)})
        scale_perfs = pd.DataFrame(evald)
        scale_perfs.to_pickle(join("perf", "sithconmulti_morse_ttest_{}.dill".format(r)))

In [None]:
sn.lineplot(data=scale_perfs, x='scale', y='perf')
plt.xscale('log')
for x in training_lens:
    
    plt.axvline(x, color='red', linestyle='--')