#Semi Supervised Learning Assignment

## Imports

In [1]:
import numpy as np
import pandas as pd
import os, time, re
import pickle, gzip
import matplotlib.pyplot as plt
import seaborn as sns
color = sns.color_palette()
import matplotlib as mpl
%matplotlib inline
from sklearn import preprocessing as pp
from sklearn.model_selection import train_test_split 
from sklearn.model_selection import StratifiedKFold 
from sklearn.metrics import log_loss
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import roc_curve, auc, roc_auc_score
import lightgbm as lgb
import tensorflow as tf
import keras
from keras import backend as K
from keras.models import Sequential, Model
from keras.layers import Activation, Dense, Dropout
from keras.layers import BatchNormalization, Input, Lambda
from keras import regularizers
from keras.losses import mse, binary_crossentropy
from fastai.vision import *
from numbers import Integral
import seaborn as sns
from sklearn.datasets import make_moons, make_blobs, make_circles, make_classification
import pdb
import contextlib
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

  import pandas.util.testing as tm


## Data

In [2]:
K=2
class MultiTransformLabelList(LabelList):
    def __getitem__(self,idxs:Union[int,np.ndarray])->'LabelList':
        "return a single (x, y) if `idxs` is an integer or a new `LabelList` object if `idxs` is a range."
        idxs = try_int(idxs)
        if isinstance(idxs, Integral):
            if self.item is None: x,y = self.x[idxs],self.y[idxs]
            else:                 x,y = self.item   ,0
            if self.tfms or self.tfmargs:
                x = [x.apply_tfms(self.tfms, **self.tfmargs) for _ in range(K)]
            if hasattr(self, 'tfms_y') and self.tfm_y and self.item is None:
                y = y.apply_tfms(self.tfms_y, **{**self.tfmargs_y, 'do_resolve':False})
            if y is None: y=0
            return x,y
        else: return self.new(self.x[idxs], self.y[idxs])
        
def MixmatchCollate(batch):
    batch = to_data(batch)
    if isinstance(batch[0][0],list):
        batch = [[torch.stack(s[0]),s[1]] for s in batch]
    return torch.utils.data.dataloader.default_collate(batch)

In [3]:
path = untar_data(URLs.CIFAR)

class MixMatchImageList(ImageList):
    def filter_train(self,num_items,seed=2343):
        train_idxs = np.array([i for i,o in enumerate(self.items) if Path(o).parts[-3] != "test"])
        valid_idxs = np.array([i for i,o in enumerate(self.items) if Path(o).parts[-3] == "test"])
        np.random.seed(seed)
        keep_idxs = np.random.choice(train_idxs,num_items,replace=False)
        self.items = np.array([o for i,o in enumerate(self.items) if i in np.concatenate([keep_idxs,valid_idxs])])
        return self
    

data_labeled = (MixMatchImageList.from_folder(path)
                .filter_train(500) 
                .split_by_folder(valid="test") 
                .label_from_folder()
                .transform(get_transforms(),size=32)
                .databunch(bs=64,num_workers=0)
                .normalize(cifar_stats))

train_set = set(data_labeled.train_ds.x.items)
src = (ImageList.from_folder(path)
        .filter_by_func(lambda x: x not in train_set)
        .split_by_folder(valid="test"))
src.train._label_list = MultiTransformLabelList
data_unlabeled = (src.label_from_folder()
         .transform(get_transforms(),size=32)
         .databunch(bs=128,collate_fn=MixmatchCollate,num_workers=0)
         .normalize(cifar_stats))


data_full = (ImageList.from_folder(path)
        .split_by_folder(valid="test")
        .label_from_folder()
        .transform(get_transforms(),size=32)
        .databunch(bs=128,num_workers=0)
        .normalize(cifar_stats))

Downloading http://files.fast.ai/data/examples/cifar10.tgz


## Build Model

In [4]:
model = models.WideResNet(num_groups=3,N=4,num_classes=10,k=2,start_nf=32)


