In [1]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

In [2]:
# Some convenience helper functions used throughout the notebook


def is_interactive_notebook():
    return __name__ == "__main__"


def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)


def execute_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)


class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None


class DummyScheduler:
    def step(self):
        None

In [3]:
class Generator(nn.Module):  #the final linear layer after decoding
    "Define standard linear + softmax generation step."

    def __init__(self, d_model): 
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, 1)
        #note that in our case, we want our output to be the number of predicted deaths, so there is no vocab size
    
    def forward(self, x): 
        x = self.proj(x) #we will get a float out which is okay
        return x

In [4]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 
#Note we need deepcopies because the layers must be different, otherwise they will be repeated layers via reference

In [5]:
class LayerNorm(nn.Module):
    "Construct a layernorm module (See https://arxiv.org/pdf/1607.06450.pdf for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout) #dropout to prevent overfitting

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x))) #returns x+sublayer(x) incl norm and dropout

In [6]:
class Encoder(nn.Module):
    "Core encoder is a stack of N layers"

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined later)"

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2) #one for after self-att, one for after ff, see below
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 
        #Note that the 2nd input needs to be a function (representing the sublayer)
        return self.sublayer[1](x, self.feed_forward)

In [7]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask): #memory refers to context from encoder
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x) 

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3) #one for each of the above

    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [8]:
def subsequent_mask(size): #also known as causal mask
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    ) #torch.triu returns the uppertriangular matrix of input matrix
    return subsequent_mask == 0 #this inverts the 1s and 0s, giving upper triangular of 0s, lower triangular of 1s

In [9]:
def attention(query, key, value, mask=None, dropout=None): #to be used later for selt and cross att sublayers
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1) #note that this d_k is typically the head size
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        #print(scores.size(),mask.size())
        scores = scores.masked_fill(mask == 0, -1e9) #fills upper triangular with -1e9
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn #returning attention matrix is good for data visualisation later

In [10]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1): #h is the number of heads
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) #.view() changes dimension without changing original 
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k) #unstacks
        )
        del query
        del key
        del value
        return self.linears[-1](x)

In [11]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu())) #standard ff relu layer

In [12]:
#NOTE: MAJOR changes from here from on
"""
We need to embed 3 parameters: Variables, Country/Location, Time.
Time embeddding will be done via the standard sinusoidal positional encoding in d dim (FUTURE WORK: use learnt time embeddings as suggested by https://arxiv.org/abs/2109.12218)
Country/Loc & Variable Embedding will be learnt. Var emb will be d-1 dim, Country/Loc will be d-1 dim and concatenated to the associated value. 
Note: In the source above they had a learnt time emb, and concatenated time embedding to the value, then added the spatial embedding. In ours we will concatenate var,loc emb and add positional emb.
"""
class Embeddings(nn.Module): 
    def __init__(self, d_emb, vocab): #d_emb will be d_model-1, vocab will be num of countries/num of vars
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_emb)
        self.d_emb = d_emb

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_emb)
    
"""
Note that we have 3 variables of 14 past days, so we need to append appropriately. Also, the format of our input is
[<14 days of past cases><past vacc><13 days of past deaths><day 14><day 14 repeated for decoder><day 15 and day 16>]
our target outcome is [<day 15,16,17>]
"""

class InpPositionalEncoding(nn.Module):
    "Implement the Input PE function. "

    def __init__(self, d_model, dropout, past, var_num): #past is no. of days of past data
        super(InpPositionalEncoding, self).__init__()
        self.past = past
        self.var_num = var_num
        self.dropout = nn.Dropout(p=dropout)
    
        # Compute the positional encodings once in log space.
        pe = torch.zeros(past, d_model)
        position = torch.arange(0, past).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) 
        self.register_buffer("pe", pe) #buffers are by default persistent
        

    def forward(self, x): 
        past = self.past
        var_num = self.var_num
        assert x.size(1)%past==0 #x should have shape batchsize*(past*var_num)*d
        
        for k in range(var_num):
            x[:, k*past:(k+1)*past] = x[:, k*past:(k+1)*past] + self.pe[:, :past].requires_grad_(False) 
        return self.dropout(x)

