In [None]:
pre_path = '../input/pytorch-image-models/pytorch-image-models-master'

import sys; sys.path.append(pre_path)

In [None]:
# at the top of the file, before other imports
import warnings

warnings.filterwarnings('ignore')

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os

import matplotlib.pyplot as plt
import seaborn as sns;# sns.set()

from tqdm import tqdm

import cv2

from glob import glob

In [None]:
train_df = pd.read_csv('../input/recursion-cellular-image-classification/train.csv')
train_control = pd.read_csv('../input/recursion-cellular-image-classification/train_controls.csv')

test_df = pd.read_csv('../input/recursion-cellular-image-classification/test.csv')
test_control = pd.read_csv('../input/recursion-cellular-image-classification/test_controls.csv')

sub = pd.read_csv('../input/recursion-cellular-image-classification/sample_submission.csv')
pix = pd.read_csv('../input/recursion-cellular-image-classification/pixel_stats.csv')

In [None]:
train_df['category'] = train_df['experiment'].apply(lambda x: x.split('-')[0])
train_df['branch'] = train_df['experiment'].apply(lambda x: x.split('-')[1])

test_df['category'] = test_df['experiment'].apply(lambda x: x.split('-')[0])
test_df['branch'] = test_df['experiment'].apply(lambda x: x.split('-')[1])


train_df['sirna'] = train_df['sirna'].apply(lambda x: x.split('_')[1]).astype('int')

In [None]:
train_df.info()

In [None]:
# work on 2 train sites
site1 = train_df[['id_code','category', 'sirna']]
site2 = train_df[['id_code','category', 'sirna']]

site1['site'] = site1['id_code'] + '_s1'
site2['site'] = site2['id_code'] + '_s2'

train = pd.concat([site1, site2], ignore_index=True)
n_classes = train['sirna'].nunique()

print(train.shape)
train.head(10)

In [None]:
# work on 2 test site
test_site1 = test_df[['id_code', 'category']]
test_site2 = test_df[['id_code', 'category']]

test_site1['site'] = test_site1['id_code'] + '_s1'
test_site2['site'] = test_site2['id_code'] + '_s2'

test = pd.concat([test_site1, test_site2], ignore_index=True)
print(test_site1.shape)
print(test_site2.shape)

test.head(10)

In [None]:
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

In [None]:
i = 150
path = '../input/recursion-cellular-image-classification-224-jpg/train/train/'
impath = path+train['site'][i]+'.jpeg'

img = get_img(impath)
plt.imshow(img)

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
from torch.cuda.amp import autocast, GradScaler

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2


from sklearn.metrics import log_loss
from sklearn.model_selection import StratifiedKFold, train_test_split

In [None]:
config = {
    'num_fold': 5,
    'epoch': 10,
    'seed': 42,
    'img_size': 244,
    'lr': 1e-3,
    'weight_decay':1e-5,
    'batch_size':32,
    'model_arc': 'tf_efficientnet_b3_ns'   
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class RecursionDataset(Dataset):
    def __init__(self, df, path, labels=True, transform=None):
        super().__init__()
        self.df = df
        self.path = path
        self.labels = labels
        self.transform = transform
        
        self.targets = df.sirna.values
        self.site = df.site.values
        self.cat = df.category.values
        
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, idx):
        category = self.cat[idx]
        img_path = self.path + self.site[idx] + '.jpeg'
        
        #print(img_path)
        image = get_img(img_path)
        
        if self.transform != None:
            image = self.transform(image=image)['image']
            
        
        if self.labels:
            target = self.targets[idx]
            data = (image, target)
        else:
            data = (image)
        
        
        return data

def get_train_transforms():
    return A.Compose([
            A.RandomResizedCrop(config['img_size'], config['img_size']),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(p=0.5),
            A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            A.CoarseDropout(p=0.5),
            A.Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)

In [None]:
trn_idx, val_idx = train_test_split(train.index, test_size=0.2, random_state=config['seed'])
train_set, valid_set = RecursionDataset(train.iloc[trn_idx], path, True, transform=get_train_transforms()), RecursionDataset(train.iloc[val_idx], path, True, transform=get_train_transforms())

site1set = RecursionDataset(site1, path, True, transform=get_train_transforms())
site2set = RecursionDataset(site2, path, True, transform=get_train_transforms())

In [None]:
train_loader = DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=2)
valid_loader = DataLoader(valid_set, batch_size=config['batch_size'], shuffle=False, num_workers=2)

In [None]:
class RecursionModel(nn.Module):
    def __init__(self, model_arc, pretrained=False, n_class=n_classes):
        super().__init__()
        self.backbone = timm.create_model(model_arc, pretrained)
        n_features = self.backbone.classifier.in_features
        
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(n_features, n_class)
        )
        
    def forward(self, x):
        return self.backbone(x)

In [None]:
def train_loop(epoch, loader, model, loss_fn, opt, scheduler=None, device=device):
    model.train()
    
    running_loss = None
    pbar = tqdm(enumerate(loader), len(loader))
    
    for i, (image, label) in pbar:
        image, label = image.to(device).float(), label.to(device).long()
        
        opt.zero_grad()
        y_pred = model(image)
        loss = loss_fn(y_pred, label)
        loss.backward()
        
        if running_loss is None:
            running_loss = loss.item()
        else:
            running_loss = running_loss * .9 + loss.item() * .1
        
        opt.step()
        scheduler.step()
        
        if (i+1) % 2 == 0 or (i+1) == len(loader):
            description = f'epoch {epoch}, loss: {running_loss:.4f}'
            pbar.set_description(description)

def valid_loop(epoch, val_loader, model, loss_fn, scheduler=None, device=device):
    model.eval()
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % 2 == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    valid_acc = (image_preds_all==image_targets_all).mean()
    print('validation multi-class accuracy = {:.4f}'.format(valid_acc))
    
    scheduler.step()
            
    return valid_acc

## Code still under modification, stay tuned.