<a href="https://colab.research.google.com/github/steve-alex999/Cloud-Workload-Prediction/blob/main/btp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np
import time
import pandas as pd
import math
from matplotlib import pyplot
# from pytorchmaster import *

torch.manual_seed(0)
np.random.seed(0)

# S is the source sequence length
# T is the target sequence length
# N is the batch size
# E is the feature number

#src = torch.rand((10, 32, 512)) # (S,N,E)
#tgt = torch.rand((20, 32, 512)) # (T,N,E)
#out = transformer_model(src, tgt)

input_window = 100 # number of input steps
output_window = 1 # number of prediction steps, in this model its fixed to one
batch_size = 30
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PWS = 30
class PositionalEncoding(nn.Module):u

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.j=0
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        #pe.requires_grad = False
        self.register_buffer('pe', pe)

    def forward(self, x):
        # if self.j==0:
        #     print(self.pe)
        #     self.j+=1
        return x + self.pe[:x.size(0), :]


class TransAm(nn.Module):
    def __init__(self,feature_size=250,num_layers=4,dropout=0.1):
        super(TransAm, self).__init__()
        self.model_type = 'Transformer'
        self.j=0
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(feature_size)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder = nn.Linear(feature_size,1)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self,src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask
        if self.j==0:
            # print(src.shape[0],",", src.shape[1],",", src.shape[2])
            # print(src)
            self.j+=1
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src,self.src_mask)#, self.src_mask)
        output = self.decoder(output)
        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask



# if window is 100 and prediction step is 1
# in -> [0..99]
# target -> [1..100]
def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    # print(L)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+output_window:i+tw+output_window]
        inout_seq.append((train_seq ,train_label))
    return torch.FloatTensor(inout_seq)

def get_data():
    # construct a littel toy dataset
    time        = np.arange(0, 136.9, 0.1)
    # amplitude   = np.sin(time) + np.sin(time*0.05) +np.sin(time*0.12) *np.random.normal(-0.2, 0.2, len(time))

    # print(time)
    # print(time.shape[0])

    from sklearn.preprocessing import MinMaxScaler

    #loading weather data from a file
    series = pd.read_csv('minute.csv')
    f = series['Frequency'][:].tolist()
    amplitude   = np.sin(time) + np.sin(time*0.05) +np.sin(time*0.12) *f

    # looks like normalizing input values curtial for the model
    scaler = MinMaxScaler(feature_range=(-1, 1))

    # amplitude = scaler.fit_transform(series.to_numpy().reshape(-1, 1)).reshape(-1)
    amplitude = scaler.fit_transform(amplitude.reshape(-1, 1)).reshape(-1)

    # for i in range(4000):
    #     print(amplitude[i])

    sampels = 822
    train_data = amplitude[:sampels]
    test_data = amplitude[sampels:]

    # convert our train data into a pytorch train tensor
    #train_tensor = torch.FloatTensor(train_data).view(-1)
    # todo: add comment..
    train_sequence = create_inout_sequences(train_data,input_window)
    train_sequence = train_sequence[:-output_window] #todo: fix hack? -> din't think this through, looks like the last n sequences are to short, so I just remove them. Hackety Hack..

    #test_data = torch.FloatTensor(test_data).view(-1)
    test_data = create_inout_sequences(test_data,PWS)
    test_data = test_data[:-output_window] #todo: fix hack?

    return train_sequence.to(device),test_data.to(device)

def get_batch(source, i,batch_size):
    seq_len = min(batch_size, len(source) - 1 - i)
    data = source[i:i+seq_len]
    input = torch.stack(torch.stack([item[0] for item in data]).chunk(input_window,1)) # 1 is feature size
    target = torch.stack(torch.stack([item[1] for item in data]).chunk(input_window,1))
    return input, target


def train(train_data):
    model.train() # Turn on the train mode \o/
    total_loss = 0.
    start_time = time.time()

    for batch, i in enumerate(range(0, len(train_data) - 1, batch_size)):
        data, targets = get_batch(train_data, i,batch_size)
        # if i==0:
            # print(data.shape[0],",", data.shape[1],",", data.shape[2])
            # print(data)
        optimizer.zero_grad()
        output = model.forward(data)
        # if i==0:
            # print(output.shape[0],",",output.shape[1],",",output.shape[2])
            # print(output)
        loss = criterion(output, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.7)
        optimizer.step()

        total_loss += loss.item()
        log_interval = int(len(train_data) / batch_size / 5)
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.6f} | {:5.2f} ms | '
                  'loss {:5.5f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // batch_size, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def plot_and_loss(eval_model, data_source,epoch):
    eval_model.eval()
    total_loss = 0.
    test_result = torch.Tensor(0)
    truth = torch.Tensor(0)
    with torch.no_grad():
        for i in range(0, len(data_source) - 1):
            data, target = get_batch(data_source, i,1)
            output = eval_model(data)
            total_loss += criterion(output, target).item()
            test_result = torch.cat((test_result, output[-1].view(-1).cpu()), 0)
            truth = torch.cat((truth, target[-1].view(-1).cpu()), 0)

    #test_result = test_result.cpu().numpy() -> no need to detach stuff..
    len(test_result)

    # pyplot.plot(test_result,color="red")
    # pyplot.plot(truth[:1369],color="blue")
    # # pyplot.plot(test_result-truth,color="green")
    # pyplot.grid(True, which='both')
    # pyplot.axhline(y=0, color='k')
    # pyplot.savefig('graph/transformer-epoch%d.png'%epoch)
    # pyplot.close()

    return total_loss / i


