# 分类器

在用AutoEncoder预训练之后，Encoder已经可以对图像进行编码。这个时候是否可以回归到分类有监督学习？  
有没有方法减少标注数量，做渐进式学习？  

考虑如下分类：  
1. 杂质  
2. 血细胞  
3. 染色体  

这些分类是互斥的，使用Softmax激活。  
对于混合样本，直接丢弃。  
这里要利用Encoder+特征工程对少例分类（血细胞和杂质）进行挖掘，尽量使得分类标注比较均匀。  


### 思路一
首先考虑锁定Encoder参数，只训练分类器，同时先标注简单样本，用BinaryCrossEntropy Loss函数，从而建立起最初的特征-分类映射，作为训练的Baseline。  

此时网络在2个方面有所欠缺：  
1. 对于难样本缺乏区分能力  
2. Embedding本身的分布可能不够理想  

进一步考虑利用Baseline模型，选取难样本进行标注。同时放开Encoder参数，利用Center或Margin Loss函数，训练的目标是改善Embedding分布，提升难样本的区分能力。  
由于这个任务中，有些小染色体与杂质是比较像的，标注可能噪声较高，就不使用Focal Loss对难例进行加强训练。  

如果顺利的话，此时应该就获得了比较理想的分类模型。  
有没有办法进一步对Embedding进行监督，提高Embedding的分布质量？  
有没有办法实现增强学习？  

### 思路二

考虑直接在Embedding上进行训练，即利用Triplet Margin Loss直接计算样本对的距离。  
由于是多分类，3种分类就有9种样本对组合，需要尽量平均生成。  
这样可以改善Embedding的分布。  

猜测此时Embedding会呈现聚类效应，难例分布在概率密度较小的区域。  
由于没有加入FC等回归器，需要在Embedding训练完成后，训练GMM等聚类模型，用于分类预测。  


In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import sys
import os
import pickle
import argparse
import itertools
from datetime import datetime
import gc

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import cv2
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
import multiprocessing as mp
from tensorboardX import SummaryWriter
import pandas as pd

from datasets.simple import *
from resnet import *
from transforms import *
from plot import *
from autoencoder import *

In [2]:
batch_size = 32
device = torch.device('cuda:0')
img_size = 256
target_size = 256

anno_paths = [
    './cluster.csv'
]

data_root = '/mnt/nvme/data/chromosome/neg-chunk'

In [3]:
# triplet chunk dataset

# paths are absolute!
# support multiple anno paths (in csv), their lines are combined into one large anno dataframe
# csv format: path, classNo

class TripletChunkDataset(Dataset):
    def __init__(
        self,
        anno_paths,
        transform=None
    ):
        self.anno_paths = anno_paths
        self.transform = transform
        
        self.anno_df = []
        
        for anno_path in self.anno_paths:
            anno_df = pd.read_csv(anno_path)
            self.anno_df.append(anno_df)
        
        self.anno_df = pd.concat(self.anno_df, axis=0)
        print(self.anno_df.head())
        self.total_len = len(self.anno_df)
        
    def __len__(self):
        return self.total_len
    
    def __getitem__(self, index):
        # create triplet 
        anchor = self.anno_df.iloc[index]
        anchor_class = self.anno_df.iloc[index, 1]

        positive = self.anno_df[self.anno_df['classNo']==anchor_class].sample(n=1).iloc[0]
        negative = self.anno_df[self.anno_df['classNo']!=anchor_class].sample(n=1).iloc[0]
        
        row_triplet = [anchor, positive, negative]
        triplet = []
        
        for row in row_triplet:
            row_path = row['path']
            
            filename = os.path.split(row_path)[1]
            file_path = os.path.join(data_root, filename)
            
            row_img = Image.open(file_path)
            
            if self.transform is not None:
                row_img = self.transform(row_img)
                
            triplet.append(row_img)
        
        return tuple(triplet)

In [4]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    PadOrCrop(img_size),
    transforms.RandomAffine(30, translate=(0.2, 0.2), resample=PIL.Image.BILINEAR, fillcolor=255),
    transforms.ToTensor(),
    ChannelExpand()
])

triplet_dataset = TripletChunkDataset(
    anno_paths=anno_paths,
    transform=transform
)

triplet_loader = DataLoader(
    triplet_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=8
)

                                                path  classNo
0  /media/ssd-ext4/neg-chunk/L1903042920.046.A_8.jpg        0
1  /media/ssd-ext4/neg-chunk/L1903083406.078.A_11...        0
2  /media/ssd-ext4/neg-chunk/L1903113536.040.A_16...        0
3  /media/ssd-ext4/neg-chunk/L1903123563.277.A_35...        0
4  /media/ssd-ext4/neg-chunk/L1903123702.038.A_17...        0


In [5]:
# create a embedding resnet

class EmbeddingNet(nn.Module):
    def __init__(self, resnet):
        super(EmbeddingNet, self).__init__()
        self.resnet = resnet

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        
        return x

resnet = models.resnet34(pretrained=True)
model = EmbeddingNet(resnet)

model = model.to(device)

In [6]:
learning_rate = 1e-3
weight_decay = 1e-5

criterion = nn.TripletMarginLoss(margin=30.)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [7]:
# train

train_id = 1
epoches = 10
iter_count = len(triplet_loader)

writer = SummaryWriter()

for epoch in range(epoches):
    print('epoch: {}/{}'.format(epoch+1, epoches))
    
    with tqdm(total=iter_count, file=sys.stdout) as pbar:
        for iter_no, triplet in enumerate(triplet_loader):
            anchors, positives, negatives = triplet
            
            anchors = anchors.to(device)
            positives = positives.to(device)
            negatives = negatives.to(device)
            
            anchor_embeddings = model(anchors)
            positive_embeddings = model(positives)
            negative_embeddings = model(negatives)
            
            loss = criterion(
                anchor_embeddings,
                positive_embeddings,
                negative_embeddings
            )
            
            loss.backward()
            optimizer.step()

            writer.add_scalar(
                'train/loss',
                loss.item(),
                epoch*iter_count+iter_no
            )
            
            pbar.update(1)
        
        if not os.path.exists('./models'):
            os.mkdir('./models')
            
        torch.save(model.state_dict(), './models/EmbeddingNet-{}-{}.pth'.format(train_id, epoch))

epoch: 1/10
100%|██████████| 260/260 [01:41<00:00,  2.56it/s]
epoch: 2/10
100%|██████████| 260/260 [01:40<00:00,  2.58it/s]
epoch: 3/10
100%|██████████| 260/260 [01:40<00:00,  2.60it/s]
epoch: 4/10
100%|██████████| 260/260 [01:40<00:00,  2.60it/s]
epoch: 5/10
100%|██████████| 260/260 [01:40<00:00,  2.60it/s]
epoch: 6/10
100%|██████████| 260/260 [01:39<00:00,  2.60it/s]
epoch: 7/10
100%|██████████| 260/260 [01:40<00:00,  2.60it/s]
epoch: 8/10
100%|██████████| 260/260 [01:39<00:00,  2.60it/s]
epoch: 9/10
100%|██████████| 260/260 [01:39<00:00,  2.60it/s]
epoch: 10/10
100%|██████████| 260/260 [01:39<00:00,  2.60it/s]
