In [123]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from skimage import io as image_io
from torch import optim

In [96]:
class TCGABags(Dataset):
    def __init__(self, df, labels, bag_len=10, mean_bag_length=10, var_bag_length=1, num_bag=1000, seed=7, train=True):
        self.mean_bag_length = mean_bag_length
        self.var_bag_length = var_bag_length
        self.num_bag = num_bag
        self.seed = seed
        self.train = train
        self.df = df
        self.idx = df.index.get_level_values(0).unique()
        self.labels = labels
        self.bag_len = bag_len

        self.r = np.random.RandomState(seed)
        
    def _form_bags(self, df: pd.Series):
        bags_list = []
        labels_list = []

        for i in df.index.get_level_values(0).unique():
            # labels_list.append()
            print(df.loc[i].iloc[0])
            df.loc[i].sample()
            for j in df.loc[i].index:
                print(j)
                img = image_io.imread(j)
            print(df.loc[i])

        return bags_list, labels_list

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, index):
        batch = self.df.loc[self.idx[index]].sample(self.bag_len, replace=True)
        label = self.labels.index(batch.iloc[0, 0])
        # print(batch.shape[0])
        res = torch.empty((10, 224, 224, 3))
        for i, fname in enumerate(batch.index):
            # print(fname)
            img = image_io.imread(fname)
            res[i] = torch.from_numpy(img)
            # print(img.shape)
        return res, label

In [30]:
df = pd.read_csv('/mnt/d/uczelnia/magister/TCGA_breast_patch/metadata.csv', index_col=0)
df['patient_id'] = df['file_name'].apply(lambda x: x.split('/')[6])
df.set_index(['patient_id', 'file_name'], inplace=True)

In [61]:
df.loc[df.index.get_level_values(0).unique()[97]].iloc[0, 0]

'hr+her2-'

In [131]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.M = 500
        self.L = 128
        self.ATTENTION_BRANCHES = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(3, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2)
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(50 * 53 * 53, self.M),
            nn.ReLU(),
        )

        self.attention = nn.Sequential(
            nn.Linear(self.M, self.L),  # matrix V
            nn.Tanh(),
            # matrix w (or vector w if self.ATTENTION_BRANCHES==1)
            nn.Linear(self.L, self.ATTENTION_BRANCHES)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.M*self.ATTENTION_BRANCHES, 3),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.squeeze(0)
        H = self.feature_extractor_part1(x)
        print(H.shape)
        H = H.reshape(-1, 50 * 53 * 53)
        H = self.feature_extractor_part2(H)  # KxM

        A = self.attention(H)  # KxATTENTION_BRANCHES
        A = torch.transpose(A, 1, 0)  # ATTENTION_BRANCHESxK
        A = F.softmax(A, dim=1)  # softmax over K

        Z = torch.mm(A, H)  # ATTENTION_BRANCHESxM

        Y_prob = self.classifier(Z)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob, Y_hat, A

    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        Y_Prob, Y_hat, _ = self.forward(X)
        print(Y_Prob, Y, Y_hat)
        error = 1. - Y_hat.eq(Y).cpu().float().mean().data.item()

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob, _, A = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        neg_log_likelihood = -1. * \
            (Y * torch.log(Y_prob) + (1. - Y) *
             torch.log(1. - Y_prob))  # negative log bernoulli

        return neg_log_likelihood, A


In [132]:
dl = DataLoader(TCGABags(df, ['hr+her2-', 'hr-', 'hr+her2+']), batch_size=1)
model = Attention().to('cuda')
optimizer = optim.Adam(model.parameters(), lr=1e-3,
                       betas=(0.9, 0.999))

model.train()
train_loss = 0.
train_error = 0.
for i, (data, bag_label) in enumerate(dl):
    print(data.shape)
    bag_label = bag_label.to('cuda')
    p = data.permute(0, 1, 4, 2, 3).to('cuda')
    optimizer.zero_grad()
    # calculate loss and metrics
    loss, _ = model.calculate_objective(p, bag_label)
    train_loss += loss.data[0]
    print(bag_label)
    error, _ = model.calculate_classification_error(p, bag_label)
    train_error += error
    
    loss.backward()
        # step
    optimizer.step()
    # backward pass

