In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch

import torchvision
from torchvision import transforms as T
import torch.nn as nn
import torch.optim as optim
from skimage import io
import cv2

from torch.utils.data import Dataset, DataLoader
import seaborn as sns

In [None]:
num_classes = 28
batch_size = 16
num_epochs = 10
im_size =  512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
labels = pd.read_csv("../input/human-protein-atlas-image-classification/train.csv")
imag_f = labels.to_numpy()[:, 0]
imag_f[:5]

In [None]:
def op_img( id ):
    colors = ['red','green','blue','yellow']
    flags = cv2.IMREAD_GRAYSCALE
    img = [cv2.imread("../input/human-protein-atlas-image-classification/train/"+ id+'_'+color+'.png', flags).astype(np.float32)/255 for color in colors]
    return np.transpose( np.stack(img, axis=-1) )

In [None]:
class hpa_ds(Dataset):
    def __init__(self ):
        self.labels = pd.read_csv("../input/human-protein-atlas-image-classification/train.csv").to_numpy()[:,0]
        self.vals = pd.read_csv("../input/human-protein-atlas-image-classification/train.csv").to_numpy()[:,1]
            
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        imgs = op_img(self.labels[idx])
        labs = np.eye(num_classes ,dtype=np.float)[np.array( self.vals[idx].split(' ') ,dtype=np.int )].sum(axis=0)
        
        return torch.from_numpy(imgs) , torch.from_numpy(labs)

hpa = hpa_ds()

In [None]:
hpa[0][0].shape , hpa[0][1].shape , len(hpa)

In [None]:
train , val = torch.utils.data.random_split(hpa, (30500 , 572) )
dataloaders = { 'train': DataLoader( train ,batch_size=batch_size , num_workers=16)
                ,'val'  : DataLoader( val ,batch_size=8 ) }

In [None]:
fig, ax = plt.subplots(1,4 , figsize=(20, 5))
for i in range(4):
    ax[i].imshow(hpa[0][0][i].cpu())
    ax[i].axis('off')
    
plt.show()

In [None]:
idx = 98

c = hpa[idx][0][0, : , :]
m = hpa[idx][0][1, : , :]
y = hpa[idx][0][2, : , :]
k = hpa[idx][0][3, : , :]

r = torch.unsqueeze( (c+m)/2 , 0 )
g = torch.unsqueeze( (m+k)/2 , 0 )
b = torch.unsqueeze( (y+k)/2 , 0 )

img = torch.cat([r,g,b])

plt.figure(figsize = (10,10))
plt.imshow( img.permute(2,1,0).cpu() )
plt.axis('off')
plt.show()

plt.figure(figsize = (10,10))
plt.imshow( hpa[98][0][:3].permute(2,1,0).cpu() , aspect ='auto')
plt.axis('off')
plt.show()

In [None]:
text_labels = [
"Nucleoplasm", 
"Nuclear membrane",   
"Nucleoli",   
"Nucleoli fibrillar center" ,  
"Nuclear speckles",
"Nuclear bodies",
"Endoplasmic reticulum",   
"Golgi apparatus",
"Peroxisomes",
"Endosomes",
 "Lysosomes",
 "Intermediate filaments",   
 "Actin filaments",
 "Focal adhesion sites",   
 "Microtubules",
 "Microtubule ends",   
 "Cytokinetic bridge",   
 "Mitotic spindle",
 "Microtubule organizing center",  
 "Centrosome",
 "Lipid droplets",  
 "Plasma membrane",   
 "Cell junctions", 
 "Mitochondria",
 "Aggresome",
 "Cytosol",
 "Cytoplasmic bodies",   
 "Rods & rings" 
]

In [None]:
unq_c = []
full_c = np.array( [0]*28 )

for i in range(28):
    unq_c.append( len( labels[labels.Target == str(i) ] ) )

for row in labels['Target']:
    full_c[ np.array( row.split(' ') , np.int ) ] +=1