class OutPositionalEncoding(nn.Module):
    "Implement the Output PE function. "

    def __init__(self, d_model, dropout, past, pred): #pred is no. of days want to pred
        super(OutPositionalEncoding, self).__init__()
        self.past = past
        self.pred = pred
        self.dropout = nn.Dropout(p=dropout)
    
        # Compute the positional encodings once in log space.
        pe = torch.zeros(past+pred-1, d_model) #Note the -1 is due to the repeated day 14
        position = torch.arange(0, past+pred-1).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 
        

    def forward(self, x):
        #x has shape batchsize*(size<=pred)*d_model
        pred = self.pred
        if x.size(1)==pred:
            x+=self.pe[:, -pred:].requires_grad_(False) 
        elif x.size(1)<pred:
            x += self.pe[:, -pred:-pred+x.size(1)].requires_grad_(False)
        return self.dropout(x)



In [13]:
'''def example_positional1():
    pe = InpPositionalEncoding(20, 0, 14, 3)
    y = pe.forward(torch.zeros(1, 42, 20)) #first index is batches, 2nd index is dates, 3rd index is dim of model
    print([y[0, :, dim] for dim in [3,4,5]])
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "embedding": y[0, :, dim],
                    "dimension": dim,
                    "position": list(range(42)),
                }
            )
            for dim in [4,5,6,7]
        ]
    )

    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800)
        .encode(x="position", y="embedding", color="dimension:N")
        .interactive()
    )

def example_positional2():
    pe = OutPositionalEncoding(20, 0, 200, 100)
    y = pe.forward(torch.zeros(1, 86, 20)) #first index is batches, 2nd index is dates, 3rd index is dim of model
    print([y[0, :, dim] for dim in [3,4,5]])
    data = pd.concat(
        [
            pd.DataFrame(
                {
                    "embedding": y[0, :, dim],
                    "dimension": dim,
                    "position": list(range(86)),
                }
            )
            for dim in [4,5,6,7]
        ]
    )

    return (
        alt.Chart(data)
        .mark_line()
        .properties(width=800)
        .encode(x="position", y="embedding", color="dimension:N")
        .interactive()
    )

show_example(example_positional1)'''


'def example_positional1():\n    pe = InpPositionalEncoding(20, 0, 14, 3)\n    y = pe.forward(torch.zeros(1, 42, 20)) #first index is batches, 2nd index is dates, 3rd index is dim of model\n    print([y[0, :, dim] for dim in [3,4,5]])\n    data = pd.concat(\n        [\n            pd.DataFrame(\n                {\n                    "embedding": y[0, :, dim],\n                    "dimension": dim,\n                    "position": list(range(42)),\n                }\n            )\n            for dim in [4,5,6,7]\n        ]\n    )\n\n    return (\n        alt.Chart(data)\n        .mark_line()\n        .properties(width=800)\n        .encode(x="position", y="embedding", color="dimension:N")\n        .interactive()\n    )\n\ndef example_positional2():\n    pe = OutPositionalEncoding(20, 0, 200, 100)\n    y = pe.forward(torch.zeros(1, 86, 20)) #first index is batches, 2nd index is dates, 3rd index is dim of model\n    print([y[0, :, dim] for dim in [3,4,5]])\n    data = pd.concat(\n     

In [14]:
def spreademb(var_num, size):
    x = [[k for _ in range(size)] for k in range(var_num)]
    x = torch.LongTensor(x)
    x = x.flatten()
    return x

spreademb(3,7)

tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])