# predict the next n steps based on the input data
def predict_future(eval_model, data_source,steps, epoch):
    eval_model.eval()
    total_loss = 0.
    test_result = torch.Tensor(0)
    truth = torch.Tensor(0)
    # data, targets = get_batch(data_source, 0,1)
    # print(len(data))
#     with torch.no_grad():

#         output = eval_model(data[-PWS:])

#         data = torch.cat((data, output[-1:]))
#         total_loss += len(data[0])* criterion(output, targets).cpu().item()
#     data = data.cpu().view(-1)

#     # I used this plot to visualize if the model pics up any long therm struccture within the data.
#     pyplot.plot(data,color="red")
#     pyplot.plot(data[:input_window],color="blue")
#     pyplot.grid(True, which='both')
#     pyplot.axhline(y=0, color='k')
#     pyplot.savefig('graph/transformer-future%d.png'%epoch)
#     pyplot.close()
#     return total_loss
    test_loss = evaluate(eval_model,data_source)
    return test_loss


def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    eval_batch_size = 1000
    with torch.no_grad():
        for i in range(0, len(data_source) - 1, eval_batch_size):
            data, targets = get_batch(data_source, i,eval_batch_size)
            output = eval_model(data)
            total_loss += len(data[0])* criterion(output, targets).cpu().item()

    return total_loss / len(data_source)

train_data, val_data = get_data()
model = TransAm().to(device)

criterion = nn.MSELoss()
lr = 0.005
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

best_val_loss = float("inf")
epochs = 100 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_data)

    if(epoch % 10 == 0):
        val_loss = plot_and_loss(model, val_data,epoch)
        test_loss = predict_future(model, val_data, 200, epoch)
    else:
        val_loss = evaluate(model, val_data)

    print('-' * 89)
    if(epoch % 10 == 0):
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.5f} | valid ppl {:8.2f}| test loss {:5.5f} |'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss),test_loss))
    else:
        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.5f} | valid ppl {:8.2f} |'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    #if val_loss < best_val_loss:
    #    best_val_loss = val_loss
    #    best_model = model

    scheduler.step()

#src = torch.rand(input_window, batch_size, 1) # (source sequence length,batch size,feature number)
#out = model(src)
#
#print(out)
#print(out.shape)



| epoch   1 |     4/   24 batches | lr 0.005000 | 169.91 ms | loss 65.61912 | ppl 31478990189337562266468352000.00
| epoch   1 |     8/   24 batches | lr 0.005000 | 95.92 ms | loss 10.03000 | ppl 22697.23
| epoch   1 |    12/   24 batches | lr 0.005000 | 96.51 ms | loss 1.34651 | ppl     3.84
| epoch   1 |    16/   24 batches | lr 0.005000 | 96.15 ms | loss 1.11620 | ppl     3.05
| epoch   1 |    20/   24 batches | lr 0.005000 | 96.15 ms | loss 0.31507 | ppl     1.37
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  2.53s | valid loss 0.50149 | valid ppl     1.65 |
-----------------------------------------------------------------------------------------
| epoch   2 |     4/   24 batches | lr 0.004513 | 120.29 ms | loss 0.13061 | ppl     1.14
| epoch   2 |     8/   24 batches | lr 0.004513 | 96.34 ms | loss 0.11646 | ppl     1.12
| epoch   2 |    12/   24 batches | lr 0.004513 | 96.79 ms | loss 0.19167 | ppl     1.21
| 



| epoch  11 |     4/   24 batches | lr 0.002844 | 121.18 ms | loss 0.10167 | ppl     1.11
| epoch  11 |     8/   24 batches | lr 0.002844 | 96.34 ms | loss 0.22015 | ppl     1.25
| epoch  11 |    12/   24 batches | lr 0.002844 | 97.28 ms | loss 0.15010 | ppl     1.16
| epoch  11 |    16/   24 batches | lr 0.002844 | 97.32 ms | loss 0.26713 | ppl     1.31
| epoch  11 |    20/   24 batches | lr 0.002844 | 97.72 ms | loss 0.59983 | ppl     1.82
-----------------------------------------------------------------------------------------
| end of epoch  11 | time:  2.35s | valid loss 0.06380 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  12 |     4/   24 batches | lr 0.002702 | 119.92 ms | loss 0.14242 | ppl     1.15
| epoch  12 |     8/   24 batches | lr 0.002702 | 96.32 ms | loss 0.17688 | ppl     1.19
| epoch  12 |    12/   24 batches | lr 0.002702 | 97.03 ms | loss 0.23443 | ppl     1.26
| epoch  12 |    16/   24 ba



