## Notebooks referenced (kudos to the creators!)

##### <https://www.kaggle.com/ihelon/molecular-translation-exploratory-data-analysis>
##### <https://www.kaggle.com/yasufuminakama/inchi-resnet-lstm-with-attention-starter>

## Import packages

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import Levenshtein
from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, Blur
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform
import re
!pip install timm
import timm
import warnings 
warnings.filterwarnings('ignore')
import time
import math
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
!pip install GPUtil
import torch
from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()

free_gpu_cache() 

## Load files

In [None]:
df_train_labels = pd.read_csv('../input/bms-molecular-translation/train_labels.csv')
df_sample_submission = pd.read_csv('../input/bms-molecular-translation/sample_submission.csv')
df_extra_InChIs = pd.read_csv('../input/bms-molecular-translation/extra_approved_InChIs.csv')

In [None]:
def get_file_path(image_id, train_flag = True):
    if train_flag:
        return "../input/bms-molecular-translation/train/{}/{}/{}/{}.png".format(
            image_id[0], image_id[1], image_id[2], image_id 
        )
    else:
        return "../input/bms-molecular-translation/test/{}/{}/{}/{}.png".format(
            image_id[0], image_id[1], image_id[2], image_id 
        )        

In [None]:
df_train_labels = df_train_labels.reset_index(drop = True)
df_sample_submission = df_sample_submission.reset_index(drop = True)

In [None]:
df_train_labels['filename'] = df_train_labels['image_id'].apply(lambda x: get_file_path(x, train_flag = True))
df_sample_submission['filename'] = df_sample_submission['image_id'].apply(lambda x: get_file_path(x, train_flag = False))

In [None]:
print(df_train_labels.shape)
df_train_labels.head()

In [None]:
print(df_sample_submission.shape)
df_sample_submission.head()

In [None]:
def convert_image_id_2_path(image_id: str) -> str:
    return "../input/bms-molecular-translation/train/{}/{}/{}/{}.png".format(
        image_id[0], image_id[1], image_id[2], image_id 
    )

In [None]:
convert_image_id_2_path('000011a64c74')

In [None]:
def visualize_train_images(image_ids, labels):
    plt.figure(figsize=(16, 12))
    
    for ind, (image_id, label) in enumerate(zip(image_ids, labels)):
        plt.subplot(3, 3, ind + 1)
        image = cv2.imread(convert_image_id_2_path(image_id))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        plt.imshow(image)
#         print(f"{ind}: {label}")
        plt.title(f"{label[:30]}", fontsize=10)
        plt.axis("off")
    
    plt.show()

In [None]:
sample_row = df_train_labels.sample(5)
visualize_train_images(
    sample_row['image_id'], sample_row["InChI"]
)

In [None]:
print(df_extra_InChIs.shape)
df_extra_InChIs.head()

## EDA & Data processing

In [None]:
img1 = Image.open("../input/bms-molecular-translation/train/0/0/0/000011a64c74.png")
img2 = Image.open("../input/bms-molecular-translation/train/0/0/0/000019cc0cd2.png")
print('Shapes:',np.asarray(img1).shape, np.asarray(img2).shape)
print('Palettes:',img1.palette, img2.palette)

In [None]:
h_shape = []
w_shape = []
sample_train_ids = df_train_labels['image_id'].sample(10000)
for image_id in tqdm(sample_train_ids):
    image = cv2.imread(convert_image_id_2_path(image_id))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    h_shape.append(image.shape[0])
    w_shape.append(image.shape[1])

In [None]:
plt.figure(figsize=(16, 5))
plt.subplot(1, 3, 1)
plt.hist(np.array(h_shape) * np.array(w_shape), bins=50)
plt.xticks(rotation=45)
plt.title("Area Image Distribution", fontsize=14)
plt.subplot(1, 3, 2)
plt.hist(h_shape, bins=50)
plt.title("Height Image Distribution", fontsize=14)
plt.subplot(1, 3, 3)
plt.hist(w_shape, bins=50)
plt.title("Width Image Distribution", fontsize=14)