In [15]:
class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """
#IMPORTANT: CHANGE SRC&TGT EMB?
    def __init__(self, encoder, decoder, loc_emb, var_emb, generator, inppos, outpos, past, pred, var_num): 
        #inputs should be encoder modules
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.loc_emb = loc_emb
        self.var_emb = var_emb
        self.generator = generator
        self.past = past
        self.pred = pred
        self.var_num = var_num
        self.inppos = inppos
        self.outpos = outpos

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.generator(self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask))
    #self.encode gives the context to be passed to decoder

    def encode(self, src, src_mask): 
        "src will be batchsize*[<country><case num for past days><vacc num><death num>]"
        "Create emb to be concat"
        loc = src[:, 0].type(torch.LongTensor).to(device=0) #shape batchsize*(d-1) ; convert to long for emb layer
        loc_emb = self.loc_emb(loc) 
        loc_emb = loc_emb.view(src.size(0),1,loc_emb.size(1)) #shape batchsize*1*(d-1)
        #print(src.size(0),src.size(1)-1,loc_emb.size(2))
        
        embtable = torch.zeros(src.size(0),src.size(1)-1,loc_emb.size(2)).to(device=0)
        #shape batchsize*(past*var_num = 3*14)*(d-1)

        embtable += loc_emb #adds loc emb
        embtable += self.var_emb(spreademb(self.var_num, self.past).to(device=0)) #adds var emb for all 3 vars spread out
        "Concat emb"
        src = src[:, 1:]
        src = src.view(src.size(0),src.size(1),1) #add another dim to cat later
        src = torch.cat((src,embtable),dim=2)

        return self.encoder(self.inppos(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        "tgt will be batchsize*[<country><death num for (x < pred) days>]"
        loc = tgt[:, 0].type(torch.LongTensor).to(device=0) #shape batchsize*(d-1)
        loc_emb = self.loc_emb(loc) 
        
        loc_emb = loc_emb.view(tgt.size(0),1,loc_emb.size(1)) #shape batchsize*1*(d-1)
        embtable = torch.zeros((tgt.size(0),tgt.size(1)-1,loc_emb.size(2))).to(device=0)
        #shape batchsize*(pred)*(d-1)
        embtable += loc_emb #adds loc emb
        embtable += self.var_emb(torch.LongTensor([self.var_num-1 for _ in range(tgt.size(1)-1)]).to(device=0)) #var emb of deaths
        
        tgt = tgt[:, 1:] #shape batchsize*(pred)
        tgt = tgt.view(tgt.size(0),tgt.size(1),1) #batchsize*pred*1
        tgt = torch.cat((tgt,embtable),dim=2) #batchsize*pred*(d)
        
        return self.decoder(self.outpos(tgt), memory, src_mask, tgt_mask)
    
def make_model(
    loc_num, var_num=3, past=14, pred=3, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1
):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    inppos = InpPositionalEncoding(d_model, dropout, past, var_num)
    outpos = OutPositionalEncoding(d_model, dropout, past, pred)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        Embeddings(d_model-1, loc_num), #loc emb 
        Embeddings(d_model-1, var_num), #var emb
        Generator(d_model),
        inppos,
        outpos,
        past, pred, var_num
    )

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

In [16]:
'''def inference_test():
    test_model = make_model(loc_num=1, var_num=2, past=5, pred=7)
    test_model.cuda(0)
    test_model.eval()
    src = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5]]).to(device=0)
    src_mask = torch.ones(1, 1, 10).to(device=0)

    memory = test_model.encode(src, src_mask)
    ys = torch.LongTensor([[0,5]]).to(device=0) #repeat last day

    for i in range(6):
        out = test_model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)-1).type_as(src.data)
        )
        out = test_model.generator(out)
        out = out.view(out.size(0),out.size(1)).to(device=0)
        out = out[:, -1:] #take last input
        #print(ys.size(),out.size())
        ys = torch.cat(
            [ys, out], dim=1
        )

    print("Example Untrained Model Prediction:", ys)


def run_tests():
    for _ in range(10):
        inference_test()


show_example(run_tests)'''

'def inference_test():\n    test_model = make_model(loc_num=1, var_num=2, past=5, pred=7)\n    test_model.cuda(0)\n    test_model.eval()\n    src = torch.LongTensor([[0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5]]).to(device=0)\n    src_mask = torch.ones(1, 1, 10).to(device=0)\n\n    memory = test_model.encode(src, src_mask)\n    ys = torch.LongTensor([[0,5]]).to(device=0) #repeat last day\n\n    for i in range(6):\n        out = test_model.decode(\n            memory, src_mask, ys, subsequent_mask(ys.size(1)-1).type_as(src.data)\n        )\n        out = test_model.generator(out)\n        out = out.view(out.size(0),out.size(1)).to(device=0)\n        out = out[:, -1:] #take last input\n        #print(ys.size(),out.size())\n        ys = torch.cat(\n            [ys, out], dim=1\n        )\n\n    print("Example Untrained Model Prediction:", ys)\n\n\ndef run_tests():\n    for _ in range(10):\n        inference_test()\n\n\nshow_example(run_tests)'

In [17]:
"""Typically, our src will have no padding. For pred=14, pred=3,
    The input we want for src is [<country><first 14 days of cases><.. of vacc><.. of deaths>]
    format for tgt (when training) is [<country><day14><day15><day16><day17>]
    """
class Batch:
    """Object for holding a batch of data with mask during training."""

    def __init__(self, src, tgt=None):
    
        self.src = src.to(device=0)
        self.src_mask = None
        
        if tgt is not None:
            self.tgt = tgt[:, :-1].to(device=0) #the location will be processed by loc_emb layer and be added into dim later
            self.raw = tgt[:, 1:-1].to(device=0) #death data (exclude country column)
            tgt_y = tgt[:, 2:] #output of model
            tgt_y = tgt_y.type(dtype=torch.FloatTensor) #change to float to calculate mse loss later
            self.tgt_y = tgt_y.view(tgt_y.size(0),tgt_y.size(1),1).to(device=0)
            self.tgt_mask = self.make_std_mask(self.raw).to(device=0)
            #self.ntokens = (self.tgt_y != pad).data.sum()
            self.ntokens = torch.flatten(self.tgt_y).size(0) 

    @staticmethod
    def make_std_mask(tgt):
        "Create a mask to hide padding and future words."
        tgt_mask = subsequent_mask(tgt.size(-1)) #dtype uint8 
        return tgt_mask

"""function for processing raw batch data into a Batch object. rawbatch should have shape: batchsize*(3*past+pred+1)
Format for each row of rawbatch is [<country><first 14 days of cases><.. of vacc><.. of deaths><day 15-17 of deaths>]
rawbatch should be a torch tensor from dataloader
"""
def ProcessRaw(rawbatch, past=14, pred=3): 
    src = rawbatch[:, :-pred]
    tgt = rawbatch[:, past*3:]
    loc = rawbatch[:, 0:1]
    tgt = torch.cat((loc,tgt),dim=1)
    
    return Batch(src,tgt)

class TrainState:
    """Track number of steps, examples, and tokens processed"""

    step: int = 0  # Steps in the current epoch
    accum_step: int = 0  # Number of gradient accumulation steps
    samples: int = 0  # total # of examples used
    tokens: int = 0  # total # of tokens processed

In [18]:
"Note: We do not use label smoothing as we are not using cross entropy as our loss function. Our output is just 1 number so we also do not softmax."
def run_epoch(
    data_iter,
    model,
    loss_compute, #we use nn.MSELoss()
    optimizer,
    scheduler,
    mode="train",
    accum_iter=1,
    train_state=TrainState(),
):
    """Train a single epoch"""
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, batch in enumerate(data_iter): #will yield batches (the class above)
        out = model.forward(
            batch.src, batch.tgt, batch.src_mask, batch.tgt_mask
        )
        loss = loss_compute(out, batch.tgt_y) 
        if mode == "train" or mode == "train+log":
            loss.backward()
            train_state.step += 1
            train_state.samples += batch.src.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()

        total_loss += loss*batch.ntokens
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]["lr"]
            elapsed = time.time() - start
            print(
                (
                    "Epoch Step: %6d | Accumulation Step: %3d | Loss: %6.8f "
                    + "| Tokens / Sec: %7.1f | Learning Rate: %6.1e"
                )
                % (i, n_accum, loss, tokens / elapsed, lr)
            )
            start = time.time()
            tokens = 0
        del loss
    return total_loss / total_tokens, train_state

In [19]:
def rate(step, model_size, factor, warmup):
    """
    we have to default the step to 1 for LambdaLR function
    to avoid zero raising to negative power.
    Returns learning rate according to Adam optimizer
    """
    if step == 0:
        step = 1
    return factor * (
        model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
    ) 

'''def example_learning_schedule():
    opts = [
        [512, 1, 4000],  # example 1
        [512, 1, 8000],  # example 2
        [256, 1, 4000],  # example 3
    ]

    dummy_model = torch.nn.Linear(1, 1)
    learning_rates = []

    # we have 3 examples in opts list.
    for idx, example in enumerate(opts):
        # run 20000 epoch for each example
        optimizer = torch.optim.Adam(
            dummy_model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9
        )
        lr_scheduler = LambdaLR(
            optimizer=optimizer, lr_lambda=lambda step: rate(step, *example)
        )
        tmp = []
        # take 20K dummy training steps, save the learning rate at each step
        for step in range(20000):
            tmp.append(optimizer.param_groups[0]["lr"])
            optimizer.step()
            lr_scheduler.step()
        learning_rates.append(tmp)

    learning_rates = torch.tensor(learning_rates)

    # Enable altair to handle more than 5000 rows
    alt.data_transformers.disable_max_rows()

    opts_data = pd.concat(
        [
            pd.DataFrame(
                {
                    "Learning Rate": learning_rates[warmup_idx, :],
                    "model_size:warmup": ["512:4000", "512:8000", "256:4000"][
                        warmup_idx
                    ],
                    "step": range(20000),
                }
            )
            for warmup_idx in [0, 1, 2]
        ]
    )

    return (
        alt.Chart(opts_data)
        .mark_line()
        .properties(width=600)
        .encode(x="step", y="Learning Rate", color="model_size:warmup:N")
        .interactive()
    )


example_learning_schedule()'''

'def example_learning_schedule():\n    opts = [\n        [512, 1, 4000],  # example 1\n        [512, 1, 8000],  # example 2\n        [256, 1, 4000],  # example 3\n    ]\n\n    dummy_model = torch.nn.Linear(1, 1)\n    learning_rates = []\n\n    # we have 3 examples in opts list.\n    for idx, example in enumerate(opts):\n        # run 20000 epoch for each example\n        optimizer = torch.optim.Adam(\n            dummy_model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9\n        )\n        lr_scheduler = LambdaLR(\n            optimizer=optimizer, lr_lambda=lambda step: rate(step, *example)\n        )\n        tmp = []\n        # take 20K dummy training steps, save the learning rate at each step\n        for step in range(20000):\n            tmp.append(optimizer.param_groups[0]["lr"])\n            optimizer.step()\n            lr_scheduler.step()\n        learning_rates.append(tmp)\n\n    learning_rates = torch.tensor(learning_rates)\n\n    # Enable altair to handle more than 5000 r

In [20]:
def data_gen(V, batch_size, nbatches):
    "Generate random data for a src-tgt copy task."
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch_size, 10))
        data[:, 0] = 1
        src = data.requires_grad_(False).clone().detach()
        tgt = data.requires_grad_(False).clone().detach()
        yield Batch(src, tgt, 0)

def inference(model, src, src_mask, pred): #input trained model
    memory = model.encode(src, src_mask)
    loc = src[:, 0:]
    initial = src[:, -1:] #repeat last day
    ys = torch.cat([loc,initial], dim=1)

    for i in range(pred):
        out = test_model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)-1).type_as(src.data)
        )
        out = test_model.generator(out)
        out = torch.round(out).type(torch.LongTensor) #we round off here so we get int death predictions
        out = out.view(out.size(0),out.size(1))
        out = out[:, -1:] #take last input
        #print(ys.size(),out.size())
        ys = torch.cat(
            [ys, out], dim=1
        )
    return ys


In [21]:
"We do not need a collate batch function as we have preprocessed all our data including shuffle"
class DfParser(torch.utils.data.Dataset): #for torch dataloader later
 
  def __init__(self,file_name):
    df=pd.read_csv(file_name)
    x=df.iloc[:,1:].values
 
    self.x=torch.tensor(x).type(torch.FloatTensor)
 
  def __len__(self):
    return len(self.x)
   
  def __getitem__(self,idx):
    return self.x[idx]

def create_dataloaders(
    device,
    batch_size=12000,
    is_distributed=True,
):
    # def create_dataloaders(batch_size=12000):
    train, valid, test = DfParser('norm_train.csv'), DfParser('norm_valid.csv'), DfParser('norm_test.csv')

    train_dataloader = DataLoader(
        train,
        batch_size=batch_size,
        shuffle=False,
    )
    valid_dataloader = DataLoader(
        valid,
        batch_size=batch_size,
        shuffle=False,
    )
    test_dataloader = DataLoader(
        test,
        batch_size=batch_size,
        shuffle=False,
    )
    return train_dataloader, valid_dataloader, test_dataloader

In [22]:
def train_worker(
    gpu,
    ngpus_per_node,
    config, #batchsize, maxpadding, baselr, warmup, num_epochs, accum_iter, file prefix
    is_distributed=False,
):
    print(f"Train worker process using GPU: {gpu} for training", flush=True)
    torch.cuda.set_device(gpu)

    d_model = 512
    model = make_model(loc_num=189, var_num=3, past=14, pred=3, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1)
    model.cuda(gpu)
    module = model
    #print(next(model.parameters()).is_cuda)
    is_main_process = True
    
    if is_distributed:
        dist.init_process_group(
            "nccl", init_method="env://", rank=gpu, world_size=ngpus_per_node
        )
        model = DDP(model, device_ids=[gpu])
        module = model.module
        is_main_process = gpu == 0

    train_dataloader, valid_dataloader, test_dataloader = create_dataloaders(
        gpu,
        batch_size=config["batch_size"] // ngpus_per_node,
        is_distributed=is_distributed,
    )
    del test_dataloader

    optimizer = torch.optim.Adam(
        model.parameters(), lr=config["base_lr"], betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, d_model, factor=1, warmup=config["warmup"]
        ),
    )
    train_state = TrainState()
    
    
    try:
        checkpoint = torch.load(f'{config["file_prefix"]}latest.pt')
        print(f'Model checkpoint found')
        try:
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            del checkpoint #this was important when running for me as I ran out of memory if I loaded
            torch.cuda.empty_cache()
        except:
            print('Error Loading')
    except:
        print(f'Model Checkpoint not found.')
    
    
    for epoch in range(config["num_epochs"]):
        if is_distributed:
            train_dataloader.sampler.set_epoch(epoch)
            valid_dataloader.sampler.set_epoch(epoch)

        model.train()
        print(f"[GPU{gpu}] Epoch {epoch} Training ====", flush=True)
        "Recall our batch class takes in batch(src,tgt)"
        _, train_state = run_epoch(
            (ProcessRaw(b) for b in train_dataloader), 
            model,
            nn.MSELoss(),
            optimizer,
            lr_scheduler,
            mode="train+log",
            accum_iter=config["accum_iter"],
            train_state=train_state,
        )

        GPUtil.showUtilization()
        if is_main_process:
            file_path = f'{config["file_prefix"]}latest.pt'
            torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, file_path)
        torch.cuda.empty_cache()

        print(f"[GPU{gpu}] Epoch {epoch} Validation ====", flush=True)
        with torch.no_grad():
            model.eval()
            sloss = run_epoch(
                (ProcessRaw(b) for b in valid_dataloader),
                model,
                nn.MSELoss(),
                DummyOptimizer(),
                DummyScheduler(),
                mode="eval",
            )
            print(sloss)
        torch.cuda.empty_cache()

    if is_main_process:
        file_path = "%sfinal.pt" % config["file_prefix"]
        torch.save(module.state_dict(), file_path)

In [23]:
def train_distributed_model(config):
    from the_annotated_transformer import train_worker

    ngpus = torch.cuda.device_count()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12356"
    print(f"Number of GPUs detected: {ngpus}")
    print("Spawning training processes ...")
    mp.spawn(
        train_worker,
        nprocs=ngpus,
        args=(ngpus, config, True),
    )


    
config = {
        "batch_size": 16,
        "distributed": False,
        "num_epochs": 8,
        "accum_iter": 10,
        "base_lr": 1.0,
        "max_padding": 72,
        "warmup": 3000,
        "file_prefix": "norm_covid_model_",
    }

    
def train_model(config):
    if config["distributed"]:
        train_distributed_model(
            config
        )
    else:
        train_worker(
            0, 1, config, False
        )


def load_trained_model(config): #loads a trained model or trains it if there is none
    model_path = "norm_covid_model_final.pt"
    if not exists(model_path):
        train_model(config)
    model = make_model(loc_num=189, var_num=3, past=14, pred=3, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1)
    model.load_state_dict(torch.load("norm_covid_model_final.pt"))
    return model

def eval_model(gpu, path):
    torch.cuda.set_device(gpu)
    model = make_model(loc_num=189, var_num=3, past=14, pred=3, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1)
    model.cuda(gpu)
    model_path=path 
    loaded_model = torch.load(path)
    model.load_state_dict(loaded_model['model_state_dict'])
    del loaded_model
    
    train_dataloader, valid_dataloader, test_dataloader = create_dataloaders(
        gpu,
        batch_size=config["batch_size"],
        is_distributed=False,
    )
    del train_dataloader
    
    torch.cuda.empty_cache()
    print("Validation ====", flush=True)
    with torch.no_grad():
            model.eval()
            sloss = run_epoch(
                (ProcessRaw(b) for b in valid_dataloader),
                model,
                nn.MSELoss(),
                DummyOptimizer(),
                DummyScheduler(),
                mode="eval",
            )
            print(sloss)
            
    torch.cuda.empty_cache()
    print("Test ====", flush=True)
    with torch.no_grad():
            model.eval()
            tloss = run_epoch(
                (ProcessRaw(b) for b in test_dataloader),
                model,
                nn.MSELoss(),
                DummyOptimizer(),
                DummyScheduler(),
                mode="eval",
            )
            print(tloss)


Validation ====
(tensor(7.8195e-07, device='cuda:0'), <__main__.TrainState object at 0x0000011FE85B5B88>)
Test ====
(tensor(7.8053e-07, device='cuda:0'), <__main__.TrainState object at 0x0000011FE85B5B88>)


In [None]:
"Run the below functions"
torch.cuda.empty_cache()

#train_model(config)
eval_model(0, 'norm_covid_model_latest.pt')

In [24]:
#To check dataloaders are working as intended
'''
train, valid, test = create_dataloaders(0)
for b in test:
    print(b[:10])
    break

for b in valid:
    print(b[:10])
    break
    
for b in test:
    print(b[:10])
    break
    
'''

tensor([[0.0000e+00, 8.9785e-02, 8.9970e-02, 9.0113e-02, 9.0196e-02, 9.0449e-02,
         9.0587e-02, 9.0744e-02, 9.0901e-02, 9.1111e-02, 9.1297e-02, 9.1404e-02,
         9.1823e-02, 9.2005e-02, 9.2193e-02, 2.1644e+00, 2.1651e+00, 2.1658e+00,
         2.1666e+00, 2.1673e+00, 2.1680e+00, 2.1687e+00, 2.1694e+00, 2.1700e+00,
         2.1707e+00, 2.1713e+00, 2.1721e+00, 2.1729e+00, 2.1737e+00, 9.7885e-07,
         1.9842e-06, 5.8202e-07, 7.6721e-07, 8.2012e-07, 6.6139e-07, 8.9949e-07,
         1.1905e-06, 1.2434e-06, 7.4075e-07, 3.4392e-07, 1.9312e-06, 8.9949e-07,
         5.2911e-07, 1.1905e-06, 1.3492e-06, 8.4657e-07],
        [0.0000e+00, 1.0806e-01, 1.0806e-01, 1.0806e-01, 1.0873e-01, 1.0873e-01,
         1.0873e-01, 1.0873e-01, 1.0873e-01, 1.0873e-01, 1.0873e-01, 1.0941e-01,
         1.0941e-01, 1.0941e-01, 1.0941e-01, 2.3221e+00, 2.3229e+00, 2.3237e+00,
         2.3245e+00, 2.3252e+00, 2.3260e+00, 2.3267e+00, 2.3276e+00, 2.3287e+00,
         2.3300e+00, 2.3312e+00, 2.3325e+00, 2.3338

In [27]:
print(torch.__version__)

1.12.1
