# 6 channel data loader

In [None]:
import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn as nn
import torch.utils.data as D
import torch.nn.functional as F

import torchvision
from torchvision import transforms as T
from torchvision import models

import tqdm

from sklearn.model_selection import train_test_split

from PIL import Image
import matplotlib.pyplot as plt

from tqdm.notebook import trange, tqdm
from pathlib import Path

from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()

free_gpu_cache()   


# Change the path!
path_data = 'D:\\Data\\DeepLearning\\recursion-cellular-image-classification'
torch.manual_seed(0)

## Dataset thing that will get 6 channels and also the validation function

In [None]:
class ImagesDS(D.Dataset):
    def __init__(self, df, img_dir, mode='train', site=1, channels=[1,2,3,4,5,6], transform=None):
        self.records = df.to_records(index=False)
        self.channels = channels
        # self.site = site
        self.mode = mode
        self.img_dir = img_dir
        self.len = df.shape[0]
        self.transform = transform
        self.unique_list = np.unique([int(n.split("_")[1]) for n in df['sirna']])
        self.mapping = {}
        for (i, val) in enumerate(self.unique_list):
            self.mapping[val] = i

    @staticmethod
    def _load_img_as_tensor(file_name):
        with Image.open(file_name) as img:
            return T.ToTensor()(img)

    def _get_img_path(self, index, channel):
        mode = self.mode
        if self.mode == 'valid':
            mode = 'train'
        experiment, well, plate = self.records[index].experiment, self.records[index].well, self.records[index].plate
        site = self.records[index].site
        return '/'.join([self.img_dir,mode,experiment,f'Plate{plate}',f'{well}_s{site}_w{channel}.png'])

    def __getitem__(self, index):
        skip = False
        paths_1 = [self._get_img_path(index, ch) for ch in self.channels]
        img_1 = []
        for img_path in paths_1:
            if Path(img_path).exists():
                img_1.append(self._load_img_as_tensor(img_path))
            else:
                img_1.append(torch.zeros(1, 512, 512))
        img = torch.cat(img_1)
        if self.mode == 'train':
            # if training, then apply transformation
            if self.transform is not None:
                img = self.transform(img)
            return img, self.mapping[int(self.records[index].sirna.split("_")[1])]
        elif self.mode == 'valid':
            if self.transform is not None:
                img = self.transform(img)
            return img, self.mapping[int(self.records[index].sirna.split('_')[1])]
        else:
            return img, self.records[index].id_code

    def __len__(self):
        return self.len

## Creates the data frames for use with 2 sites

In [None]:
transforms = T.Compose([
    T.RandomCrop(384, 384),
    T.RandomHorizontalFlip(),
    T.RandomRotation(90),
    T.RandomVerticalFlip()
])


df = pd.read_csv(path_data+'/train.csv')
df["label"] = -1
for index, obs in df.iterrows():
    df['label'].loc[index] = int(obs['sirna'].split("_")[1])
subset_index = df['label'] <= 1108
subset_df = df.loc[subset_index, :]
subset_df = subset_df.drop(['label'], axis = 1)
subset_df['site'] = "1"
site_2 = subset_df.copy()
site_2['site'] = "2"
subset_df = pd.concat([subset_df, site_2])

print(subset_df)

df_train, df_test = train_test_split(subset_df, test_size = 0.1, random_state=42)
df_train, df_val = train_test_split(df_train, test_size=0.1, random_state=42)

print(df_train)

ds_train = ImagesDS(df_train, path_data, mode='train', transform=transforms)
ds_val = ImagesDS(df_val, path_data, mode='valid')
ds_test = ImagesDS(df_test, path_data, mode='valid')



batch_size = 16

train_loader = D.DataLoader(ds_train, batch_size=batch_size, shuffle=True)
test_loader = D.DataLoader(ds_test, batch_size=batch_size, shuffle=False)
valid_loader = D.DataLoader(ds_val, batch_size=16, shuffle=False)

