In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
#import matplotlib.pyplot as plt
import os, random, math
import pandas as pd
import numpy as np
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

# !conda install -c conda-forge rdkit
# from rdkit import Chem

import seaborn as sns
from matplotlib import pyplot as plt

import PIL
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset

import torchvision
from torchvision import transforms
import torchvision.models as models
import torchvision.transforms as T
from torchvision.transforms.transforms import Compose, Normalize, Resize, ToTensor, RandomHorizontalFlip, RandomCrop

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet34(pretrained= False)
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False) # kernel_size=7, stride=2, padding=3
        modules = list(resnet.children())[:-3] # use the most effective part for image mapping
        self.resnet = nn.Sequential(*modules)

    def forward(self, images):
        features = self.resnet(images)
        return features

# https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py
class PositionEncode2D(nn.Module):
    def __init__(self, dim, width, height):
        super().__init__()
        assert (dim % 4 == 0)
        self.width = width
        self.height = height

        dim = dim // 2
        d = torch.exp(torch.arange(0., dim, 2) * -(math.log(10000.0) / dim))
        position_w = torch.arange(0., width).unsqueeze(1)
        position_h = torch.arange(0., height).unsqueeze(1)
        pos = torch.zeros(1, dim*2, height, width)

        pos[0, 0:dim:2, :, :] = torch.sin(position_w * d).transpose(0, 1).unsqueeze(1).repeat(1, 1, height, 1)
        pos[0, 1:dim:2, :, :] = torch.cos(position_w * d).transpose(0, 1).unsqueeze(1).repeat(1, 1, height, 1)
        pos[0, dim + 0::2, :, :] = torch.sin(position_h * d).transpose(0, 1).unsqueeze(2).repeat(1, 1, 1, width)
        pos[0, dim + 1::2, :, :] = torch.cos(position_h * d).transpose(0, 1).unsqueeze(2).repeat(1, 1, 1, width)
        self.register_buffer('pos', pos)

    def forward(self, x):
        batch_size, C, H, W = x.shape
        x = x + self.pos[:, :, :H, :W]
        return x

# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class TransformerModel(nn.Module):
    def __init__(self, hidden = 768,nhead = 4, nlayers = 3, dropout=0.1):
        super(TransformerModel, self).__init__()
        nhead = hidden // 64
        self.positionencode2D = PositionEncode2D(dim = hidden , width = 40, height = 24)
        self.encoder_layer = nn.TransformerEncoderLayer(d_model = hidden, nhead = nhead)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers = nlayers)
        
        self.fc_out = nn.Linear(hidden, 1)
        self.relu = nn.ReLU()
        self.src_mask = None

    def forward(self, src):
        
        # enc_src = [src_len, hid dim]
        src = self.positionencode2D(src) # [batch, src_len]
        src = src.view(src.size(0),src.size(1),-1)
        src = src.permute(2,0,1).contiguous()
        
        output = self.transformer_encoder(src)
        output = output.permute(1,0,2).contiguous()
        return output
    

class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        self.A = nn.Linear(attention_dim,1)

    def forward(self, features, hidden_state):
        u_hs = self.U(features)     #(batch_size,num_layers_pixels,attention_dim) 
        w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) #(batch_size,num_layers_pixels,attemtion_dim)
        
        attention_scores = self.A(combined_states)         #(batch_size,num_layers_pixels,1)
        attention_scores = attention_scores.squeeze(2)     #(batch_size,num_layers_pixels)
        
        # to combine features and hidden_state with features as a weight, and then give features
        alpha = F.softmax(attention_scores,dim=1)          #(batch_size,num_layers_pixels)
        # Multiply after increasing the dimension, different pixels have different weights
        # but the same pixel has the same weight for N dimensions
        attention_weights = features * alpha.unsqueeze(2)  #(batch_size,num_layers_pixels,features_dim)
        # Sum X again. This is the final matrix regrading weight, but not the weight
        attention_weights = attention_weights.sum(dim=1)   #(batch_size,num_layers_pixels)
        return alpha,attention_weights