In [None]:
def split_form(form):
    string = ''
    for i in re.findall(r"[A-Z][^A-Z]*", form):
        elem = re.match(r"\D+", i).group()
        num = i.replace(elem, "")
        if num == "":
            string += f"{elem} "
        else:
            string += f"{elem} {str(num)} "
    return string.rstrip(' ')

def split_form2(form):
    string = ''
    for i in re.findall(r"[a-z][^a-z]*", form):
        elem = i[0]
        num = i.replace(elem, "").replace('/', "")
        num_string = ''
        for j in re.findall(r"[0-9]+[^0-9]*", num):
            num_list = list(re.findall(r'\d+', j))
            assert len(num_list) == 1, f"len(num_list) != 1"
            _num = num_list[0]
            if j == _num:
                num_string += f"{_num} "
            else:
                extra = j.replace(_num, "")
                num_string += f"{_num} {' '.join(list(extra))} "
        string += f"/{elem} {num_string}"
    return string.rstrip(' ')

In [None]:
df_train_labels['InChI'].iloc[2], df_train_labels['InChI'].iloc[2].split('/')[1], split_form(df_train_labels['InChI'].iloc[2].split('/')[1]), split_form2('/'.join(df_train_labels['InChI'].iloc[2].split('/')[2:]))

In [None]:
class Tokenizer(object):
    
    def __init__(self):
        self.stoi = {}
        self.itos = {}
        self.count = 0

    def __len__(self):
        return len(self.stoi)
    
    def fit_on_texts(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(text.split(' '))
        vocab = sorted(vocab)
        vocab.append('<sos>')
        vocab.append('<eos>')
        vocab.append('<pad>')
        for i, s in enumerate(vocab):
            if s not in list(self.stoi.keys()):
                self.stoi[s] = self.count
                self.count = self.count + 1
        self.itos = {item[1]: item[0] for item in self.stoi.items()}
        
    def text_to_sequence(self, text):
        sequence = []
        sequence.append(self.stoi['<sos>'])
        for s in text.split(' '):
            sequence.append(self.stoi[s])
        sequence.append(self.stoi['<eos>'])
        return sequence
    
    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            sequence = self.text_to_sequence(text)
            sequences.append(sequence)
        return sequences

    def sequence_to_text(self, sequence):
        return ''.join(list(map(lambda i: self.itos[i], sequence)))
    
    def sequences_to_texts(self, sequences):
        texts = []
        for sequence in sequences:
            text = self.sequence_to_text(sequence)
            texts.append(text)
        return texts
    
    def predict_caption(self, sequence):
        caption = ''
        for i in sequence:
            if i == self.stoi['<eos>'] or i == self.stoi['<pad>']:
                break
            caption += self.itos[i]
        return caption
    
    def predict_captions(self, sequences):
        captions = []
        for sequence in sequences:
            caption = self.predict_caption(sequence)
            captions.append(caption)
        return captions

In [None]:
df_train_labels['InChI_1'] = df_train_labels['InChI'].apply(lambda x: x.split('/')[1])
df_train_labels['InChI_text'] = df_train_labels['InChI_1'].apply(split_form) + ' ' + \
                        df_train_labels['InChI'].apply(lambda x: '/'.join(x.split('/')[2:])).apply(split_form2).values
# ====================================================
# create tokenizer
# ====================================================
# tokenizer = Tokenizer()
# tokenizer.fit_on_texts(df_train_labels['InChI_text'].values)
# torch.save(tokenizer, 'tokenizer2.pth')

tokenizer = torch.load('../input/bms-molecular-translation-tokenizer/tokenizer2.pth')
print('Saved tokenizer')
# ====================================================
# preprocess df_train_labels.csv
# ====================================================
lengths = []
tk0 = tqdm(df_train_labels['InChI_text'].values, total=len(df_train_labels))
for text in tk0:
    seq = tokenizer.text_to_sequence(text)
    length = len(seq) - 2
    lengths.append(length)
df_train_labels['InChI_length'] = lengths

In [None]:
df_train_labels.head()

In [None]:
plt.hist(df_train_labels['InChI_length'], bins=50)
plt.title("InChI length Distribution (Train data)", fontsize=14)

In [None]:
df_train_labels['InChI_length'].max()

In [None]:
## All characters in InChI
all_characters_in_InChI = []
for i in tqdm(range(0, len(df_train_labels))):
    characters_in_InChI = list(set(df_train_labels['InChI'].iloc[i]))
    all_characters_in_InChI = all_characters_in_InChI + list(set(characters_in_InChI).difference(set(all_characters_in_InChI)))

In [None]:
print(len(all_characters_in_InChI), 'characters in InChI')

In [None]:
def get_score(y_true, y_pred):
    scores = []
    for true, pred in zip(y_true, y_pred):
        score = Levenshtein.distance(true, pred)
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score

In [None]:
get_score('apple','paple')

In [None]:
def init_logger(log_file='train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

In [None]:
class TrainDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.file_paths = df['filename'].values
        self.labels = df['InChI_text'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = self.labels[idx]
        label = self.tokenizer.text_to_sequence(label)
        label_length = len(label)
        label_length = torch.LongTensor([label_length])
        return image, torch.LongTensor(label), label_length

In [None]:
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df
        self.file_paths = df['filename'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [None]:
def bms_collate(batch):
    imgs, labels, label_lengths = [], [], []
    for data_point in batch:
        imgs.append(data_point[0])
        labels.append(data_point[1])
        label_lengths.append(data_point[2])
    labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.stoi["<pad>"])
    return torch.stack(imgs), labels, torch.stack(label_lengths).reshape(-1, 1)

In [None]:
#https://pytorch.org/hub/pytorch_vision_resnet/

def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            Resize(500, 500),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(500, 500),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [None]:
df_train_labels.head()

In [None]:
train_dataset = TrainDataset(df_train_labels, tokenizer, transform=get_transforms(data='train'))

for i in range(1):
    image, label, label_length = train_dataset[i]
    text = tokenizer.sequence_to_text(label.numpy())
    plt.imshow(image.transpose(0, 1).transpose(1, 2))
    plt.title(f'label: {label}  text: {text}  label_length: {label_length}')
    plt.show() 

In [None]:
visualize_train_images(df_train_labels['image_id'].iloc[0:1], df_train_labels["InChI"].iloc[0:1])

In [None]:
list_InChI_versions = df_train_labels['InChI'].apply(lambda x: x.split('/')[0])

In [None]:
set(list_InChI_versions)

## Build Model

In [None]:
class Encoder(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False):
        super().__init__()
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        self.n_features = self.cnn.fc.in_features
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()

    def forward(self, x):
        bs = x.size(0)
        features = self.cnn(x)
        features = features.permute(0, 2, 3, 1)
        return features

In [None]:
class Attention(nn.Module):
    """
    Attention network for calculate attention value
    """
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: input size of encoder network
        :param decoder_dim: input size of decoder network
        :param attention_dim: input size of attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        return attention_weighted_encoding, alpha

In [None]:
class DecoderWithAttention(nn.Module):
    """
    Decoder network with attention network used for training
    """

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, device, encoder_dim=512, dropout=0.5):
        """
        :param attention_dim: input size of attention network
        :param embed_dim: input size of embedding network
        :param decoder_dim: input size of decoder network
        :param vocab_size: total number of characters used in training
        :param encoder_dim: input size of encoder network
        :param dropout: dropout rate
        """
        super(DecoderWithAttention, self).__init__()
        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.device = device
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network
        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        :param encoder_out: output of encoder network
        :param encoded_captions: transformed sequence from character to integer
        :param caption_lengths: length of transformed sequence
        """
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]
        # embedding transformed sequence for vector
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)
        # initialize hidden state and cell state of LSTM cell
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)
        # set decode length by caption length - 1 because of omitting start token
        decode_lengths = (caption_lengths - 1).tolist()
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(self.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(self.device)
        # predict sequence
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind
    
    def predict(self, encoder_out, decode_lengths, tokenizer):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)
        # embed start tocken for LSTM input
        start_tockens = torch.ones(batch_size, dtype=torch.long).to(self.device) * tokenizer.stoi["<sos>"]
        embeddings = self.embedding(start_tockens)
        # initialize hidden state and cell state of LSTM cell
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)
        predictions = torch.zeros(batch_size, decode_lengths, vocab_size).to(self.device)
        # predict sequence
        for t in range(decode_lengths):
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            gate = self.sigmoid(self.f_beta(h))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings, attention_weighted_encoding], dim=1),
                (h, c))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:, t, :] = preds
            if np.argmax(preds.detach().cpu().numpy()) == tokenizer.stoi["<eos>"]:
                break
            embeddings = self.embedding(torch.argmax(preds, -1))
        return predictions

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
def train_fn(train_loader, encoder, decoder, criterion, 
             encoder_optimizer, decoder_optimizer, epoch,
             device, encoder_scheduler=None, decoder_scheduler=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    # switch to train mode
    encoder.train()
    decoder.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels, label_lengths) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)
        batch_size = images.size(0)
        features = encoder(images)
        predictions, caps_sorted, decode_lengths, alphas, sort_ind = decoder(features, labels, label_lengths)
        targets = caps_sorted[:, 1:]
        predictions = pack_padded_sequence(predictions, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data
        loss = criterion(predictions, targets)
        # record loss
        losses.update(loss.item(), batch_size)
        loss.backward()
        encoder_grad_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
        decoder_grad_norm = torch.nn.utils.clip_grad_norm_(decoder.parameters(), 5)
        encoder_optimizer.step()
        decoder_optimizer.step()
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % 1000 == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Encoder Grad: {encoder_grad_norm:.4f}  '
                  'Decoder Grad: {decoder_grad_norm:.4f}  '
                  .format(
                   epoch+1, step, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(train_loader)),
                   encoder_grad_norm=encoder_grad_norm,
                   decoder_grad_norm=decoder_grad_norm
                   ))
    return losses.avg

In [None]:
def valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    # switch to evaluation mode
    encoder.eval()
    decoder.eval()
    text_preds = []
    start = end = time.time()
    for step, (images) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        batch_size = images.size(0)
        with torch.no_grad():
            features = encoder(images)
            predictions = decoder.predict(features, 275, tokenizer)
        predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy()
        _text_preds = tokenizer.predict_captions(predicted_sequence)
        text_preds.append(_text_preds)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % 1000 == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
    text_preds = np.concatenate(text_preds)
    return text_preds

In [None]:
train_dataset = TrainDataset(df_train_labels.sample(100000), tokenizer, transform=get_transforms(data='train'))
valid_dataset = TestDataset(df_sample_submission, transform=get_transforms(data='valid'))

train_loader = DataLoader(train_dataset, 
                          batch_size=64, 
                          shuffle=True, 
                          num_workers=4, 
                          pin_memory=True,
                          drop_last=True, 
                          collate_fn=bms_collate)
valid_loader = DataLoader(valid_dataset, 
                          batch_size=64, 
                          shuffle=False, 
                          num_workers=4,
                          pin_memory=True, 
                          drop_last=False)

In [None]:
encoder = Encoder(model_name='resnet18', pretrained=True)
encoder.to(device)
encoder_optimizer = Adam(encoder.parameters(), lr=1e-4, weight_decay=1e-6, amsgrad=False)

decoder = DecoderWithAttention(attention_dim=256,
                               embed_dim=256,
                               decoder_dim=512,
                               vocab_size=len(tokenizer),
                               dropout=0.5,
                               device=device)
decoder.to(device)
decoder_optimizer = Adam(decoder.parameters(), lr=1e-4, weight_decay=1e-6, amsgrad=False)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])

In [None]:
for epoch in range(1):
    start_time = time.time()

    # train
    avg_loss = train_fn(train_loader, encoder, decoder, criterion, 
                        encoder_optimizer, decoder_optimizer, epoch, device)

    # eval
    text_preds = valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device)
    text_preds = [f"InChI=1S/{text}" for text in text_preds]
#     LOGGER.info(f"labels: {valid_labels[:5]}")
#     LOGGER.info(f"preds: {text_preds[:5]}")

    # scoring
#     score = get_score(valid_labels, text_preds)
    elapsed = time.time() - start_time

#     LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  time: {elapsed:.0f}s')
#     LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
    
    with open('InChI_output.pkl', 'wb') as f:
        pickle.dump(text_preds, f)