In [5]:
from PIL import Image
from torchvision import transforms
from glob import iglob
import torch
from torch import nn
from torchvision import models
import os

In [6]:
class ResNet50_feat(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        resnet_50 = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.extractor = nn.Sequential(*(list(resnet_50.children())[:-1]))

    def forward(self, x):
        feat = self.extractor(x)
        feat = feat.transpose(1, 3)
        feat = nn.Linear(2048, 1024)(feat)
        return feat


In [7]:
model = ResNet50_feat()
model.eval()
# Maybe need normalization
transform = transforms.Compose([transforms.ToTensor()])
for folder in sorted(iglob('../BCNB_Dataset/patches/*'),key=lambda x:int(os.path.split(x)[1])):
    idx = os.path.split(folder)[1]
    print(idx,end='\t')
    img_list = []
    
    # each patch
    for pth in iglob(folder+'/*.jpg'):
        img = Image.open(pth)
        img_list.append(transform(img).unsqueeze(0))
    patch_tensor = torch.cat(img_list,dim=0)
    feat = model(patch_tensor).squeeze(1).transpose(0,1)
    # print(feat.shape,feat)
    feat = feat.clone().detach().requires_grad_(False)
    feat = feat.data
    print(feat.shape)
    # print(feat.shape,feat)
    break
    torch.save(feat,'TransMIL/pt_files/'+idx+'.pt')
    # with open('BCNB_Dataset/pt_files/'+idx+'.pt','wb') as f:
    #     pickle.dump(feat,f)

1	torch.Size([1, 26, 1024])


#### For label

In [47]:
import pandas as pd
import numpy as np

df = pd.read_csv('TransMIL/dataset_csv/bcnb/fold0.csv')
df.train_label = df.train_label.astype(int)
df.fillna(-999999, inplace=True)
df.val = df.val.astype(int)
df.val_label = df.val_label.astype(int)
df.test = df.test.astype(int)
df.test_label = df.test_label.astype(int)
# df.replace(-999999, np.nan,inplace=True)

df.to_csv('TransMIL/dataset_csv/bcnb/fold0.csv')
df.head(10000)


Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Unnamed: 0.1.1,train,train_label,val,val_label,test,test_label
0,0,0,0,71,1,4,0,981,1
1,1,1,1,910,0,800,1,483,1
2,2,2,2,167,1,1034,1,1056,1
3,3,3,3,616,1,418,1,488,1
4,4,4,4,503,1,587,0,716,1
...,...,...,...,...,...,...,...,...,...
625,625,625,625,711,1,-999999,-999999,-999999,-999999
626,626,626,626,774,0,-999999,-999999,-999999,-999999
627,627,627,627,346,1,-999999,-999999,-999999,-999999
628,628,628,628,228,1,-999999,-999999,-999999,-999999


In [8]:
# modified from Pytorch official resnet.py
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F

__all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]

model_urls = {
    "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
    "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
    "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
    "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}


class Bottleneck_Baseline(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck_Baseline, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, planes * self.expansion, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet_Baseline(nn.Module):
    def __init__(self, block, layers):
        self.inplanes = 64
        super(ResNet_Baseline, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes,
                    planes * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x


def load_pretrained_weights(model, name):
    pretrained_dict = model_zoo.load_url(model_urls[name])
    model.load_state_dict(pretrained_dict, strict=False)
    return model

def resnet50_baseline(pretrained=False):
    """Constructs a Modified ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet_Baseline(Bottleneck_Baseline, [3, 4, 6, 3])
    if pretrained:
        model = load_pretrained_weights(model, "resnet50")
    return model


In [15]:
model = resnet50_baseline(pretrained=True)
model.eval()
# Maybe need normalization
transform = transforms.Compose([transforms.ToTensor()])
for folder in sorted(iglob('../BCNB_Dataset/patches/*'),key=lambda x:int(os.path.split(x)[1])):
    idx = os.path.split(folder)[1]
    print(idx,end='\t')
    img_list = []
    
    # each patch
    for pth in iglob(folder+'/*.jpg'):
        img = Image.open(pth)
        img_list.append(transform(img).unsqueeze(0))
    patch_tensor = torch.cat(img_list,dim=0)
    feat = model(patch_tensor)
    # print(feat.shape,feat)
    feat = feat.clone().detach().requires_grad_(False)
    feat = feat.data
    print(feat.shape)
    # print(feat.shape,feat)
    torch.save(feat,'pt_files/'+idx+'.pt')
    # with open('BCNB_Dataset/pt_files/'+idx+'.pt','wb') as f:
    #     pickle.dump(feat,f)

1	torch.Size([26, 1024])
2	torch.Size([16, 1024])
3	torch.Size([75, 1024])
4	torch.Size([27, 1024])
5	torch.Size([70, 1024])
6	torch.Size([240, 1024])
7	torch.Size([18, 1024])
8	torch.Size([21, 1024])
9	torch.Size([20, 1024])
10	torch.Size([40, 1024])
11	torch.Size([26, 1024])
12	torch.Size([10, 1024])
13	torch.Size([23, 1024])
14	torch.Size([11, 1024])
15	torch.Size([21, 1024])
16	torch.Size([115, 1024])
17	torch.Size([26, 1024])
18	torch.Size([22, 1024])
19	torch.Size([58, 1024])
20	torch.Size([1089, 1024])
21	torch.Size([22, 1024])
22	torch.Size([68, 1024])
23	torch.Size([61, 1024])
24	torch.Size([10, 1024])
25	torch.Size([39, 1024])
26	torch.Size([32, 1024])
27	torch.Size([73, 1024])
28	torch.Size([34, 1024])
29	torch.Size([31, 1024])
30	torch.Size([42, 1024])
31	torch.Size([154, 1024])
32	torch.Size([26, 1024])
33	torch.Size([77, 1024])
34	torch.Size([58, 1024])
35	torch.Size([18, 1024])
36	torch.Size([111, 1024])
37	torch.Size([272, 1024])
38	torch.Size([53, 1024])
39	torch.Size(