In [None]:
def validation_accuracy(model, valid_loader, p=False):
    total_t=0
    correct_t=0
    with torch.no_grad():
        model.eval()
        for batch_idx, (data_t, target_t) in enumerate(tqdm(valid_loader, desc='Batches', leave=False)):
            data_t, target_t = data_t.to(device), target_t.to(device)
            outputs_t = model(data_t)
            loss_t = criterion(outputs_t, target_t)
            _,pred_t = torch.max(outputs_t, dim=1)
            correct_t += torch.sum(pred_t==target_t).item()
            total_t += target_t.size(0)
    return (100 * correct_t / total_t)

## Load model

In [None]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Device is: {device}")

model = models.resnet50(pretrained=True)

channels = 6
trained_kernel = model.conv1.weight
new_conv = nn.Conv2d(channels, 64, 7, 2, 3, bias=False)
with torch.no_grad():
    new_conv.weight[:,:] = torch.stack([torch.mean(trained_kernel, 1)] * channels, dim=1)
model.conv1 = new_conv


### Load model
model.load_state_dict(torch.load(f'6chResNet50/resnet_100.pt'), strict=False)
model.eval()

classes = 1108
model.fc = torch.nn.Sequential(
    torch.nn.Linear(model.fc.in_features, 1000),
    torch.nn.ReLU(),
    torch.nn.Linear(1000, classes)
)

model = model.to(device)

criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)

def accuracy(out, labels):
    _,pred = torch.max(out, dim=1)
    return torch.sum(pred==labels).item()


## Train models

In [None]:
# path to save the best model
best_model_path = '6chResNet50/resnet_1108.pt'


# If you want a variable batch size, where large_batch is for warm-up
large_batch = 64
small_batch = 20

n_epochs = 200
max_val_acc = 0
val_loss = []
val_acc = []
train_loss = []
train_acc = []
total_step = len(train_loader)
for epoch in trange(n_epochs, desc='Epochs', leave=False):
    if epoch == 0:
        # Sets batch size to the large when only training a few layers
        train_loader = D.DataLoader(ds_train, batch_size=large_batch, shuffle=True)
        for name, child in model.named_children():
            if name == 'fc':
                # print(f"{name} is unfrozen")
                for param in child.parameters():
                    param.requires_grad = True
            else:
                # print(f"{name} is frozen")
                for param in child.parameters():
                    param.requires_grad = False
    if epoch == 0:
        # Shrinks the batch size when training all layers
        train_loader = D.DataLoader(ds_train, batch_size=small_batch, shuffle=True)
        print("Turn on all the layers")
        for name, child in model.named_children():
            for param in child.parameters():
                param.requires_grad = True
    running_loss = 0.0
    correct = 0
    total=0
    for batch_idx, (data_, target_) in enumerate(tqdm(train_loader, desc='Batches', leave=False)):
        data_, target_ = data_.to(device), target_.to(device)
        optimizer.zero_grad()
        
        outputs = model(data_)
        loss = criterion(outputs, target_)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _,pred = torch.max(outputs, dim=1)
        correct += torch.sum(pred==target_).item()
        total += target_.size(0)

    train_acc.append(100 * correct / total)
    train_loss.append(running_loss/total_step)
    print(f'train-loss: {np.mean(train_loss):.4f}, train-acc: {(100 * correct/total):.4f}')
    batch_loss = 0
    total_t=0
    correct_t=0
    with torch.no_grad():
        model.eval()
        for data_t, target_t in valid_loader:
            data_t, target_t = data_t.to(device), target_t.to(device)
            outputs_t = model(data_t)
            loss_t = criterion(outputs_t, target_t)
            batch_loss += loss_t.item()
            _,pred_t = torch.max(outputs_t, dim=1)
            correct_t += torch.sum(pred_t==target_t).item()
            total_t += target_t.size(0)
        val_acc.append(100 * correct_t/total_t)
        val_loss.append(batch_loss/len(valid_loader))
        network_learned = val_acc[len(val_acc) - 1] > max_val_acc
        print(f'validation loss: {np.mean(val_loss):.4f}, validation acc: {(100 * correct_t/total_t):.4f}')
  
        if network_learned:
            max_val_acc = val_acc[len(val_acc) - 1]
            torch.save(model.state_dict(), best_model_path)
            print('Improvement-Detected, save-model')
    model.train()