In [None]:
sns.set(style='darkgrid')
plt.figure(figsize=(10, 10))
sns.set_color_codes("muted")
sns.barplot(x = list(full_c) , y= text_labels, label="half count",color='b', orient = 'h' )
sns.set_color_codes("pastel")
sns.barplot(x = unq_c , y= text_labels, label="full count",color='b', orient = 'h' )

In [None]:
tr = hpa[-1][0][:3, : , : ]
tr=  tr.to(device)
tr = torch.unsqueeze( tr , 0 )

model = torchvision.models.resnet50(pretrained=True)
model = model.to(device)
model.eval()

with torch.no_grad():
    otps = model(tr)

otps.shape

In [None]:
class res_34(nn.Module):
    def __init__(self ):
        super().__init__()
        encoder = model
        
        self.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        w = encoder.conv1.weight
        self.conv1.weight = nn.Parameter(torch.cat((w, 0.5*(w[:,:1,:,:]+w[:,2:,:,:])),dim=1))
        
        self.bn1 = encoder.bn1
        self.relu = nn.ReLU(inplace=True) 
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer0 = nn.Sequential(self.conv1,self.relu,self.bn1,self.maxpool)
        self.layer1 = encoder.layer1
        self.layer2 = encoder.layer2
        self.layer3 = encoder.layer3
        self.layer4 = encoder.layer4
        self.avgpool = encoder.avgpool
        self.fc = nn.Linear(2048 , num_classes )
        
        
    def forward(self, x):
        x = self.layer0(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
    
res_next = res_34()
res_next

In [None]:
class focaloss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target):
        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
        
        m = nn.LogSigmoid()
        invprobs = m(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.sum(dim=1).mean()

In [None]:
params_to_update = []
for name,param in model.named_parameters():
    param.requires_grad = True
    if param.requires_grad == True:
        params_to_update.append(param)

optimizer = torch.optim.SGD( params_to_update , lr=10)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.005, max_lr=0.05,step_size_up=5,mode="exp_range",gamma=0.85)
lrs = []


for i in range(100):
    optimizer.step()
    lrs.append(optimizer.param_groups[0]["lr"])
    scheduler.step()

plt.plot(lrs)

In [None]:
import time
import copy
def train_model(model, dataloaders, criterion, optimizer, num_epochs , scheduler = None):
    model = model.to(device)
    model.cuda()
    
    since = time.time()
    val_acc_history = []
    train_acc_history = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}  - '.format(epoch, num_epochs - 1) , end = " ")
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for i, (inputs, labels) in enumerate( dataloaders[phase]) :
                inputs = inputs.to(device)
                labels = labels.to(device)
                inputs.cuda()
                labels.cuda()
                optimizer.zero_grad()
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                preds = outputs > 0.9

                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
                if(i%100==0):
                    print("-" , end ='')

            if scheduler:
                scheduler.step()                        
                
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print(' {} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc) , end = " ")
            
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            
            if phase == 'val':
                val_acc_history.append(epoch_acc.to('cpu'))
            else:
                train_acc_history.append(epoch_acc.to('cpu'))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    model.load_state_dict(best_model_wts)
    hist = {}
    hist['train'] = train_acc_history
    hist['val'] = val_acc_history
    return model, hist

In [None]:
criterian = focaloss()

optimizer_ft = torch.optim.SGD( params_to_update , lr=10)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer_ft, base_lr=0.005, max_lr=0.05,step_size_up=5,mode="exp_range",gamma=0.85)

res_next , hist = train_model(res_next, dataloaders , criterian, optimizer_ft, num_epochs=num_epochs , scheduler = scheduler)

In [None]:
train_acc = np.array( hist['train'] )
val_acc = np.array( hist['val'] )

sns.lineplot(x = range(1,11) ,y =train_acc , label = "Train accuracy")
sns.lineplot(x = range(1,11) ,y =val_acc , label = "val accuracy")

In [None]:
torch.save(res_next , './res_34.pt')