## Training

In [5]:
class MixMatchTrainer(LearnerCallback):
    _order=-20
    def on_train_begin(self, **kwargs):
        self.l_dl = iter(data_labeled.train_dl)
        self.smoothL, self.smoothUL = SmoothenValue(0.98), SmoothenValue(0.98)
        self.recorder.add_metric_names(["l_loss","ul_loss"])
        self.it = 0
        
    def on_batch_begin(self, train, last_input, last_target, **kwargs):
        if not train: return
        try:
            x_l,y_l = next(self.l_dl)
        except:
            self.l_dl = iter(data_labeled.train_dl)
            x_l,y_l = next(self.l_dl)
            
        x_ul = last_input
        
        with torch.no_grad():
            ul_labels = sharpen(torch.softmax(torch.stack([self.learn.model(x_ul[:,i]) for i in range(x_ul.shape[1])],dim=1),dim=2).mean(dim=1))
            
        x_ul = torch.cat([x for x in x_ul])
        ul_labels = torch.cat([y.unsqueeze(0).expand(K,-1) for y in ul_labels])
        
        l_labels = torch.eye(data_labeled.c).cuda()[y_l]
        
        w_x = torch.cat([x_l,x_ul])
        w_y = torch.cat([l_labels,ul_labels])
        idxs = torch.randperm(w_x.shape[0])
        
        mixed_input, mixed_target = mixup(w_x,w_y,w_x[idxs],w_y[idxs])
        bn_idxs = torch.randperm(mixed_input.shape[0])
        unsort = [0] * len(bn_idxs)
        for i,j in enumerate(bn_idxs): unsort[j] = i
        mixed_input = mixed_input[bn_idxs]
    

        ramp = self.it / 3000.0 if self.it < 3000 else 1.0
        return {"last_input": mixed_input, "last_target": (mixed_target,unsort,ramp,x_l.shape[0])}
    
    def on_batch_end(self, train, **kwargs):
        if not train: return
        self.smoothL.add_value(self.learn.loss_func.loss_x)
        self.smoothUL.add_value(self.learn.loss_func.loss_u)
        self.it += 1

    def on_epoch_end(self, last_metrics, **kwargs):
        return add_metrics(last_metrics,[self.smoothL.smooth,self.smoothUL.smooth])

## MINST Dataset

In [7]:
path = untar_data(URLs.MNIST)

Downloading https://s3.amazonaws.com/fast-ai-imageclas/mnist_png.tgz


In [8]:
class MixMatchImageList(ImageList):
    def filter_train(self,num_items,seed=2343):
        train_idxs = np.array([i for i,o in enumerate(self.items) if Path(o).parts[-3] != "testing"])
        valid_idxs = np.array([i for i,o in enumerate(self.items) if Path(o).parts[-3] == "testing"])
        np.random.seed(seed)
        keep_idxs = np.random.choice(train_idxs,num_items,replace=False)
        self.items = np.array([o for i,o in enumerate(self.items) if i in np.concatenate([keep_idxs,valid_idxs])])
        return self

In [9]:
data = (MixMatchImageList.from_folder(path)
        .filter_train(500)
        .split_by_folder(train="training",valid="testing")
        .label_from_folder()
        .databunch(bs=64))

dataFull = (ImageList.from_folder(path)
            .split_by_folder(train="training",valid="testing")
            .label_from_folder()
            .databunch(bs=128))

In [10]:

class BasicNN(nn.Module):
    def __init__(self,num_classes=10):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,64,3,2,1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64,128,3,2,1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(128))
        self.out = nn.Linear(128,num_classes)
        
    def forward(self, x, noise=True):
        x = self.conv(x)
        x = F.adaptive_avg_pool2d(x,1).view(-1,128)
        return self.out(x)

In [11]:
def sharpen(p,T=0.5):
    u = p ** (1/T)
    return u / u.sum(dim=1,keepdim=True)