| epoch  21 |     4/   24 batches | lr 0.001703 | 117.82 ms | loss 0.17413 | ppl     1.19
| epoch  21 |     8/   24 batches | lr 0.001703 | 98.39 ms | loss 0.07370 | ppl     1.08
| epoch  21 |    12/   24 batches | lr 0.001703 | 96.45 ms | loss 0.20732 | ppl     1.23
| epoch  21 |    16/   24 batches | lr 0.001703 | 96.79 ms | loss 0.08153 | ppl     1.08
| epoch  21 |    20/   24 batches | lr 0.001703 | 96.73 ms | loss 0.05683 | ppl     1.06
-----------------------------------------------------------------------------------------
| end of epoch  21 | time:  2.33s | valid loss 0.07316 | valid ppl     1.08 |
-----------------------------------------------------------------------------------------
| epoch  22 |     4/   24 batches | lr 0.001618 | 121.83 ms | loss 0.17774 | ppl     1.19
| epoch  22 |     8/   24 batches | lr 0.001618 | 98.26 ms | loss 0.07162 | ppl     1.07
| epoch  22 |    12/   24 batches | lr 0.001618 | 97.01 ms | loss 0.20152 | ppl     1.22
| epoch  22 |    16/   24 ba



| epoch  31 |     4/   24 batches | lr 0.001020 | 119.66 ms | loss 0.19035 | ppl     1.21
| epoch  31 |     8/   24 batches | lr 0.001020 | 97.08 ms | loss 0.06616 | ppl     1.07
| epoch  31 |    12/   24 batches | lr 0.001020 | 97.29 ms | loss 0.18493 | ppl     1.20
| epoch  31 |    16/   24 batches | lr 0.001020 | 96.76 ms | loss 0.07269 | ppl     1.08
| epoch  31 |    20/   24 batches | lr 0.001020 | 97.02 ms | loss 0.06326 | ppl     1.07
-----------------------------------------------------------------------------------------
| end of epoch  31 | time:  2.34s | valid loss 0.06947 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  32 |     4/   24 batches | lr 0.000969 | 121.84 ms | loss 0.19083 | ppl     1.21
| epoch  32 |     8/   24 batches | lr 0.000969 | 98.15 ms | loss 0.06597 | ppl     1.07
| epoch  32 |    12/   24 batches | lr 0.000969 | 98.09 ms | loss 0.18430 | ppl     1.20
| epoch  32 |    16/   24 ba



| epoch  41 |     4/   24 batches | lr 0.000610 | 119.55 ms | loss 0.19302 | ppl     1.21
| epoch  41 |     8/   24 batches | lr 0.000610 | 97.78 ms | loss 0.06501 | ppl     1.07
| epoch  41 |    12/   24 batches | lr 0.000610 | 97.29 ms | loss 0.18107 | ppl     1.20
| epoch  41 |    16/   24 batches | lr 0.000610 | 97.36 ms | loss 0.07118 | ppl     1.07
| epoch  41 |    20/   24 batches | lr 0.000610 | 97.42 ms | loss 0.06438 | ppl     1.07
-----------------------------------------------------------------------------------------
| end of epoch  41 | time:  2.35s | valid loss 0.06895 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  42 |     4/   24 batches | lr 0.000580 | 121.79 ms | loss 0.19314 | ppl     1.21
| epoch  42 |     8/   24 batches | lr 0.000580 | 97.26 ms | loss 0.06495 | ppl     1.07
| epoch  42 |    12/   24 batches | lr 0.000580 | 97.40 ms | loss 0.18088 | ppl     1.20
| epoch  42 |    16/   24 ba



| epoch  51 |     4/   24 batches | lr 0.000365 | 119.32 ms | loss 0.19379 | ppl     1.21
| epoch  51 |     8/   24 batches | lr 0.000365 | 97.85 ms | loss 0.06460 | ppl     1.07
| epoch  51 |    12/   24 batches | lr 0.000365 | 97.15 ms | loss 0.17968 | ppl     1.20
| epoch  51 |    16/   24 batches | lr 0.000365 | 97.24 ms | loss 0.07067 | ppl     1.07
| epoch  51 |    20/   24 batches | lr 0.000365 | 97.77 ms | loss 0.06469 | ppl     1.07
-----------------------------------------------------------------------------------------
| end of epoch  51 | time:  2.34s | valid loss 0.06880 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  52 |     4/   24 batches | lr 0.000347 | 120.08 ms | loss 0.19383 | ppl     1.21
| epoch  52 |     8/   24 batches | lr 0.000347 | 96.54 ms | loss 0.06458 | ppl     1.07
| epoch  52 |    12/   24 batches | lr 0.000347 | 96.99 ms | loss 0.17959 | ppl     1.20
| epoch  52 |    16/   24 ba