In [None]:
# Save model data, will need to change path
torch.save(model.state_dict(), f'6chResNet50/resnet_50_1108.pt')
torch.save(val_loss, f'6chResNet50acc/val_loss_1108.pt')
torch.save(train_loss, f'6chResNet50acc/train_loss_1108.pt')
torch.save(val_acc, f'6chResNet50acc/val_acc_1108.pt')
torch.save(train_acc, f'6chResNet50acc/train_acc_1108.pt')

## Load saved model before the next step!!!

In [None]:
# For example

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Device is: {device}")

model = models.resnet50(pretrained=True)

channels = 6
trained_kernel = model.conv1.weight
new_conv = nn.Conv2d(channels, 64, 7, 2, 3, bias=False)
with torch.no_grad():
    new_conv.weight[:,:] = torch.stack([torch.mean(trained_kernel, 1)] * channels, dim=1)
model.conv1 = new_conv

classes = 1108
model.fc = torch.nn.Sequential(
    torch.nn.Linear(model.fc.in_features, 1000),
    torch.nn.ReLU(),
    torch.nn.Linear(1000, classes)
)

### Load model
model.load_state_dict(torch.load(f'6chResNet50/resnet_50_1108.pt'), strict=False)
model.eval()


model = model.to(device)


criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00025)

def accuracy(out, labels):
    _,pred = torch.max(out, dim=1)
    return torch.sum(pred==labels).item()


## Get model accuracies

In [None]:
train_accuracy = validation_accuracy(model, train_loader, True)
valid_accuracy = validation_accuracy(model, valid_loader, True)
test_accuracy = validation_accuracy(model, test_loader, True)

print(f'Train Accuracy: \nValidation Accuracy: {valid_accuracy}\nTest Accuracy: {test_accuracy}')

## Make plots if you have saved the data

In [None]:
def get_numbers(ch, name, versions) :
    res = np.array([])
    for v in versions:
        temp = torch.load(f'{ch}chResNet50acc/{name}_{v}.pt')
        res = np.concatenate((res, np.array(temp)))
    return res

def get_breaks(ch, name, versions):
    res = np.array([])
    total = 0
    for v in versions:
        temp = torch.load(f'{ch}chResNet50acc/{name}_{v}.pt')
        total += len(temp)
        res = np.concatenate((res, np.array([total])))
    return res

# Specify what ever version numbers you have saved, like [1,2,3,4]
versions = [1108]
ch = 6
val_loss = get_numbers(ch, 'val_loss', versions)
train_loss = get_numbers(ch, 'train_loss', versions)
val_acc = get_numbers(ch, 'val_acc', versions)
train_acc = get_numbers(ch, 'train_acc', versions)
breaks = get_breaks(ch, 'val_loss', versions)
breaks = np.delete(breaks, -1)

plt.plot(train_loss, color="blue", label='Training')
plt.plot(val_loss, color="orange", label='Validation')
plt.legend()
for xc in breaks:
    plt.axvline(x = xc, alpha=0.5)
plt.title(f"1108 Classes {ch}ch ResNet50 Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

plt.plot(train_acc, color="blue", label='Training')
plt.plot(val_acc, color="orange", label='Validation')
# Change this 
plt.axhline(y=test_accuracy, c='green', linestyle='dashed', label='Final Test', alpha = .45)
plt.legend()
for xc in breaks:
    plt.axvline(x = xc, alpha=0.5)
plt.title(f"1108 Classes {ch}ch ResNet50 Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy %")
plt.show()