# Yuhong: pad input tensors with zeros if model has to take inputs with constant length

# Yuhong: technically we're going for a vision transformer with 1x273 (or 273x1) patches

## Useful links: [https://www.machinecurve.com/index.php/2020/12/28/introduction-to-transformers-in-machine-learning](https://www.machinecurve.com/index.php/2020/12/28/introduction-to-transformers-in-machine-learning), [https://pytorch.org/tutorials/beginner/transformer_tutorial.html](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)

In [None]:
%matplotlib inline

*   YouTube video explaining Transformers: [https://www.youtube.com/watch?v=TQQlZhbC5ps&list=TLPQMDYwNzIwMjFuBc39xf3IYg&index=9&ab_channel=CodeEmporium](https://www.youtube.com/watch?v=TQQlZhbC5ps&list=TLPQMDYwNzIwMjFuBc39xf3IYg&index=9&ab_channel=CodeEmporium)
*   Original Transformers paper: [https://arxiv.org/pdf/1706.03762.pdf](https://arxiv.org/pdf/1706.03762.pdf)

# Import dependencies

In [None]:
# Data & storage
import os
import glob
import hashlib
from google.colab import drive
from torch.utils.data import random_split, DataLoader
from torch.utils.data.distributed import DistributedSampler 


# Analysis
import numpy as np
import pandas as pd
from pandas import read_csv

# Visualizations
from matplotlib import pyplot as plt
from tqdm import tqdm

# Deep learning
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Distributed training (TPUs)
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings
warnings.filterwarnings("ignore")

# Miscellaneous
from typing import Optional, Union

Collecting torch-xla==1.9
  Downloading https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl (149.9 MB)
[K     |████████████████████████████████| 149.9 MB 50 kB/s 
[?25hCollecting cloud-tpu-client==0.10
  Downloading cloud_tpu_client-0.10-py3-none-any.whl (7.4 kB)
Collecting google-api-python-client==1.8.0
  Downloading google_api_python_client-1.8.0-py3-none-any.whl (57 kB)
[K     |████████████████████████████████| 57 kB 2.3 MB/s 
Installing collected packages: google-api-python-client, torch-xla, cloud-tpu-client
  Attempting uninstall: google-api-python-client
    Found existing installation: google-api-python-client 1.12.8
    Uninstalling google-api-python-client-1.12.8:
      Successfully uninstalled google-api-python-client-1.12.8
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
earthengine-api 0.1.272 requir



# Download Data



To download the Kaggle dataset, we must first mount our Google Drive to this Colab notebook.

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


Then, we specify the config path to our Kaggle API token (in the form of a `kaggle.json` file), and change the current working directory to that path.

In [None]:
os.environ['KAGGLE_CONFIG_DIR'] = '/content/drive/MyDrive/Research/Dynamic Spectra Sequence Modeling/Data/Kaggle'
%cd '/content/drive/MyDrive/Research/Ongoing/Dynamic Spectra Sequence Modeling/Data'

/content/drive/MyDrive/Research/Transformers/Code/Data


Finally, we copy and run the API command for the BL Kaggle competition to download the datasets (remember to unzip the files).

In [None]:
if not os.listdir():
  # Note, if you're getting the error message "429 - Too Many Requests", try running the following commands before the API command:
  # !pip uninstall -y kaggle
  # !pip install --upgrade pip
  # !pip install kaggle==1.5.6
  !kaggle competitions download -c seti-breakthrough-listen
  
  file_to_extract = 'seti-breakthrough-listen.zip'
  # Read in zip file
    with ZipFile(file_to_extract,'r') as zip_ref:
      # Add progress bar
      for file in tqdm(iterable=zip_ref.namelist(), total=len(zip_ref.namelist())):
        # Extract and store in current directory
        zip_ref.extract(member=file)

# Prep Data

We want to create lookup tables in the form of Python dictionaries, with ID-target key-value pairs, for both the training and test data. 

To do so for the training data is quite straightforward. Note however, that the test data IDs have been hashed for security purposes, hence we must go through some extra steps beforehand.

In [None]:
train_labels = read_csv('train_labels.csv')
train_dict = dict(zip(train_labels.id, train_labels.target))

original_labels = read_csv('sample_submission.csv')['id']
hash_labels = read_csv('masked_labels.csv')
test_dict = {}
keyword = input('Enter keyword: ')
for labels in tqdm(original_labels):
  m = hashlib.md5(keyword.encode("utf-8"))
  m.update(bytes.fromhex("0" + labels))
  hashed_id = m.hexdigest()
  test_dict[labels] = hash_labels.loc[hash_labels['id'] == hashed_id, 'target'].item()

Enter keyword: zach


100%|██████████| 39995/39995 [01:51<00:00, 359.27it/s]


Split the training set into non-overlapping new datasets for cross-validation. Note that `x_train` and `x_valid` will hold the ID values, whereas `y_train` and `y_valid` will hold the target values (both with lengths `(48000, 12000)`, respectively). Since our model is self-supervised, we'll only use `y_train` and `y_valid` for validation using downstream tasks.

In [None]:
len_train = int(len(train_labels) * 0.8)
len_valid = int(len(train_labels) * 0.2)

x_train, x_valid = random_split(train_labels['id'], (len_train, len_valid))
y_train, y_valid = random_split(train_labels['target'], (len_train, len_valid))

# Class Definitions

## Positional Encoding

In [None]:
def positional_encoding(length: int, d_model: int) -> torch.Tensor:
    """
    Generate positional encoding as described in original paper.  :class:`torch.Tensor`
    Parameters
    ----------
    length:
        Time window length, i.e. K.
    d_model:
        Dimension of the model vector.
    Returns
    -------
        Tensor of shape (K, d_model).
    """
    PE = torch.zeros((length, d_model))
    pos = torch.arange(length).unsqueeze(1)

    PE[:, 0::2] = torch.sin(
        pos / torch.pow(1000, torch.arange(0, d_model, 2, dtype=torch.float32)/d_model))
    PE[:, 1::2] = torch.cos(
        pos / torch.pow(1000, torch.arange(1, d_model, 2, dtype=torch.float32)/d_model))
    
    return PE

## Masks

In [None]:
def generate_local_masks(chunk_size: int,
                         attention_size: int,
                         mask_future=False,
                         device: torch.device = 'cpu') -> torch.BoolTensor:
    """
    Compute attention mask as attention_size wide diagonal.
    Parameters
    ----------
    chunk_size:
        Time dimension size.
    attention_size:
        Number of backward elements to apply attention.
    device:
        torch device. Default is ``'cpu'``.
    Returns
    -------
        Mask as a boolean tensor.
    """
    local_map = np.empty((chunk_size, chunk_size))
    i, j = np.indices(local_map.shape)

    if mask_future:
        local_map[i, j] = (i - j > attention_size) ^ (j - i > 0)
    else:
        local_map[i, j] = np.abs(i - j) > attention_size

    return torch.BoolTensor(local_map).to(device)

## Multi-Headed Attention

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi Head Attention block from Attention is All You Need.
    Given 3 inputs of shape (batch_size, K, d_model), that will be used
    to compute query, keys and values, we output a self attention
    tensor of shape (batch_size, K, d_model).
    Parameters
    ----------
    d_model:
        Dimension of the input vector.
    q:
        Dimension of all query matrix.
    v:
        Dimension of all value matrix.
    h:
        Number of heads.
    attention_size:
        Number of backward elements to apply attention.
        Deactivated if ``None``. Default is ``None``.
    """
    def __init__(self,
                 d_model: int,
                 q: int,
                 v: int,
                 h: int,
                 attention_size: int = None):
        """Initialize the Multi Head Block."""
        super().__init__()

        self._h = h
        self._attention_size = attention_size

        # Query, keys and value matrices
        self._W_q = nn.Linear(d_model, q*self._h)
        self._W_k = nn.Linear(d_model, q*self._h)
        self._W_v = nn.Linear(d_model, v*self._h)

        # Output linear function
        self._W_o = nn.Linear(self._h*v, d_model)

        # Score placeholder
        self._scores = None

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[str] = None) -> torch.Tensor:
        """
        Propagate forward the input through the MHB.
        We compute for each head the queries, keys and values matrices,
        followed by the Scaled Dot-Product. The result is concatenated 
        and returned with shape (batch_size, K, d_model).
        Parameters
        ----------
        query:
            Input tensor with shape (batch_size, K, d_model) used to compute queries.
        key:
            Input tensor with shape (batch_size, K, d_model) used to compute keys.
        value:
            Input tensor with shape (batch_size, K, d_model) used to compute values.
        mask:
            Mask to apply on scores before computing attention.
            One of ``'subsequent'``, None. Default is None.
        Returns
        -------
            Self attention tensor with shape (batch_size, K, d_model).
        """
        K = query.shape[1]

        # Compute Q, K and V, concatenate heads on batch dimension
        queries = torch.cat(self._W_q(query).chunk(self._h, dim=-1), dim=0)
        keys = torch.cat(self._W_k(key).chunk(self._h, dim=-1), dim=0)
        values = torch.cat(self._W_v(value).chunk(self._h, dim=-1), dim=0)

        # Scaled Dot Product
        self._scores = torch.bmm(queries, keys.transpose(1, 2)) / np.sqrt(K)

        # Compute local map mask
        if self._attention_size is not None:
            attention_mask = generate_local_map_mask(K, self._attention_size, mask_future=False, device=self._scores.device)
            self._scores = self._scores.masked_fill(attention_mask, float('-inf'))

        # Compute future mask
        if mask == "subsequent":
            future_mask = torch.triu(torch.ones((K, K)), diagonal=1).bool()
            future_mask = future_mask.to(self._scores.device)
            self._scores = self._scores.masked_fill(future_mask, float('-inf'))

        # Apply sotfmax
        self._scores = F.softmax(self._scores, dim=-1)

        attention = torch.bmm(self._scores, values)

        # Concatenat the heads
        attention_heads = torch.cat(attention.chunk(self._h, dim=0), dim=-1)

        # Apply linear transformation W^O
        self_attention = self._W_o(attention_heads)

        return self_attention

    @property
    def attention_map(self) -> torch.Tensor:
        """
        Attention map after a forward propagation,
        variable `score` in the original paper.
        """
        if self._scores is None:
            raise RuntimeError(
                "Evaluate the model once to generate attention map")
        return self._scores

## Feed-Forward Network

In [None]:
class PositionwiseFeedForward(nn.Module):
    """
    Position-wise Feed Forward Network block from Attention is All You Need.
    Apply two linear transformations to each input, separately but indetically. We
    implement them as 1D convolutions. Input and output have a shape (batch_size, d_model).
    Parameters
    ----------
    d_model:
        Dimension of input tensor.
    d_ff:
        Dimension of hidden layer, default is 2048.
    """
    def __init__(self,
                 d_model: int,
                 d_ff: Optional[int] = 2048):
        """Initialize the PFF block."""
        super().__init__()

        self._linear1 = nn.Linear(d_model, d_ff)
        self._linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Propagate forward the input through the PFF block.
        Apply the first linear transformation, then a relu actvation,
        and the second linear transformation.
        Parameters
        ----------
        x:
            Input tensor with shape (batch_size, K, d_model).
        Returns
        -------
            Output tensor with shape (batch_size, K, d_model).
        """
        return self._linear2(F.relu(self._linear1(x)))

## Loss

[SET LOSS FUNCTION SUCH THAT IT CALCULATES LOSS BETWEEN PREDICTED AND ACTUAL SPECTRA FOR NEXT TIMESTEP]

In [None]:
class OZELoss(nn.Module):
    """Custom loss for TRNSys metamodel.
    Compute, for temperature and consumptions, the intergral of the squared differences
    over time. Sum the log with a coeficient ``alpha``.
    .. math::
        \Delta_T = \sqrt{\int (y_{est}^T - y^T)^2}
        \Delta_Q = \sqrt{\int (y_{est}^Q - y^Q)^2}
        loss = log(1 + \Delta_T) + \\alpha \cdot log(1 + \Delta_Q)
    Parameters:
    -----------
    alpha:
        Coefficient for consumption. Default is ``0.3``.
    """
    def __init__(self, reduction: str = 'mean', alpha: float = 0.3):
        super().__init__()

        self.alpha = alpha
        self.reduction = reduction

        self.base_loss = nn.MSELoss(reduction=self.reduction)

    def forward(self,
                y_true: torch.Tensor,
                y_pred: torch.Tensor) -> torch.Tensor:
        """Compute the loss between a target value and a prediction.
        Parameters
        ----------
        y_true:
            Target value.
        y_pred:
            Estimated value.
        Returns
        -------
        Loss as a tensor with gradient attached.
        """
        delta_Q = self.base_loss(y_pred[..., :-1], y_true[..., :-1])
        delta_T = self.base_loss(y_pred[..., -1], y_true[..., -1])

        if self.reduction == 'none':
            delta_Q = delta_Q.mean(dim=(1, 2))
            delta_T = delta_T.mean(dim=(1))

        return torch.log(1 + delta_T) + self.alpha * torch.log(1 + delta_Q)

## Decoder

In [None]:
class Decoder(nn.Module):
    """
    Decoder block from Attention is All You Need.
    Apply two Multi Head Attention block followed by a Point-wise Feed Forward block.
    Residual sum and normalization are applied at each step.
    Parameters
    ----------
    d_model: 
        Dimension of the input vector.
    q:
        Dimension of all query matrix.
    v:
        Dimension of all value matrix.
    h:
        Number of heads.
    attention_size:
        Number of backward elements to apply attention.
        Deactivated if ``None``. Default is ``None``.
    dropout:
        Dropout probability after each MHA or PFF block.
        Default is ``0.3``.
    chunk_mode:
        Swict between different MultiHeadAttention blocks.
        One of ``'chunk'``, ``'window'`` or ``None``. Default is ``'chunk'``.
    """
    def __init__(self,
                 d_model: int,
                 q: int,
                 v: int,
                 h: int,
                 attention_size: int = None,
                 dropout: float = 0.3,
                 chunk_mode: str = 'chunk'):
        """Initialize the Decoder block"""
        super().__init__()

        chunk_mode_modules = {
            'chunk': MultiHeadAttentionChunk,
            'window': MultiHeadAttentionWindow,
        }

        if chunk_mode in chunk_mode_modules.keys():
            MHA = chunk_mode_modules[chunk_mode]
        elif chunk_mode is None:
            MHA = MultiHeadAttention
        else:
            raise NameError(
                f'chunk_mode "{chunk_mode}" not understood. Must be one of {", ".join(chunk_mode_modules.keys())} or None.')

        self._selfAttention = MHA(d_model, q, v, h, attention_size=attention_size)
        self._encoderDecoderAttention = MHA(d_model, q, v, h, attention_size=attention_size)
        self._feedForward = PositionwiseFeedForward(d_model)

        self._layerNorm1 = nn.LayerNorm(d_model)
        self._layerNorm2 = nn.LayerNorm(d_model)
        self._layerNorm3 = nn.LayerNorm(d_model)

        self._dopout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
        """Propagate the input through the Decoder block.
        Apply the self attention block, add residual and normalize.
        Apply the encoder-decoder attention block, add residual and normalize.
        Apply the feed forward network, add residual and normalize.
        Parameters
        ----------
        x:
            Input tensor with shape (batch_size, K, d_model).
        memory:
            Memory tensor with shape (batch_size, K, d_model)
            from encoder output.
        Returns
        -------
        x:
            Output tensor with shape (batch_size, K, d_model).
        """
        # Self attention
        residual = x
        x = self._selfAttention(query=x, key=x, value=x, mask="subsequent")
        x = self._dopout(x)
        x = self._layerNorm1(x + residual)

        # Encoder-decoder attention
        residual = x
        x = self._encoderDecoderAttention(query=x, key=memory, value=memory)
        x = self._dopout(x)
        x = self._layerNorm2(x + residual)

        # Feed forward
        residual = x
        x = self._feedForward(x)
        x = self._dopout(x)
        x = self._layerNorm3(x + residual)

        return x

## Transformer

In [None]:
class Transformer(nn.Module):
    """
    Transformer model from Attention is All You Need.
    A classic transformer model adapted for sequential data.
    Embedding has been replaced with a fully connected layer,
    the last layer softmax is now a sigmoid.
    Attributes
    ----------
    layers_encoding: :py:class:`list` of :class:`Encoder.Encoder`
        stack of Encoder layers.
    layers_decoding: :py:class:`list` of :class:`Decoder.Decoder`
        stack of Decoder layers.
    Parameters
    ----------
    d_input:
        Model input dimension.
    d_model:
        Dimension of the input vector.
    d_output:
        Model output dimension.
    q:
        Dimension of queries and keys.
    v:
        Dimension of values.
    h:
        Number of heads.
    N:
        Number of encoder and decoder layers to stack.
    attention_size:
        Number of backward elements to apply attention.
        Deactivated if ``None``. Default is ``None``.
    dropout:
        Dropout probability after each MHA or PFF block.
        Default is ``0.3``.
    chunk_mode:
        Switch between different MultiHeadAttention blocks.
        One of ``'chunk'``, ``'window'`` or ``None``. Default is ``'chunk'``.
    pe:
        Type of positional encoding to add.
        Must be one of ``'original'``, ``'regular'`` or ``None``. Default is ``None``.
    pe_period:
        If using the ``'regular'` pe, then we can define the period. Default is ``24``.
    """
    def __init__(self,
                 d_input: int,
                 d_model: int,
                 d_output: int,
                 q: int,
                 v: int,
                 h: int,
                 N: int,
                 attention_size: int = None,
                 dropout: float = 0.3,
                 chunk_mode: str = 'chunk',
                 pe: str = None,
                 pe_period: int = 24):
        """Create transformer structure from Encoder and Decoder blocks."""
        super().__init__()

        self._d_model = d_model

        self.layers_encoding = nn.ModuleList([Encoder(d_model,
                                                      q,
                                                      v,
                                                      h,
                                                      attention_size=attention_size,
                                                      dropout=dropout,
                                                      chunk_mode=chunk_mode) for _ in range(N)])
        self.layers_decoding = nn.ModuleList([Decoder(d_model,
                                                      q,
                                                      v,
                                                      h,
                                                      attention_size=attention_size,
                                                      dropout=dropout,
                                                      chunk_mode=chunk_mode) for _ in range(N)])

        self._embedding = nn.Linear(d_input, d_model)
        self._linear = nn.Linear(d_model, d_output)

        pe_functions = {
            'original': generate_original_PE,
            'regular': generate_regular_PE,
        }

        if pe in pe_functions.keys():
            self._generate_PE = pe_functions[pe]
            self._pe_period = pe_period
        elif pe is None:
            self._generate_PE = None
        else:
            raise NameError(
                f'PE "{pe}" not understood. Must be one of {", ".join(pe_functions.keys())} or None.')

        self.name = 'transformer'

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Propagate input through transformer
        Forward input through an embedding module,
        the encoder then decoder stacks, and an output module.
        Parameters
        ----------
        x:
            :class:`torch.Tensor` of shape (batch_size, K, d_input).
        Returns
        -------
            Output tensor with shape (batch_size, K, d_output).
        """
        K = x.shape[1]

        # Embeddin module
        encoding = self._embedding(x)

        # Add position encoding
        if self._generate_PE is not None:
            pe_params = {'period': self._pe_period} if self._pe_period else {}
            positional_encoding = self._generate_PE(K, self._d_model, **pe_params)
            positional_encoding = positional_encoding.to(encoding.device)
            encoding.add_(positional_encoding)

        # Encoding stack
        for layer in self.layers_encoding:
            encoding = layer(encoding)

        # Decoding stack
        decoding = encoding

        # Add position encoding
        if self._generate_PE is not None:
            positional_encoding = self._generate_PE(K, self._d_model)
            positional_encoding = positional_encoding.to(decoding.device)
            decoding.add_(positional_encoding)

        for layer in self.layers_decoding:
            decoding = layer(decoding, encoding)

        # Output module
        output = self._linear(decoding)
        output = torch.sigmoid(output)
        return output

# Train

Initialize the random seed.

In [None]:
# Random Seed Initialize
RANDOM_SEED = 11
def seed_everything(seed=RANDOM_SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything()

Set the model parameters.

In [None]:
checkpoint_path = '../Code/Checkpoints/.'
num_cores = 8
num_workers = 0
epochs = 30
batch_size = 128
learning_rate = 1e-4


d_model = 256   # Latent dim
d_input = (6, t, 256)   # Input dim (from dataset), where 0<= t <= 6*273
d_output = (6, 1, 256)   # Output dim (from dataset)
q = 8   # Query size
v = 8   # Value size
h = 8   # Number of heads
N = 4   # Number of decoder blocks to stack
attention_size = 12   # Attention window size
dropout = 0.2   # Dropout rate
pe = None
chunk_mode = None

# training_params = {
#     'checkpoint_path': '../Checkpoints/.',
#     'num_cores': 8,
#     'num_workers': 0,
#     'epochs': 30,
#     'batch_size': 128,
#     'learning_rate': 1e-4
# }

# # Dimensions for data are (6, 273, 256), i.e. 6 snippets of 273 timesteps and 256 frequency channels
# model_params = {
#     'd_model': 256, # Latent dim
#     'd_input': (6, t, 256), # Input dim (from dataset), where 0<= t <= 6*273
#     'd_output': (6, 1, 256), # Output dim (from dataset)
#     'q': 8, # Query size
#     'v': 8, # Value size
#     'h': 8, # Number of heads
#     'N': 4, # Number of decoder blocks to stack
#     'attention_size': 12, # Attention window size
#     'dropout': 0.2, # Dropout rate
#     'pe': None,
#     'chunk_mode': None
# }

NameError: ignored

## Configuring Colab's Cloud TPUs



Colab provides a free Cloud TPU system (a remote CPU host + four TPU chips with two cores each). To gain access to a TPU on Colab, on the main menu, click Runtime > Change runtime type > set "TPU" as the hardware accelerator.

The PyTorch/XLA package lets PyTorch connect to Cloud TPUs (It's named PyTorch/XLA, not PyTorch/TPU, because XLA is the name of the TPU compiler), and makes TPU cores available as PyTorch devices, which lets PyTorch create and manipulate tensors on TPUs.

In [None]:
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

`torch.utils.data.distributed.DistrubutedSampler()` distributes the training data evenly (with no replicas) to all 8 TPU cores that Colab provides. Note that `xm.xrt_world_size()` retrieves the number of devices that are taking part in the replication (basically the number of cores), and `xm.get_ordinal()` retrieves the replication ordinal of the current process. The ordinals range from `0` to `xrt_world_size()-1`.

In [None]:
train_sampler = DistributedSampler(
    x_train,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
     
valid_sampler = DistributedSampler(
    x_valid,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=False)

After the data has been distributed, we can create dataloaders using `ParallelLoader`.

In [None]:
train_loader = DataLoader(
    x_train,
    batch_size=training_params['batch_size'],
    sampler=train_sampler,
    num_workers=training_params['num_workers'],
    shuffle=True,
    drop_last=True)

valid_loader = DataLoader(
    x_valid,
    batch_size=training_params['batch_size'],
    sampler=train_sampler,
    num_workers=training_params['num_workers'],
    shuffle=False,
    drop_last=True)

# drop_last = True drops the last incomplete batch if the dataset size is not divisible by the batch size
# drop_last = False will cause the last batch to be smaller if the size of dataset is not divisible by the batch size

Optimize parameters for distributed training on TPU cores (remember `xm.xrt_world_size()` returns the number of TPU cores, which for our case is 8).

In [None]:
# Scale learning rate to world size
lr = training_params['learning_rate'] * xm.xrt_world_size()

# Get loss function, optimizer, and model
device = xm.xla_device()
model = Transformer().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_function = OZELoss(alpha=0.3)

## Loop