torch.Size([1, 10, 224, 224, 3])
torch.Size([10, 50, 53, 53])
tensor([2], device='cuda:0')
torch.Size([10, 50, 53, 53])
tensor([[0.9962, 0.0739, 0.8012]], device='cuda:0', grad_fn=<SigmoidBackward0>) tensor([2.], device='cuda:0') tensor([[1., 0., 1.]], device='cuda:0')


RuntimeError: grad can be implicitly created only for scalar outputs

In [17]:
image_io.imread('/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_9_7.png')

array([[[183, 131, 171],
        [200, 136, 173],
        [169, 131, 166],
        ...,
        [192, 171, 195],
        [232, 236, 244],
        [254, 237, 227]],

       [[171, 146, 184],
        [172, 117, 154],
        [ 86,  67, 106],
        ...,
        [133,  95, 142],
        [197, 192, 221],
        [243, 223, 220]],

       [[173, 130, 175],
        [152,  89, 137],
        [ 43,  24,  71],
        ...,
        [100,  52, 101],
        [128,  87, 153],
        [217, 210, 215]],

       ...,

       [[ 68,  52,  99],
        [ 94,  60, 116],
        [105,  54, 120],
        ...,
        [169, 116, 166],
        [184, 154, 184],
        [198, 190, 213]],

       [[100,  68, 125],
        [112,  87, 134],
        [105,  76, 124],
        ...,
        [206, 173, 207],
        [210, 173, 190],
        [220, 202, 214]],

       [[150, 101, 163],
        [136,  99, 148],
        [110,  68, 124],
        ...,
        [210, 180, 206],
        [201, 155, 185],
        [212, 179, 203]]

In [18]:
df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,label
patient_id,file_name,Unnamed: 2_level_1
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_10.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_11.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_12.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_13.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_14.png,hr+her2+


label    hr+her2+
Name: /mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_10.png, dtype: object
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_10.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_11.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_12.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_13.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_14.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_15.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_16.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_17.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_18.png
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_19.png
/mnt/d/uczelnia/magister/TCGA_breast_

KeyboardInterrupt: 

In [24]:
df.loc['TCGA-3C-AALI-01A-01-T_R1'].sample(10, random_state=10)#.loc['/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_10.png']

Unnamed: 0_level_0,label
file_name,Unnamed: 1_level_1
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_5_14.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_13_15.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_7_6.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_5_18.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_3_16.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_8_11.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_20.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_11_17.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_9_15.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_13_13.png,hr+her2+


In [26]:
df.loc['TCGA-3C-AALI-01A-01-T_R2'].sample(10, random_state=10)

Unnamed: 0_level_0,label
file_name,Unnamed: 1_level_1
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_5_11.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_9_6.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_12_7.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_7_15.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_9_17.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_4_15.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_6_17.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_3_12.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_5_15.png,hr+her2+
/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R2/patch_7_11.png,hr+her2+


In [8]:
df.index.get_level_values(0).unique()

Unnamed: 0_level_0,Unnamed: 1_level_0,label
patient_id,file_name,Unnamed: 2_level_1
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_10.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_11.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_12.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_13.png,hr+her2+
TCGA-3C-AALI-01A-01-T_R1,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-3C-AALI-01A-01-T_R1/patch_10_14.png,hr+her2+
...,...,...
TCGA-S3-AA17-01A-01-T_R2,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-S3-AA17-01A-01-T_R2/patch_9_5.png,hr+her2-
TCGA-S3-AA17-01A-01-T_R2,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-S3-AA17-01A-01-T_R2/patch_9_6.png,hr+her2-
TCGA-S3-AA17-01A-01-T_R2,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-S3-AA17-01A-01-T_R2/patch_9_7.png,hr+her2-
TCGA-S3-AA17-01A-01-T_R2,/mnt/d/uczelnia/magister/TCGA_breast_patch/TCGA-S3-AA17-01A-01-T_R2/patch_9_8.png,hr+her2-