| epoch  61 |     4/   24 batches | lr 0.000219 | 118.27 ms | loss 0.19410 | ppl     1.21
| epoch  61 |     8/   24 batches | lr 0.000219 | 97.11 ms | loss 0.06441 | ppl     1.07
| epoch  61 |    12/   24 batches | lr 0.000219 | 97.78 ms | loss 0.17902 | ppl     1.20
| epoch  61 |    16/   24 batches | lr 0.000219 | 98.32 ms | loss 0.07043 | ppl     1.07
| epoch  61 |    20/   24 batches | lr 0.000219 | 97.13 ms | loss 0.06480 | ppl     1.07
-----------------------------------------------------------------------------------------
| end of epoch  61 | time:  2.34s | valid loss 0.06875 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  62 |     4/   24 batches | lr 0.000208 | 121.82 ms | loss 0.19411 | ppl     1.21
| epoch  62 |     8/   24 batches | lr 0.000208 | 97.63 ms | loss 0.06440 | ppl     1.07
| epoch  62 |    12/   24 batches | lr 0.000208 | 96.34 ms | loss 0.17898 | ppl     1.20
| epoch  62 |    16/   24 ba



| epoch  71 |     4/   24 batches | lr 0.000131 | 121.04 ms | loss 0.19424 | ppl     1.21
| epoch  71 |     8/   24 batches | lr 0.000131 | 97.87 ms | loss 0.06431 | ppl     1.07
| epoch  71 |    12/   24 batches | lr 0.000131 | 97.21 ms | loss 0.17867 | ppl     1.20
| epoch  71 |    16/   24 batches | lr 0.000131 | 96.75 ms | loss 0.07031 | ppl     1.07
| epoch  71 |    20/   24 batches | lr 0.000131 | 97.49 ms | loss 0.06486 | ppl     1.07
-----------------------------------------------------------------------------------------
| end of epoch  71 | time:  2.35s | valid loss 0.06873 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  72 |     4/   24 batches | lr 0.000124 | 121.54 ms | loss 0.19424 | ppl     1.21
| epoch  72 |     8/   24 batches | lr 0.000124 | 96.76 ms | loss 0.06430 | ppl     1.07
| epoch  72 |    12/   24 batches | lr 0.000124 | 96.86 ms | loss 0.17865 | ppl     1.20
| epoch  72 |    16/   24 ba



| epoch  81 |     4/   24 batches | lr 0.000078 | 120.03 ms | loss 0.19430 | ppl     1.21
| epoch  81 |     8/   24 batches | lr 0.000078 | 96.89 ms | loss 0.06426 | ppl     1.07
| epoch  81 |    12/   24 batches | lr 0.000078 | 96.71 ms | loss 0.17848 | ppl     1.20
| epoch  81 |    16/   24 batches | lr 0.000078 | 96.95 ms | loss 0.07025 | ppl     1.07
| epoch  81 |    20/   24 batches | lr 0.000078 | 97.46 ms | loss 0.06488 | ppl     1.07
-----------------------------------------------------------------------------------------
| end of epoch  81 | time:  2.34s | valid loss 0.06872 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  82 |     4/   24 batches | lr 0.000075 | 121.21 ms | loss 0.19431 | ppl     1.21
| epoch  82 |     8/   24 batches | lr 0.000075 | 97.73 ms | loss 0.06425 | ppl     1.07
| epoch  82 |    12/   24 batches | lr 0.000075 | 98.35 ms | loss 0.17846 | ppl     1.20
| epoch  82 |    16/   24 ba



| epoch  91 |     4/   24 batches | lr 0.000047 | 118.98 ms | loss 0.19435 | ppl     1.21
| epoch  91 |     8/   24 batches | lr 0.000047 | 97.13 ms | loss 0.06422 | ppl     1.07
| epoch  91 |    12/   24 batches | lr 0.000047 | 97.29 ms | loss 0.17837 | ppl     1.20
| epoch  91 |    16/   24 batches | lr 0.000047 | 97.21 ms | loss 0.07021 | ppl     1.07
| epoch  91 |    20/   24 batches | lr 0.000047 | 97.19 ms | loss 0.06489 | ppl     1.07
-----------------------------------------------------------------------------------------
| end of epoch  91 | time:  2.34s | valid loss 0.06871 | valid ppl     1.07 |
-----------------------------------------------------------------------------------------
| epoch  92 |     4/   24 batches | lr 0.000045 | 122.17 ms | loss 0.19435 | ppl     1.21
| epoch  92 |     8/   24 batches | lr 0.000045 | 96.74 ms | loss 0.06422 | ppl     1.07
| epoch  92 |    12/   24 batches | lr 0.000045 | 97.13 ms | loss 0.17836 | ppl     1.20
| epoch  92 |    16/   24 ba

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import numpy as np
import time
import pandas as pd
import math
from matplotlib import pyplot

