In [None]:
!pip install -qU timm
!pip install -qU wandb

In [None]:
import os,glob,warnings,random
import cv2
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from pprint import pprint
from sklearn.model_selection import train_test_split


# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

import timm

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations import Normalize,Resize,Compose
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
import wandb

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("WANDB")
    wandb.login(key=api_key)
    anonymous = None
except:
    anonymous = "must"
    print('To use your W&B account,\nGo to Add-ons -> Secrets and provide your W&B access token. Use the Label name as WANDB. \nGet your W&B access token from here: https://wandb.ai/authorize')

In [None]:
def fix_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

fix_all_seeds(2021)

In [None]:
INPUT_PATH = "../input/sartorius-cell-instance-segmentation/"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('DEVICE: ', DEVICE)

In [None]:

df_train=pd.read_csv(INPUT_PATH+'train.csv')
df_train=df_train.groupby('id')[['cell_type']].first().reset_index()
display(df_train)

In [None]:
train_image_paths=[INPUT_PATH+f'train/{i}.png' for i in df_train['id']]
semi_image_paths=glob.glob(INPUT_PATH+'train_semi_supervised/*.png')
train_image_paths.extend(semi_image_paths)

train_labels = df_train['cell_type'].to_list()
semi_labels=[path.split('/')[-1].split('[')[0] for path in semi_image_paths]
semi_labels=['astro' if label=='astros' else label for label in semi_labels]
train_labels.extend(semi_labels)

df=pd.DataFrame({'image_path':train_image_paths,'cell_type':train_labels})
display(df)

In [None]:
IMAGE_RESIZE=(224,224)
RESNET_MEAN=(0.485, 0.456, 0.406)
RESNET_STD = (0.229, 0.224, 0.225)

class DatasetImageCelltype(Dataset):
    def __init__(self,df):
        self.df=df
        self.images_paths=df['image_path']
        self.labels=df['cell_type']
        
    def __getitem__(self,idx):
        transforms=Compose([Resize(IMAGE_RESIZE[0],IMAGE_RESIZE[1]),
                            Normalize(mean=RESNET_MEAN,std=RESNET_STD,p=1),
                            ToTensorV2()
                           ])
        image_path=self.images_paths.iloc[idx]
        image=cv2.imread(image_path)
        image=transforms(image=image)['image']
        label_list=['shsy5y','astro','cort']
        label=self.labels.iloc[idx]
        label_id=label_list.index(label)
        return {'image':image,'label':label_id}
    
    def __len__(self):
        return len(self.df)

In [None]:
# Split into train and validation
df_train, df_valid = train_test_split(df, test_size=0.20)

# Dataset
ds_train = DatasetImageCelltype(df_train)
ds_valid = DatasetImageCelltype(df_valid)
# Data loader
dl_train = DataLoader(ds_train, batch_size=64, num_workers=0, pin_memory=True, shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=64, num_workers=0, pin_memory=True, shuffle=False)

print(f'Number of train dataset {len(ds_train)}')
print(f'Number of valid dataset {len(ds_valid)}')

In [None]:
import timm
print(timm.list_models())

In [None]:
model=timm.create_model('swin_base_patch4_window7_224',pretrained=True)
model

In [None]:
model.head=nn.Linear(in_features=1024,out_features=3,bias=True)
model

In [None]:
LEARNING_RATE = 3e-4
EPOCHS = 20
model.to(DEVICE)
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=LEARNING_RATE)
wandb.init()
wandb.watch(model, log_freq=100)

for epoch in range(1,EPOCHS+1):
    print(f'Epoch: {epoch}/{EPOCHS}')
    model.train()
    scaler=amp.GradScaler()
    
    optimizer.zero_grad()
    loss_train =0.0
    correct_train=0.0
    pbar = tqdm(enumerate(dl_train), total=len(dl_train), desc='Train ')
    for idx,data in pbar:
        # Input
        images, labels = data['image'], data['label']
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        with amp.autocast(enabled=True):
            outputs = model(images) # probabilities
            loss = criterion(outputs, labels)
            loss_train += loss
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        
        outputs = outputs.argmax(dim=1) # one hot vector
        correct_train += (labels==outputs).sum()
        
        mem=torch.cuda.memory_reserved()/1E9 if torch.cuda.is_available() else 0
        pbar.set_postfix(train_loss=f'{loss_train / len(dl_train):0.4f}',
                        lr=optimizer.param_groups[0]['lr'],
                        gpu_memory=f'{mem:0.2f} GB')
    loss_train = loss_train / len(dl_train)
    acc_train = correct_train / len(ds_train)
    print(f'Train loss: {loss_train:.4f}, Train accuracy: {acc_train*100:.2f}%') 
    
    model.eval()
    loss_valid = 0.0
    correct_valid = 0.0
    with torch.no_grad():
        for data in tqdm(dl_valid, total=len(dl_valid), desc='[valid]'):
            images, labels = data['image'], data['label']
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images) # probabilities
            loss_valid += criterion(outputs, labels)
            outputs = outputs.argmax(dim=1) # one hot vector
            correct_valid += (labels==outputs).sum()
            
    
    loss_valid = loss_valid / len(dl_valid)
    acc_valid = correct_valid / len(ds_valid)
    
    torch.cuda.empty_cache()
    print(f'Valid loss: {loss_valid:.4f}, Valid accuracy: {acc_valid*100:.2f}%\n')   
    wandb.log({"Train Loss": loss_train, 
                   "Valid Loss": loss_valid,
                   "Train Acc": acc_train,
                   "Valid Acc": acc_valid,
                   "LR":optimizer.param_groups[0]['lr']})

In [None]:
torch.save(model, 'swin_ransformer_crassifier.bin')