In [12]:
class EntropyMinTrainer(LearnerCallback):
    def __init__(self, learn, T=0.5):
        super().__init__(learn)
        self.T = T
        
    def on_train_begin(self, **kwargs):
        self.l_dl = iter(data.train_dl)
        self.it = 0
        
    def on_batch_begin(self, train, last_input, **kwargs):
        if not train: return 
        with torch.no_grad():
            ul_labels = sharpen(torch.softmax(self.learn.model(last_input),dim=1),T=self.T)
            #ul_labels = torch.softmax(self.learn.model(*last_input),dim=1)
        
        self.it += 1
        ramp = self.it / 800.0 if self.it < 800 else 1.0
        return {"last_target": (ul_labels,ramp)}
    
    def on_backward_begin(self, last_loss, last_output, **kwargs):
        try:
            l_x,l_y = next(self.l_dl)
        except:
            self.l_dl = iter(data.train_dl)
            l_x,l_y = next(self.l_dl)
            
        real_preds = self.learn.model(l_x)
        real_loss = F.cross_entropy(real_preds,l_y)
        return {"last_loss": last_loss + real_loss}
    
def entropy_min_loss(preds,target,ramp=None):
    if ramp is None:
        return F.cross_entropy(preds,target)
    preds = torch.softmax(preds,dim=1)
    return 10.0 * ramp * F.mse_loss(preds,target)

### Train

In [13]:
learn = Learner(dataFull,BasicNN(),loss_func=entropy_min_loss,callback_fns=[EntropyMinTrainer],metrics=accuracy)

In [15]:
learn.fit_one_cycle(5,3e-3,wd=1e-4)


epoch,train_loss,valid_loss,accuracy,time
0,0.006316,0.587751,0.8369,00:37
1,0.005953,0.729223,0.8548,00:53
2,0.004531,0.595933,0.8871,00:55
3,0.003093,0.63465,0.8964,00:45
4,0.002711,0.598873,0.9044,00:34


In [16]:
learn2 = Learner(dataFull,BasicNN(),loss_func=entropy_min_loss,callback_fns=[partial(EntropyMinTrainer,T=1.0)],metrics=accuracy)

In [17]:
learn2.fit_one_cycle(5,3e-3,wd=1e-4)


epoch,train_loss,valid_loss,accuracy,time
0,0.0,0.89121,0.6936,00:33
1,0.0,0.683977,0.7833,00:34
2,0.0,0.599547,0.8119,00:34
3,0.0,0.63127,0.8144,00:34
4,0.0,0.645716,0.8143,00:34


In [18]:
digits, preds, pred_labels, labels = [], [], [], []
preds2, pred_labels2, = [], []
l = torch.eye(10).byte().cuda()
learn.model.eval()
learn2.model.eval()
with torch.no_grad():
    for x,y in progress_bar(iter(dataFull.valid_dl),total=len(dataFull.valid_dl)):
        p = learn.model(x).detach()
        p = torch.softmax(p,dim=1)
        p_a = torch.argmax(p,dim=1)
        preds.append(p.masked_select(l[p_a]))
        pred_labels.append(p_a)
        labels.append(y)
        digits.append(x[:,0].view(-1,28*28))
        
        p = learn2.model(x).detach()
        p = torch.softmax(p,dim=1)
        p_a = torch.argmax(p,dim=1)
        preds2.append(p.masked_select(l[p_a]))
        pred_labels2.append(p_a)

labels = torch.cat(labels)
digits = torch.cat(digits)
preds = torch.cat(preds)
pred_labels = torch.cat(pred_labels)
preds2 = torch.cat(preds2)
pred_labels2 = torch.cat(pred_labels2)
(pred_labels == labels).float().mean(), (pred_labels2 == labels).float().mean()

  # This is added back by InteractiveShellApp.init_path()


(tensor(0.9044, device='cuda:0'), tensor(0.8143, device='cuda:0'))