torch.manual_seed(0)
np.random.seed(0)

# S is the source sequence length
# T is the target sequence length
# N is the batch size
# E is the feature number

#src = torch.rand((10, 32, 512)) # (S,N,E)
#tgt = torch.rand((20, 32, 512)) # (T,N,E)
#out = transformer_model(src, tgt)

input_window = 100 # number of input steps
output_window = 1 # number of prediction steps, in this model its fixed to one
batch_size = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.j=0
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        #pe.requires_grad = False
        self.register_buffer('pe', pe)

    def forward(self, x):
        # if self.j==0:
        #     print(self.pe)
        #     self.j+=1
        return x + self.pe[:x.size(0), :]


class TransAm(nn.Module):
    def __init__(self,feature_size=250,num_layers=1,dropout=0.1):
        super(TransAm, self).__init__()
        self.model_type = 'Transformer'
        self.j=0
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(feature_size)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_size, nhead=10, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder = nn.Linear(feature_size,1)
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self,src):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            device = src.device
            mask = self._generate_square_subsequent_mask(len(src)).to(device)
            self.src_mask = mask
        if self.j==0:
            # print(src.shape[0],",", src.shape[1],",", src.shape[2])
            # print(src)
            self.j+=1
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src,self.src_mask)#, self.src_mask)
        output = self.decoder(output)
        return output

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask



# if window is 100 and prediction step is 1
# in -> [0..99]
# target -> [1..100]
def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    # print(L)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+output_window:i+tw+output_window]
        inout_seq.append((train_seq ,train_label))
    return torch.FloatTensor(inout_seq)

def get_data():
    # construct a littel toy dataset
    time        = np.arange(0, 400, 0.1)
    # amplitude   = np.sin(time) + np.sin(time*0.05) +np.sin(time*0.12) *np.random.normal(-0.2, 0.2, len(time))

    # print(time)
    # print(time.shape[0])

    from sklearn.preprocessing import MinMaxScaler

    #loading weather data from a file
    series = pd.read_csv('Normalize_4000.csv')
    f = series['Frequency'][:].tolist()
    amplitude   = np.sin(time) + np.sin(time*0.05) +np.sin(time*0.12) *f

    # looks like normalizing input values curtial for the model
    scaler = MinMaxScaler(feature_range=(-1, 1))

    # amplitude = scaler.fit_transform(series.to_numpy().reshape(-1, 1)).reshape(-1)
    amplitude = scaler.fit_transform(amplitude.reshape(-1, 1)).reshape(-1)

    # for i in range(4000):
    #     print(amplitude[i])

    sampels = 1000
    train_data = amplitude[:sampels]
    test_data = amplitude[sampels:]

    # convert our train data into a pytorch train tensor
    #train_tensor = torch.FloatTensor(train_data).view(-1)
    # todo: add comment..
    train_sequence = create_inout_sequences(train_data,input_window)
    train_sequence = train_sequence[:-output_window] #todo: fix hack? -> din't think this through, looks like the last n sequences are to short, so I just remove them. Hackety Hack..

    #test_data = torch.FloatTensor(test_data).view(-1)
    test_data = create_inout_sequences(test_data,input_window)
    test_data = test_data[:-output_window] #todo: fix hack?

    return train_sequence.to(device),test_data.to(device)

def get_batch(source, i,batch_size):
    seq_len = min(batch_size, len(source) - 1 - i)
    data = source[i:i+seq_len]
    input = torch.stack(torch.stack([item[0] for item in data]).chunk(input_window,1)) # 1 is feature size
    target = torch.stack(torch.stack([item[1] for item in data]).chunk(input_window,1))
    return input, target


def train(train_data):
    model.train() # Turn on the train mode \o/
    total_loss = 0.
    start_time = time.time()

    for batch, i in enumerate(range(0, len(train_data) - 1, batch_size)):
        data, targets = get_batch(train_data, i,batch_size)
        # if i==0:
            # print(data.shape[0],",", data.shape[1],",", data.shape[2])
            # print(data)
        optimizer.zero_grad()
        output = model.forward(data)
        # if i==0:
            # print(output.shape[0],",",output.shape[1],",",output.shape[2])
            # print(output)
        loss = criterion(output, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.7)
        optimizer.step()

        total_loss += loss.item()
        log_interval = int(len(train_data) / batch_size / 5)
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.6f} | {:5.2f} ms | '
                  'loss {:5.5f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // batch_size, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def plot_and_loss(eval_model, data_source,epoch):
    eval_model.eval()
    total_loss = 0.
    test_result = torch.Tensor(0)
    truth = torch.Tensor(0)
    with torch.no_grad():
        for i in range(0, len(data_source) - 1):
            data, target = get_batch(data_source, i,1)
            output = eval_model(data)
            total_loss += criterion(output, target).item()
            test_result = torch.cat((test_result, output[-1].view(-1).cpu()), 0)
            truth = torch.cat((truth, target[-1].view(-1).cpu()), 0)
    print(test_result.size())
    #test_result = test_result.cpu().numpy() -> no need to detach stuff..
    len(test_result)

    pyplot.plot(test_result,color="red")
    pyplot.plot(truth[:2500],color="blue")
    # pyplot.plot(test_result-truth,color="green")
    pyplot.grid(True, which='both')
    pyplot.axhline(y=0, color='k')
    pyplot.savefig('graph/transformer-epoch%d.png'%epoch)
    pyplot.close()

    return total_loss / i


