# Triplet Evaluation  

评价Triplet Loss表示训练的结果。  
1. 查看标签的Embedding分布  
2. 查看图片的Embedding分布  
3. GMM的分布  


In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib widget

import sys
import os
import pickle
import argparse
import itertools
from datetime import datetime
import gc
import csv

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset

from apex import amp
import cv2
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
import multiprocessing as mp
from tensorboardX import SummaryWriter
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from sklearn.manifold import TSNE
from sklearn import mixture
from sklearn.utils.fixes import logsumexp
from mpl_toolkits.mplot3d import Axes3D
import matplotlib

from datasets.utils import BalancedBatchSampler
from datasets.simple import *
from utils import AllTripletSelector, HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
from losses import OnlineTripletLoss
from metrics import AverageNonzeroTripletsMetric
from resnet import *
from transforms import *
from plot import *
from autoencoder import *
from network import *


In [2]:
# const

margin = 1.
class_mapping = {
    'chromosome': 0,
    'cell': 1,
    'impurity': 2
}
n_classes=len(class_mapping.keys())

In [3]:
# config

batch_size = 128 # actual batch size = 128 // 3 * 3
device = torch.device('cuda:1')
img_size = 256

data_root = '/home/xd/data/chromosome'
anno_paths = [
    'anno_round-1.csv',
    'anno_round-2.csv'
]
img_path = 'neg-chunk'

checkpoint = './models/EmbeddingNet-2-15.pth'

In [4]:
# create a embedding resnet
    
resnet = models.resnet34(pretrained=True)
model = EmbeddingNet(resnet, margin)

model.load_state_dict(torch.load(checkpoint, map_location='cpu'))

model = model.to(device)

In [5]:
# data

val_transform = transforms.Compose([
    PadOrCrop(img_size),
    transforms.ToTensor(),
    ChannelExpand()
])

val_dataset = ChunkDataset(data_root, img_path, anno_paths, class_mapping, transform=val_transform)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=8,
    pin_memory=True
)

4275
3732
                    filename     class
0  L1903012841.010.A_100.jpg  impurity
1    L1903012841.060.A_2.jpg  impurity
2   L1903012841.060.A_31.jpg  impurity
3   L1903012841.060.A_58.jpg  impurity
4    L1903012841.063.A_0.jpg  impurity


In [6]:
cluster_num = 5
color_iter = itertools.cycle(['navy', 'c', 'cornflowerblue', 'gold', 'darkorange'])

def draw_clusters_2d(clusters, cluster_num, embeddings, colors):
    embeddings = embeddings[:, :2]
    
    fig = plt.figure(figsize=(8,6))
    ax = fig.add_subplot(1, 1, 1)

    for g, c in zip(range(cluster_num), colors):
        condlist = clusters == g
        cluster = np.compress(condlist, embeddings, axis=0)

        x, y = cluster.transpose()

        ax.scatter(x, y, c=c)

    plt.show()

In [7]:
cluster_num = 5
color_iter = itertools.cycle(['navy', 'c', 'cornflowerblue', 'gold', 'darkorange'])

def draw_clusters_3d(clusters, cluster_num, embeddings):
    fig = plt.figure(1, figsize=(8,6))
    ax = Axes3D(fig)
    
    ax.scatter(
        embeddings[:, 0],
        embeddings[:, 1],
        embeddings[:, 2],
        c=clusters,
        cmap=plt.cm.Set1,
        edgecolor='k',
        s=40
    )
    ax.set_title("draw_clusters_3d")
    ax.set_xlabel("X")
    # ax.w_xaxis.set_ticklabels([])
    ax.set_ylabel("Y")
    # ax.w_yaxis.set_ticklabels([])
    ax.set_zlabel("Z")
    # ax.w_zaxis.set_ticklabels([])

    plt.show()

In [8]:
PIC_SIZE = 6400

def draw_pics(embeddings, filenames):
    fig = plt.figure(figsize=(32,32))
    ax = fig.add_subplot(1, 1, 1)
    
    img = np.full((PIC_SIZE, PIC_SIZE, 3), 255)
    
    rois = []
    
    xs, ys = embeddings.transpose()
    x_max = np.max(xs)
    x_min = np.min(xs)
    y_max = np.max(ys)
    y_min = np.min(ys)
    x_range = x_max - x_min
    y_range = y_max - y_min
    
    for embedding, filename in zip(embeddings, filenames):
        roi_img = cv2.imread(filename)
        w, h, _ = roi_img.shape
        
        x = int(((embedding[0] - x_min) / x_range) * PIC_SIZE)
        y = int(((embedding[1] - y_min) / y_range) * PIC_SIZE)
        
        if (x+w <= PIC_SIZE) and (y+h <= PIC_SIZE):
            img[x:x+w, y:y+h, ...] = roi_img
            
    ax.imshow(img)
    plt.show()
    cv2.imwrite('pic.jpg', img)

In [9]:
# get embeddings

epoch_logits = []

with torch.no_grad():
    with tqdm(total=len(val_loader), file=sys.stdout) as pbar:
        for imgs, labels in val_loader:
            imgs = imgs.to(device)

            logits = model(imgs)
            logits = logits.detach().cpu()

            epoch_logits.append(logits)

            pbar.update(1)

        epoch_logits = torch.cat(epoch_logits)

100%|██████████| 30/30 [00:20<00:00,  1.45it/s]


In [10]:
clusters = list(val_dataset.anno_df['class'])
clusters = np.array([class_mapping[class_name] for class_name in clusters])

In [11]:
# PCA

pca = PCA()
pca_logits = pca.fit_transform(epoch_logits)

In [12]:
draw_clusters_3d(clusters, cluster_num, pca_logits)

FigureCanvasNbAgg()

In [13]:
draw_clusters_2d(clusters, cluster_num, pca_logits, color_iter)

FigureCanvasNbAgg()