In [None]:
import numpy as np 
import pandas as pd
import cv2
import os
import re

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm
tqdm.pandas()

In [None]:
BATCH = 16
EPOCHS = 2

LR = 0.0001
IM_SIZE = 128

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

TRAIN_LABELS_PATH = '../input/bms-molecular-translation/train_labels.csv'
TRAIN_DIR = '../input/bms-molecular-translation/train/'
TEST_DIR = '../input/bms-molecular-translation/test/'

In [None]:
train_labels = pd.read_csv('../input/bms-molecular-translation/train_labels.csv')
train_labels.head()

In [None]:
def get_train_file_path(image_id):
    return "../input/bms-molecular-translation/train/{}/{}/{}/{}.png".format(
        image_id[0], image_id[1], image_id[2], image_id 
    )

def get_test_file_path(image_id):
    return "../input/bms-molecular-translation/test/{}/{}/{}/{}.png".format(
        image_id[0], image_id[1], image_id[2], image_id 
    )

train_labels['file_path'] = train_labels['image_id'].progress_apply(get_train_file_path)

print(f'train.shape: {train_labels.shape}')
train_labels.head()

In [None]:
# Get chemical formula
train_labels['formula'] = train_labels['InChI'].progress_apply(lambda x: x.split('/')[1])
train_labels.head()

In [None]:
# Get labels
CL = ["N", "Br", "I", "S", "Cl", "H", "C", "P", "O", "Si", "F", "B"]
NUM_CL = len(CL)
NUM_CL

In [None]:
# Remove digits from chemical formula
def str_to_list(st):    
    r2=[]
    for e in ['Br', 'Cl', 'Si']:
        if e in st:
            r2.append(e)
            st = st.replace(e,'')
    return [chr for chr in st] + r2


train_labels['f2'] = train_labels['formula'].progress_apply(lambda x: str_to_list("".join(re.findall("[a-zA-Z]+", x))))
train_labels.head(10)

In [None]:
md = dict(zip(CL, range(NUM_CL+1)))
md

In [None]:
# Replace 'lab' to digits
train_labels['lab'] = train_labels['f2'].progress_apply(lambda x: "|".join([str(md[v]) for v in x]))
train_labels.head(10)

In [None]:
# Just for speed up
train_df = train_labels[:1000]

In [None]:
X_Train, Y_Train = train_df['file_path'].values, train_df['lab'].values

In [None]:
class GetData(Dataset):
    def __init__(self, Dir, FNames, Labels, Transform):
        self.dir = Dir
        self.fnames = FNames
        self.transform = Transform
        self.labels = Labels         
        
    def __len__(self):
        return len(self.fnames)

    def __getitem__(self, index):
        img = cv2.imread(self.fnames[index])
        img = cv2.resize(img, (IM_SIZE, IM_SIZE))
        X = img/255.        
    
        if "train" in self.dir:              
            y = self.labels[index]
            y = y.split('|')
            y = list(map(int, y))            
            y = np.eye(NUM_CL, dtype='float')[y]                                    
            y = y.sum(axis=0)

            return self.transform(X), y

In [None]:
Transform = transforms.Compose(
    [transforms.ToTensor()])

In [None]:
trainset = GetData(TRAIN_DIR, X_Train, Y_Train, Transform)
trainloader = DataLoader(trainset, batch_size=BATCH, shuffle=True)

In [None]:
next(iter(trainloader))[0].shape

In [None]:
model = torchvision.models.resnet34()
model.fc = nn.Linear(512, NUM_CL, bias=True)
model = model.to(DEVICE)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [None]:
%%time

for epoch in range(EPOCHS):
    tr_loss = 0.0

    model = model.train()

    for i, (images, labels) in enumerate(trainloader):        
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)       
        logits = model(images.float())       
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss += loss.detach().item()
    
    model.eval()
    print('Epoch: %d | Loss: %.4f'%(epoch, tr_loss / i))