# predict the next n steps based on the input data
def predict_future(eval_model, data_source,steps, epoch):
    eval_model.eval()
    total_loss = 0.
    test_result = torch.Tensor(0)
    truth = torch.Tensor(0)
    data, target = get_batch(data_source, 0,1)
    # print(len(data))
    with torch.no_grad():
        for i in range(0, 10):
            output = eval_model(data[-input_window:])
            data = torch.cat((data, output[-1:]))

    data = data.cpu().view(-1)

    # I used this plot to visualize if the model pics up any long therm struccture within the data.
    pyplot.plot(data,color="red")
    pyplot.plot(data[:input_window],color="blue")
    pyplot.grid(True, which='both')
    pyplot.axhline(y=0, color='k')
    pyplot.savefig('graph/transformer-future%d.png'%epoch)
    pyplot.close()


def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    eval_batch_size = 1000
    with torch.no_grad():
        for i in range(0, len(data_source) - 1, eval_batch_size):
            data, targets = get_batch(data_source, i,eval_batch_size)
            output = eval_model(data)
            total_loss += len(data[0])* criterion(output, targets).cpu().item()
    return total_loss / len(data_source)

train_data, val_data = get_data()
model = TransAm().to(device)

criterion = nn.MSELoss()
lr = 0.005
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

best_val_loss = float("inf")
epochs = 100 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(train_data)

    if(epoch % 10 == 0):
        val_loss = plot_and_loss(model, val_data,epoch)
        predict_future(model, val_data, 200, epoch)
    else:
        val_loss = evaluate(model, val_data)

    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.5f} | valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    #if val_loss < best_val_loss:
    #    best_val_loss = val_loss
    #    best_model = model

    scheduler.step()

#src = torch.rand(input_window, batch_size, 1) # (source sequence length,batch size,feature number)
#out = model(src)
#
#print(out)
#print(out.shape)



| epoch   1 |    17/   89 batches | lr 0.005000 | 15.45 ms | loss 13.28888 | ppl 590590.38
| epoch   1 |    34/   89 batches | lr 0.005000 | 13.12 ms | loss 0.15743 | ppl     1.17
| epoch   1 |    51/   89 batches | lr 0.005000 | 12.09 ms | loss 0.16324 | ppl     1.18
| epoch   1 |    68/   89 batches | lr 0.005000 | 11.31 ms | loss 0.10322 | ppl     1.11
| epoch   1 |    85/   89 batches | lr 0.005000 | 10.92 ms | loss 0.09764 | ppl     1.10
-----------------------------------------------------------------------------------------
| end of epoch   1 | time:  1.86s | valid loss 0.23705 | valid ppl     1.27
-----------------------------------------------------------------------------------------
| epoch   2 |    17/   89 batches | lr 0.004513 | 12.88 ms | loss 0.05841 | ppl     1.06
| epoch   2 |    34/   89 batches | lr 0.004513 | 11.24 ms | loss 0.03169 | ppl     1.03
| epoch   2 |    51/   89 batches | lr 0.004513 | 10.82 ms | loss 0.06227 | ppl     1.06
| epoch   2 |    68/   89 batc



| epoch  11 |    17/   89 batches | lr 0.002844 | 13.70 ms | loss 0.00809 | ppl     1.01
| epoch  11 |    34/   89 batches | lr 0.002844 | 11.86 ms | loss 0.00761 | ppl     1.01
| epoch  11 |    51/   89 batches | lr 0.002844 | 11.12 ms | loss 0.00582 | ppl     1.01
| epoch  11 |    68/   89 batches | lr 0.002844 | 10.68 ms | loss 0.00461 | ppl     1.00
| epoch  11 |    85/   89 batches | lr 0.002844 | 11.11 ms | loss 0.00427 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  11 | time:  1.79s | valid loss 0.00849 | valid ppl     1.01
-----------------------------------------------------------------------------------------
| epoch  12 |    17/   89 batches | lr 0.002702 | 13.04 ms | loss 0.00666 | ppl     1.01
| epoch  12 |    34/   89 batches | lr 0.002702 | 11.29 ms | loss 0.00612 | ppl     1.01
| epoch  12 |    51/   89 batches | lr 0.002702 | 10.89 ms | loss 0.00670 | ppl     1.01
| epoch  12 |    68/   89 batche



| epoch  21 |    34/   89 batches | lr 0.001703 | 10.79 ms | loss 0.00324 | ppl     1.00
| epoch  21 |    51/   89 batches | lr 0.001703 | 10.78 ms | loss 0.00508 | ppl     1.01
| epoch  21 |    68/   89 batches | lr 0.001703 | 10.82 ms | loss 0.00421 | ppl     1.00
| epoch  21 |    85/   89 batches | lr 0.001703 | 10.74 ms | loss 0.00293 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  21 | time:  1.73s | valid loss 0.00505 | valid ppl     1.01
-----------------------------------------------------------------------------------------
| epoch  22 |    17/   89 batches | lr 0.001618 | 12.58 ms | loss 0.00428 | ppl     1.00
| epoch  22 |    34/   89 batches | lr 0.001618 | 11.14 ms | loss 0.00313 | ppl     1.00
| epoch  22 |    51/   89 batches | lr 0.001618 | 10.83 ms | loss 0.00367 | ppl     1.00
| epoch  22 |    68/   89 batches | lr 0.001618 | 10.87 ms | loss 0.00269 | ppl     1.00
| epoch  22 |    85/   89 batche



| epoch  31 |    17/   89 batches | lr 0.001020 | 13.45 ms | loss 0.00321 | ppl     1.00
| epoch  31 |    34/   89 batches | lr 0.001020 | 11.67 ms | loss 0.00246 | ppl     1.00
| epoch  31 |    51/   89 batches | lr 0.001020 | 10.97 ms | loss 0.00266 | ppl     1.00
| epoch  31 |    68/   89 batches | lr 0.001020 | 10.69 ms | loss 0.00226 | ppl     1.00
| epoch  31 |    85/   89 batches | lr 0.001020 | 10.81 ms | loss 0.00199 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  31 | time:  1.77s | valid loss 0.00412 | valid ppl     1.00
-----------------------------------------------------------------------------------------
| epoch  32 |    17/   89 batches | lr 0.000969 | 12.75 ms | loss 0.00313 | ppl     1.00
| epoch  32 |    34/   89 batches | lr 0.000969 | 11.53 ms | loss 0.00240 | ppl     1.00
| epoch  32 |    51/   89 batches | lr 0.000969 | 10.78 ms | loss 0.00266 | ppl     1.00
| epoch  32 |    68/   89 batche



| epoch  41 |    34/   89 batches | lr 0.000610 | 10.94 ms | loss 0.00201 | ppl     1.00
| epoch  41 |    51/   89 batches | lr 0.000610 | 10.93 ms | loss 0.00219 | ppl     1.00
| epoch  41 |    68/   89 batches | lr 0.000610 | 10.83 ms | loss 0.00204 | ppl     1.00
| epoch  41 |    85/   89 batches | lr 0.000610 | 10.88 ms | loss 0.00178 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  41 | time:  1.73s | valid loss 0.00287 | valid ppl     1.00
-----------------------------------------------------------------------------------------
| epoch  42 |    17/   89 batches | lr 0.000580 | 13.03 ms | loss 0.00257 | ppl     1.00
| epoch  42 |    34/   89 batches | lr 0.000580 | 11.56 ms | loss 0.00203 | ppl     1.00
| epoch  42 |    51/   89 batches | lr 0.000580 | 10.81 ms | loss 0.00222 | ppl     1.00
| epoch  42 |    68/   89 batches | lr 0.000580 | 10.85 ms | loss 0.00203 | ppl     1.00
| epoch  42 |    85/   89 batche



| epoch  51 |    34/   89 batches | lr 0.000365 | 10.99 ms | loss 0.00186 | ppl     1.00
| epoch  51 |    51/   89 batches | lr 0.000365 | 10.78 ms | loss 0.00195 | ppl     1.00
| epoch  51 |    68/   89 batches | lr 0.000365 | 11.21 ms | loss 0.00187 | ppl     1.00
| epoch  51 |    85/   89 batches | lr 0.000365 | 10.74 ms | loss 0.00159 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  51 | time:  1.74s | valid loss 0.00296 | valid ppl     1.00
-----------------------------------------------------------------------------------------
| epoch  52 |    17/   89 batches | lr 0.000347 | 12.70 ms | loss 0.00224 | ppl     1.00
| epoch  52 |    34/   89 batches | lr 0.000347 | 10.92 ms | loss 0.00184 | ppl     1.00
| epoch  52 |    51/   89 batches | lr 0.000347 | 11.28 ms | loss 0.00193 | ppl     1.00
| epoch  52 |    68/   89 batches | lr 0.000347 | 10.79 ms | loss 0.00184 | ppl     1.00
| epoch  52 |    85/   89 batche



| epoch  61 |    17/   89 batches | lr 0.000219 | 14.55 ms | loss 0.00213 | ppl     1.00
| epoch  61 |    34/   89 batches | lr 0.000219 | 12.07 ms | loss 0.00169 | ppl     1.00
| epoch  61 |    51/   89 batches | lr 0.000219 | 11.13 ms | loss 0.00174 | ppl     1.00
| epoch  61 |    68/   89 batches | lr 0.000219 | 10.92 ms | loss 0.00178 | ppl     1.00
| epoch  61 |    85/   89 batches | lr 0.000219 | 10.72 ms | loss 0.00162 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  61 | time:  1.81s | valid loss 0.00285 | valid ppl     1.00
-----------------------------------------------------------------------------------------
| epoch  62 |    17/   89 batches | lr 0.000208 | 12.41 ms | loss 0.00205 | ppl     1.00
| epoch  62 |    34/   89 batches | lr 0.000208 | 11.02 ms | loss 0.00166 | ppl     1.00
| epoch  62 |    51/   89 batches | lr 0.000208 | 10.89 ms | loss 0.00182 | ppl     1.00
| epoch  62 |    68/   89 batche



| epoch  71 |    17/   89 batches | lr 0.000131 | 11.78 ms | loss 0.00192 | ppl     1.00
| epoch  71 |    34/   89 batches | lr 0.000131 | 10.92 ms | loss 0.00177 | ppl     1.00
| epoch  71 |    51/   89 batches | lr 0.000131 | 10.85 ms | loss 0.00166 | ppl     1.00
| epoch  71 |    68/   89 batches | lr 0.000131 | 10.79 ms | loss 0.00173 | ppl     1.00
| epoch  71 |    85/   89 batches | lr 0.000131 | 10.84 ms | loss 0.00146 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  71 | time:  1.73s | valid loss 0.00227 | valid ppl     1.00
-----------------------------------------------------------------------------------------
| epoch  72 |    17/   89 batches | lr 0.000124 | 12.71 ms | loss 0.00188 | ppl     1.00
| epoch  72 |    34/   89 batches | lr 0.000124 | 11.22 ms | loss 0.00163 | ppl     1.00
| epoch  72 |    51/   89 batches | lr 0.000124 | 11.03 ms | loss 0.00168 | ppl     1.00
| epoch  72 |    68/   89 batche



| epoch  81 |    17/   89 batches | lr 0.000078 | 12.87 ms | loss 0.00192 | ppl     1.00
| epoch  81 |    34/   89 batches | lr 0.000078 | 11.48 ms | loss 0.00160 | ppl     1.00
| epoch  81 |    51/   89 batches | lr 0.000078 | 10.84 ms | loss 0.00159 | ppl     1.00
| epoch  81 |    68/   89 batches | lr 0.000078 | 10.74 ms | loss 0.00170 | ppl     1.00
| epoch  81 |    85/   89 batches | lr 0.000078 | 11.08 ms | loss 0.00143 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  81 | time:  1.76s | valid loss 0.00241 | valid ppl     1.00
-----------------------------------------------------------------------------------------
| epoch  82 |    17/   89 batches | lr 0.000075 | 12.59 ms | loss 0.00187 | ppl     1.00
| epoch  82 |    34/   89 batches | lr 0.000075 | 11.43 ms | loss 0.00156 | ppl     1.00
| epoch  82 |    51/   89 batches | lr 0.000075 | 10.86 ms | loss 0.00161 | ppl     1.00
| epoch  82 |    68/   89 batche



| epoch  91 |    34/   89 batches | lr 0.000047 | 10.97 ms | loss 0.00164 | ppl     1.00
| epoch  91 |    51/   89 batches | lr 0.000047 | 10.95 ms | loss 0.00167 | ppl     1.00
| epoch  91 |    68/   89 batches | lr 0.000047 | 10.84 ms | loss 0.00171 | ppl     1.00
| epoch  91 |    85/   89 batches | lr 0.000047 | 10.70 ms | loss 0.00149 | ppl     1.00
-----------------------------------------------------------------------------------------
| end of epoch  91 | time:  1.72s | valid loss 0.00280 | valid ppl     1.00
-----------------------------------------------------------------------------------------
| epoch  92 |    17/   89 batches | lr 0.000045 | 12.45 ms | loss 0.00201 | ppl     1.00
| epoch  92 |    34/   89 batches | lr 0.000045 | 10.93 ms | loss 0.00161 | ppl     1.00
| epoch  92 |    51/   89 batches | lr 0.000045 | 10.78 ms | loss 0.00171 | ppl     1.00
| epoch  92 |    68/   89 batches | lr 0.000045 | 10.90 ms | loss 0.00171 | ppl     1.00
| epoch  92 |    85/   89 batche